diff --git a/examples/model_cli_demo.py b/examples/model_cli_demo.py deleted file mode 100644 index 0cdd60fe..00000000 --- a/examples/model_cli_demo.py +++ /dev/null @@ -1,63 +0,0 @@ -from argparse import ArgumentParser - -from lagent.llms import HFTransformer -from lagent.llms.meta_template import INTERNLM2_META as META - - -def parse_args(): - parser = ArgumentParser(description='chatbot') - parser.add_argument( - '--path', - type=str, - default='internlm/internlm2-chat-20b', - help='The path to the model') - parser.add_argument( - '--mode', - type=str, - default='chat', - help='Completion through chat or generate') - args = parser.parse_args() - return args - - -def main(): - args = parse_args() - # Initialize the HFTransformer-based Language Model (llm) - model = HFTransformer( - path=args.path, - meta_template=META, - max_new_tokens=1024, - top_p=0.8, - top_k=None, - temperature=0.1, - repetition_penalty=1.0, - stop_words=['<|im_end|>']) - - def input_prompt(): - print('\ndouble enter to end input >>> ', end='', flush=True) - sentinel = '' # ends when this string is seen - return '\n'.join(iter(input, sentinel)) - - history = [] - while True: - try: - prompt = input_prompt() - except UnicodeDecodeError: - print('UnicodeDecodeError') - continue - if prompt == 'exit': - exit(0) - history.append(dict(role='user', content=prompt)) - if args.mode == 'generate': - history = [dict(role='user', content=prompt)] - print('\nInternLm2:', end='') - current_length = 0 - for status, response, _ in model.stream_chat(history): - print(response[current_length:], end='', flush=True) - current_length = len(response) - history.append(dict(role='assistant', content=response)) - print('') - - -if __name__ == '__main__': - main() diff --git a/examples/run_agent_lmdeploy.py b/examples/run_agent_lmdeploy.py deleted file mode 100644 index 21b80592..00000000 --- a/examples/run_agent_lmdeploy.py +++ /dev/null @@ -1,177 +0,0 @@ -import json -import time - -from lagent.agents.aggregator import InternLMToolAggregator -from lagent.agents.stream import PLUGIN_CN, AgentForInternLM, MathCoder, get_plugin_prompt -from lagent.llms import INTERNLM2_META, LMDeployPipeline -from lagent.prompts.parsers import PluginParser - -model = LMDeployPipeline( - path='internlm/internlm2_5-7b-chat', - meta_template=INTERNLM2_META, - tp=1, - top_k=1, - temperature=1.0, - stop_words=['<|im_end|>', '<|action_end|>'], - max_new_tokens=1024, -) - -# ----------------------- interpreter ----------------------- -print('-' * 80, 'interpreter', '-' * 80) -coder = MathCoder( - llm=model, - aggregator=InternLMToolAggregator(few_shot=[[ - { - 'role': - 'user', - 'content': - "Mike paid $\\$1.25$ for a stamp three years ago. He was just offered double that amount for the stamp. Assuming the stamp's offer price doubles every three years, how many dollars will he be offered in 12 more years?" - }, - { - 'role': - 'language', - 'content': - "Let's break down the problem into smaller steps to understand how to approach it.\n\n**Step 1:** Determine the initial amount paid for the stamp, which is $1.25.\n\n**Step 2:** Understand the pattern of the offer price increase. The offer price doubles every three years.\n\n**Step 3:** Calculate the offer price in 12 more years.\n\nSince the offer price doubles every three years, in 12 years, the price will double 12 / 3 = 4 times.\n\nWe can represent the offer price after 12 more years as:\n\\[ \\text{Offer Price} = \\text{Initial Price} \\times 2^{(\\text{Number of Doublings})} \\]\n\nLet's calculate this using Python to get the exact figure.\n\n" - }, - { - 'role': 'tool', - 'content': - '```python\nfrom sympy import symbols, Eq, solve\n\n# Define the variables\ninitial_price = 1.25\ndoublings_in_12_years = 12 / 3\n\n# Calculate the offer price after 12 years\noffer_price_12_years = initial_price * 2**doublings_in_12_years\n\noffer_price_12_years\n```', - 'name': 'interpreter', - }, - { - 'role': 'environment', - 'content': '20.0', - 'name': 'interpreter' - }, - { - 'role': - 'language', - 'content': - 'After calculating the given formula, the offer price for the stamp in 12 more years is found to be $\\boxed{20.0}$ dollars.' - }, - ]])) -query = ( - 'Marie is thinking of a multiple of 63, while Jay is thinking of a factor ' - 'of 63. They happen to be thinking of the same number. There are two ' - 'possibilities for the number that each of them is thinking of, one ' - 'positive and one negative. Find the product of these two numbers.') -res = coder(query, session_id=0) -# print(res.model_dump_json()) -# print('-' * 120) -print(coder.get_steps(0)) -import json - -print( - json.dumps( - coder.agent.aggregator.aggregate(coder.agent.memory, coder.agent.name, - coder.agent.output_format), - ensure_ascii=False, - indent=4)) - -state_dict = { - 'memory': [{ - 'sender': 'user', - 'content': - 'Marie is thinking of a multiple of 63, while Jay is thinking of a factor of 63. They happen to be thinking of the same number. There are two possibilities for the number that each of them is thinking of, one positive and one negative. Find the product of these two numbers.', - 'formatted': None, - 'type': None, - 'receiver': None - }, { - 'sender': 'Agent', - 'content': - 'The product of the two numbers, \\(63\\) and \\(-63\\), is \\(-3969\\). Therefore, the product of the two numbers that Marie and Jay are thinking of is \\(\\boxed{-3969}\\).', - 'formatted': { - 'tool_type': None, - 'thought': - 'The product of the two numbers, \\(63\\) and \\(-63\\), is \\(-3969\\). Therefore, the product of the two numbers that Marie and Jay are thinking of is \\(\\boxed{-3969}\\).', - 'action': None, - 'status': 0 - }, - 'type': None, - 'receiver': None - }], - 'agent.memory': [{ - 'sender': 'user', - 'content': - 'Marie is thinking of a multiple of 63, while Jay is thinking of a factor of 63. They happen to be thinking of the same number. There are two possibilities for the number that each of them is thinking of, one positive and one negative. Find the product of these two numbers.', - 'formatted': None, - 'type': None, - 'receiver': None - }, { - 'sender': 'Agent', - 'content': - "To solve this problem, we first need to identify the factors of 63. Since Marie is thinking of a multiple of 63 and Jay is thinking of a factor of 63, and they are thinking of the same number, we need to find the factors of 63 and then determine the multiples of 63 that match these factors.\n\nLet's start by finding the factors of 63 using Python.<|action_start|><|interpreter|>\n```python\nfrom sympy import divisors\n\n# Find the factors of 63\nfactors_of_63 = divisors(63)\nfactors_of_63\n```", - 'formatted': { - 'tool_type': 'interpreter', - 'thought': - "To solve this problem, we first need to identify the factors of 63. Since Marie is thinking of a multiple of 63 and Jay is thinking of a factor of 63, and they are thinking of the same number, we need to find the factors of 63 and then determine the multiples of 63 that match these factors.\n\nLet's start by finding the factors of 63 using Python.", - 'action': - '```python\nfrom sympy import divisors\n\n# Find the factors of 63\nfactors_of_63 = divisors(63)\nfactors_of_63\n```', - 'status': 1 - }, - 'type': None, - 'receiver': None - }, { - 'sender': 'ActionExecutor', - 'content': '[1, 3, 7, 9, 21, 63]', - 'formatted': None, - 'type': None, - 'receiver': None - }, { - 'sender': 'Agent', - 'content': - "The factors of 63 are \\(1, 3, 7, 9, 21,\\) and \\(63\\). Since Marie is thinking of a multiple of 63 and Jay is thinking of a factor of 63, and they are thinking of the same number, the number they are thinking of must be one of the factors of 63.\n\nGiven that there are two possibilities for the number they are thinking of, one positive and one negative, we need to consider both \\(63\\) and \\(-63\\) as the number they might be thinking of.\n\nNow, let's calculate the product of these two numbers, \\(63\\) and \\(-63\\).<|action_start|><|interpreter|>\n```python\n# Calculate the product of 63 and -63\nproduct = 63 * (-63)\nproduct\n```", - 'formatted': { - 'tool_type': 'interpreter', - 'thought': - "The factors of 63 are \\(1, 3, 7, 9, 21,\\) and \\(63\\). Since Marie is thinking of a multiple of 63 and Jay is thinking of a factor of 63, and they are thinking of the same number, the number they are thinking of must be one of the factors of 63.\n\nGiven that there are two possibilities for the number they are thinking of, one positive and one negative, we need to consider both \\(63\\) and \\(-63\\) as the number they might be thinking of.\n\nNow, let's calculate the product of these two numbers, \\(63\\) and \\(-63\\).", - 'action': - '```python\n# Calculate the product of 63 and -63\nproduct = 63 * (-63)\nproduct\n```', - 'status': 1 - }, - 'type': None, - 'receiver': None - }, { - 'sender': 'ActionExecutor', - 'content': '-3969', - 'formatted': None, - 'type': None, - 'receiver': None - }, { - 'sender': 'Agent', - 'content': - 'The product of the two numbers, \\(63\\) and \\(-63\\), is \\(-3969\\). Therefore, the product of the two numbers that Marie and Jay are thinking of is \\(\\boxed{-3969}\\).', - 'formatted': { - 'tool_type': None, - 'thought': - 'The product of the two numbers, \\(63\\) and \\(-63\\), is \\(-3969\\). Therefore, the product of the two numbers that Marie and Jay are thinking of is \\(\\boxed{-3969}\\).', - 'action': None, - 'status': 0 - }, - 'type': None, - 'receiver': None - }] -} -coder.load_state_dict(state_dict) -print(coder.state_dict() == state_dict) -print(coder.state_dict()) - -# ----------------------- plugin ----------------------- -print('-' * 80, 'plugin', '-' * 80) -plugins = [dict(type='lagent.actions.ArxivSearch')] -agent = AgentForInternLM( - llm=model, - plugins=plugins, - output_format=dict( - type=PluginParser, - template=PLUGIN_CN, - prompt=get_plugin_prompt(plugins))) - -tic = time.time() -query = 'LLM智能体方向的最新论文有哪些?' -res = agent(query, session_id=0) -print(res.model_dump_json()) -print('-' * 120) -print(agent.get_steps(0)) -print(f'time elapsed: {time.time() - tic}') diff --git a/examples/run_agent_services.py b/examples/run_agent_services.py deleted file mode 100644 index 65fd8e9f..00000000 --- a/examples/run_agent_services.py +++ /dev/null @@ -1,110 +0,0 @@ -import asyncio -import time - -from lagent.agents.stream import PLUGIN_CN, get_plugin_prompt -from lagent.distributed import AsyncHTTPAgentClient, AsyncHTTPAgentServer, HTTPAgentClient, HTTPAgentServer -from lagent.llms import INTERNLM2_META -from lagent.schema import AgentMessage -from lagent.utils import create_object - -loop = asyncio.new_event_loop() -asyncio.set_event_loop(loop) - -server = HTTPAgentServer( - '1', - { - 'type': 'lagent.agents.AsyncAgent', - 'llm': { - 'type': 'lagent.llms.AsyncLMDeployPipeline', - 'path': 'internlm/internlm2_5-7b-chat', - 'meta_template': INTERNLM2_META, - } - }, - port=8090, -) -print(server.is_alive) -message = AgentMessage(sender='user', content='hello') -result = server(message) -print(result) -server.shutdown() - -# math coder -server = AsyncHTTPAgentServer( - '1', - { - 'type': 'lagent.agents.AsyncMathCoder', - 'llm': { - 'type': 'lagent.llms.AsyncLMDeployPipeline', - 'path': 'internlm/internlm2_5-7b-chat', - 'meta_template': INTERNLM2_META, - 'tp': 1, - 'top_k': 1, - 'temperature': 1.0, - 'stop_words': ['<|im_end|>', '<|action_end|>'], - 'max_new_tokens': 1024, - }, - 'interpreter': { - 'type': 'lagent.actions.AsyncIPythonInterpreter', - 'max_kernels': 100 - }, - }, - port=8091, -) -message = AgentMessage( - sender='user', - content= - ('Marie is thinking of a multiple of 63, while Jay is thinking of a factor ' - 'of 63. They happen to be thinking of the same number. There are two ' - 'possibilities for the number that each of them is thinking of, one ' - 'positive and one negative. Find the product of these two numbers.')) -result = server(message) -print(loop.run_until_complete(result)) -print(server.state_dict()) - -client = AsyncHTTPAgentClient(port=8091) -result = client('hello', session_id=1) -print(loop.run_until_complete(result)) -print(client.state_dict(1)) - -client = HTTPAgentClient(port=8091) -print(client.state_dict(1)) -print(client('introduce yourself', session_id=1)) -print(client.state_dict(1)) -server.shutdown() - -# plugins -plugins = [dict(type='lagent.actions.AsyncArxivSearch')] -server_cfg = dict( - type='lagent.distributed.AsyncHTTPAgentServer', - gpu_id='1', - config={ - 'type': 'lagent.agents.AsyncAgentForInternLM', - 'llm': { - 'type': 'lagent.llms.AsyncLMDeployPipeline', - 'path': 'internlm/internlm2_5-7b-chat', - 'meta_template': INTERNLM2_META, - 'tp': 1, - 'top_k': 1, - 'temperature': 1.0, - 'stop_words': ['<|im_end|>', '<|action_end|>'], - 'max_new_tokens': 1024, - }, - 'plugins': plugins, - 'output_format': { - 'type': 'lagent.prompts.parsers.PluginParser', - 'template': PLUGIN_CN, - 'prompt': get_plugin_prompt(plugins), - } - }, - port=8091, -) -server = create_object(server_cfg) -tic = time.time() -coros = [ - server(query, session_id=i) - for i, query in enumerate(['LLM智能体方向的最新论文有哪些?'] * 50) -] -res = loop.run_until_complete(asyncio.gather(*coros)) -print('-' * 120) -print(f'time elapsed: {time.time() - tic}') -server.shutdown() diff --git a/examples/run_async_agent_lmdeploy.py b/examples/run_async_agent_lmdeploy.py deleted file mode 100644 index dbd2b15f..00000000 --- a/examples/run_async_agent_lmdeploy.py +++ /dev/null @@ -1,70 +0,0 @@ -import asyncio -import json -import time - -from datasets import load_dataset - -from lagent.agents.stream import PLUGIN_CN, AsyncAgentForInternLM, AsyncMathCoder, get_plugin_prompt -from lagent.llms import INTERNLM2_META -from lagent.llms.lmdeploy_wrapper import AsyncLMDeployPipeline -from lagent.prompts.parsers import PluginParser - -# set up the loop -loop = asyncio.new_event_loop() -asyncio.set_event_loop(loop) -# initialize the model -model = AsyncLMDeployPipeline( - path='internlm/internlm2_5-7b-chat', - meta_template=INTERNLM2_META, - model_name='internlm-chat', - tp=1, - top_k=1, - temperature=1.0, - stop_words=['<|im_end|>', '<|action_end|>'], - max_new_tokens=1024, -) - -# ----------------------- interpreter ----------------------- -print('-' * 80, 'interpreter', '-' * 80) - -ds = load_dataset('lighteval/MATH', split='test') -problems = [item['problem'] for item in ds.select(range(0, 5000, 2))] - -coder = AsyncMathCoder( - llm=model, - interpreter=dict( - type='lagent.actions.AsyncIPythonInterpreter', max_kernels=300), - max_turn=11) -tic = time.time() -coros = [coder(query, session_id=i) for i, query in enumerate(problems)] -res = loop.run_until_complete(asyncio.gather(*coros)) -# print([r.model_dump_json() for r in res]) -print('-' * 120) -print(f'time elapsed: {time.time() - tic}') - -with open('./tmp_1.json', 'w') as f: - json.dump([coder.get_steps(i) for i in range(len(res))], - f, - ensure_ascii=False, - indent=4) - -# ----------------------- plugin ----------------------- -print('-' * 80, 'plugin', '-' * 80) -plugins = [dict(type='lagent.actions.AsyncArxivSearch')] -agent = AsyncAgentForInternLM( - llm=model, - plugins=plugins, - output_format=dict( - type=PluginParser, - template=PLUGIN_CN, - prompt=get_plugin_prompt(plugins))) - -tic = time.time() -coros = [ - agent(query, session_id=i) - for i, query in enumerate(['LLM智能体方向的最新论文有哪些?'] * 50) -] -res = loop.run_until_complete(asyncio.gather(*coros)) -# print([r.model_dump_json() for r in res]) -print('-' * 120) -print(f'time elapsed: {time.time() - tic}') diff --git a/examples/run_async_agent_lmdeploy_server.py b/examples/run_async_agent_lmdeploy_server.py deleted file mode 100644 index a3f78dea..00000000 --- a/examples/run_async_agent_lmdeploy_server.py +++ /dev/null @@ -1,86 +0,0 @@ -import asyncio -import json -import time - -from datasets import load_dataset - -from lagent.agents.stream import AsyncAgentForInternLM, AsyncMathCoder, get_plugin_prompt -from lagent.llms import INTERNLM2_META -from lagent.llms.lmdeploy_wrapper import AsyncLMDeployClient, AsyncLMDeployServer - -# set up the loop -loop = asyncio.new_event_loop() -asyncio.set_event_loop(loop) -# initialize the model -model = AsyncLMDeployServer( - path='internlm/internlm2_5-7b-chat', - meta_template=INTERNLM2_META, - model_name='internlm-chat', - tp=1, - top_k=1, - temperature=1.0, - stop_words=['<|im_end|>', '<|action_end|>'], - max_new_tokens=1024, -) - -# ----------------------- interpreter ----------------------- -print('-' * 80, 'interpreter', '-' * 80) - -ds = load_dataset('lighteval/MATH', split='test') -problems = [item['problem'] for item in ds.select(range(50))] - - -# coder = AsyncMathCoder( -# llm=model, -# interpreter=dict(type='AsyncIPythonInterpreter', max_kernels=250)) -# tic = time.time() -# coros = [coder(query, session_id=i) for i, query in enumerate(problems)] -# res = loop.run_until_complete(asyncio.gather(*coros)) -# # print([r.model_dump_json() for r in res]) -# print('-' * 120) -# print(f'time elapsed: {time.time() - tic}') -# with open('./tmp_4.json', 'w') as f: -# json.dump([coder.get_steps(i) for i in range(len(res))], -# f, -# ensure_ascii=False, -# indent=4) - -# ----------------------- streaming chat ----------------------- -async def streaming(llm, problem): - async for out in llm.stream_chat([{'role': 'user', 'content': problem}]): - print(out) - - -tic = time.time() -client = AsyncLMDeployClient( - url='http://127.0.0.1:23333', - meta_template=INTERNLM2_META, - model_name='internlm2_5-7b-chat', - top_k=1, - temperature=1.0, - stop_words=['<|im_end|>', '<|action_end|>'], - max_new_tokens=1024, -) -# loop.run_until_complete(streaming(model, problems[0])) -loop.run_until_complete(streaming(client, problems[0])) -print(time.time() - tic) - -# ----------------------- plugin ----------------------- -# print('-' * 80, 'plugin', '-' * 80) -# plugins = [dict(type='AsyncArxivSearch')] -# agent = AsyncAgentForInternLM( -# llm=model, -# plugins=plugins, -# aggregator=dict( -# type='InternLMToolAggregator', -# plugin_prompt=get_plugin_prompt(plugins))) - -# tic = time.time() -# coros = [ -# agent(query, session_id=i) -# for i, query in enumerate(['LLM智能体方向的最新论文有哪些?'] * 50) -# ] -# res = loop.run_until_complete(asyncio.gather(*coros)) -# # print([r.model_dump_json() for r in res]) -# print('-' * 120) -# print(f'time elapsed: {time.time() - tic}') diff --git a/examples/run_async_agent_openai.py b/examples/run_async_agent_openai.py deleted file mode 100644 index a9952f27..00000000 --- a/examples/run_async_agent_openai.py +++ /dev/null @@ -1,61 +0,0 @@ -import asyncio -import json -import time - -from datasets import load_dataset - -from lagent.agents import AsyncMathCoder -from lagent.agents.aggregator import InternLMToolAggregator -from lagent.llms import AsyncGPTAPI -from lagent.prompts.parsers import ToolParser - -loop = asyncio.new_event_loop() -asyncio.set_event_loop(loop) - -interpreter_prompt = ( - 'Below is a math problem. Please solve it step by step with the assistance of Python programming. Consider using Sympy or Numpy library ' - 'to facilitate your derivation, calculation and equation solving. Utilize the "pi" symbol and "Rational" from Sympy ' - 'for $$\pi$$ and fractions, and simplify all fractions and square roots without converting them to decimal values. ' - 'Please encapsulate each generated Jupyter Python code block with tags "{begin}" and "{end}". Conclude the ' - r'final answer when observations are sufficient and encapsulate the numerical result with LaTeX syntax "\boxed{{}}" ' - 'without any unit, and end your conclusion with the special token "[END]" to denote the completion of your response. ' - 'Keep the following points in mind:\n' - '- You must alternately use human and programming languages in the chain of thought;\n' - '- The number of your reasoning steps should not exceed **three**, which means you may merge some intermediate steps when the original answer is tedious.' -) - -async_llm = dict( - type=AsyncGPTAPI, - model='gpt-4o-2024-05-13', - retry=50, - key='', - max_new_tokens=2048, - stop_words=['', '\n```\n'], - max_new_tokens=1024, -) - -# ----------------------- interpreter ----------------------- -print('-' * 80, 'interpreter', '-' * 80) - -ds = load_dataset('lighteval/MATH', split='test') -problems = [item['problem'] for item in ds.select(range(50))] - -coder = AsyncMathCoder( - llm=model, - interpreter=dict( - type='lagent.actions.AsyncIPythonInterpreter', max_kernels=200), - output_format=ToolParser( - 'interpreter', - template= - ('Integrate step-by-step reasoning and Python code to solve math problems ' - 'using the following guidelines:\n' - '- Analyze the question and write jupyter code to solve the problem;\n' - r"- Present the final result in LaTeX using a '\boxed{{}}' without any " - 'units. \n'), - begin='\n```python\n', - end='\n```\n')) - -tic = time.time() -coros = [coder(query, session_id=i) for i, query in enumerate(problems)] -res = loop.run_until_complete(asyncio.gather(*coros)) -# print([r.model_dump_json() for r in res]) -print('-' * 120) -print(f'time elapsed: {time.time() - tic}') - -with open('./tmp_3.json', 'w') as f: - json.dump([coder.get_steps(i) for i in range(len(res))], - f, - ensure_ascii=False, - indent=4) diff --git a/examples/run_ray_async_agent_lmdeploy.py b/examples/run_ray_async_agent_lmdeploy.py deleted file mode 100644 index 530ca1b2..00000000 --- a/examples/run_ray_async_agent_lmdeploy.py +++ /dev/null @@ -1,65 +0,0 @@ -import asyncio -import json -import time - -import ray -from datasets import load_dataset - -from lagent.distributed.ray_serve import AsyncAgentRayActor -from lagent.llms import INTERNLM2_META -from lagent.llms.lmdeploy_wrapper import AsyncLMDeployPipeline - -ray.init() - -# set up the loop - -# initialize the model -loop = asyncio.new_event_loop() -asyncio.set_event_loop(loop) -model = dict( - type=AsyncLMDeployPipeline, - path='internlm/internlm2_5-7b-chat', - meta_template=INTERNLM2_META, - tp=1, - top_k=1, - temperature=1.0, - stop_words=['<|im_end|>', '<|action_end|>'], - max_new_tokens=1024, -) - -# ----------------------- interpreter ----------------------- -print('-' * 80, 'interpreter', '-' * 80) -ds = load_dataset('lighteval/MATH', split='test') -problems = [item['problem'] for item in ds.select(range(5000))] - -coder = dict( - type='lagent.agents.stream.AsyncMathCoder', - llm=model, - interpreter=dict(type='AsyncIPythonInterpreter', max_kernels=300), -) -tic = time.time() - -actor1 = AsyncAgentRayActor(coder.copy(), num_gpus=1) -actor2 = AsyncAgentRayActor(coder.copy(), num_gpus=1) -corots = [ - actor1(query, session_id=i) - for i, query in enumerate(problems[:len(problems) // 2]) -] -corots += [ - actor2(query, session_id=i) - for i, query in enumerate(problems[len(problems) // 2:]) -] -results = loop.run_until_complete(asyncio.gather(*corots)) - -print('-' * 120) -print(f'time elapsed: {time.time() - tic}') -all_step = ray.get([ - actor1.agent_actor.get_steps.remote(i) for i in range(len(problems) // 2) -]) -all_step += ray.get([ - actor2.agent_actor.get_steps.remote(i) - for i in range(len(problems[len(problems) // 2:])) -]) - -with open('./tmp_1.json', 'w') as f: - json.dump(all_step, f, ensure_ascii=False, indent=4) diff --git a/lagent/actions/__init__.py b/lagent/actions/__init__.py index b75a2262..f0c3a1d6 100644 --- a/lagent/actions/__init__.py +++ b/lagent/actions/__init__.py @@ -12,7 +12,13 @@ from .ppt import PPT, AsyncPPT from .python_interpreter import AsyncPythonInterpreter, PythonInterpreter from .web_browser import AsyncWebBrowser, WebBrowser - +from .subagent import AsyncAgentAction +from .cron import CronAction, AsyncCronAction +from .task import TaskAction, AsyncTaskAction +from .send_message import SendMessageAction, AsyncSendMessageAction +from .external_agent import ExternalAgentAction +from .sandbox_executor import SandboxActionExecutor +from .hybrid_executor import HybridActionExecutor __all__ = [ 'BaseAction', 'ActionExecutor', @@ -44,4 +50,14 @@ 'TupleParser', 'tool_api', 'AsyncActionMixin', + 'AsyncAgentAction', + 'CronAction', + 'AsyncCronAction', + 'TaskAction', + 'AsyncTaskAction', + 'SendMessageAction', + 'AsyncSendMessageAction', + 'ExternalAgentAction', + 'SandboxActionExecutor', + 'HybridActionExecutor', ] diff --git a/lagent/actions/action_executor.py b/lagent/actions/action_executor.py index e28f32d5..30c7c831 100644 --- a/lagent/actions/action_executor.py +++ b/lagent/actions/action_executor.py @@ -96,11 +96,10 @@ def forward(self, name, parameters, **kwargs) -> ActionReturn: def __call__(self, message: AgentMessage, - session_id=0, **kwargs) -> AgentMessage: # message.receiver = self.name for hook in self._hooks.values(): - result = hook.before_action(self, message, session_id) + result = hook.before_action(self, message) if result: message = result @@ -123,7 +122,7 @@ def __call__(self, ) for hook in self._hooks.values(): - result = hook.after_action(self, response_message, session_id) + result = hook.after_action(self, response_message) if result: response_message = result return response_message @@ -158,14 +157,13 @@ async def forward(self, name, parameters, **kwargs) -> ActionReturn: async def __call__(self, message: AgentMessage, - session_id=0, **kwargs) -> AgentMessage: # message.receiver = self.name for hook in self._hooks.values(): if inspect.iscoroutinefunction(hook.before_action): - result = await hook.before_action(self, message, session_id) + result = await hook.before_action(self, message) else: - result = hook.before_action(self, message, session_id) + result = hook.before_action(self, message) if result: message = result @@ -189,10 +187,9 @@ async def __call__(self, for hook in self._hooks.values(): if inspect.iscoroutinefunction(hook.after_action): - result = await hook.after_action(self, response_message, - session_id) + result = await hook.after_action(self, response_message) else: - result = hook.after_action(self, response_message, session_id) + result = hook.after_action(self, response_message) if result: response_message = result return response_message diff --git a/lagent/actions/base_action.py b/lagent/actions/base_action.py index b42036af..2a74cd92 100644 --- a/lagent/actions/base_action.py +++ b/lagent/actions/base_action.py @@ -152,7 +152,10 @@ def _parse_tool(function): if doc.kind is DocstringSectionKind.parameters: for d in doc.value: d = d.as_dict() - d['type'] = _detect_type(d.pop('annotation').lower()) + if not d['annotation']: + d.pop('annotation') + else: + d['type'] = _detect_type(d.pop('annotation').lower()) args_doc[d['name']] = d if doc.kind is DocstringSectionKind.returns: for d in doc.value: diff --git a/lagent/actions/cron.py b/lagent/actions/cron.py new file mode 100644 index 00000000..ebe4664f --- /dev/null +++ b/lagent/actions/cron.py @@ -0,0 +1,197 @@ +"""CronAction — agent-facing toolkit for managing scheduled tasks. + +This is a short-lived Action created per-request by the Dispatcher. +It wraps the long-lived :class:`CronService` singleton and carries the +request-scoped ``channel`` / ``chat_id`` so the agent itself stays +completely channel-unaware. + +Usage:: + + cron_action = CronAction( + cron_service=app.cron_service, + channel="feishu", + chat_id="oc_xxxx", + ) + executor = AsyncActionExecutor(actions=[cron_action, ...]) +""" + +from __future__ import annotations + +from typing import Annotated, Optional, Type + +from lagent.actions.base_action import AsyncActionMixin, BaseAction, tool_api +from lagent.actions.parser import BaseParser, JsonParser +from lagent.schema import ActionReturn, ActionStatusCode + +from ..services.cron import CronService, Schedule + + +class CronAction(BaseAction): + """Manage scheduled / recurring tasks. + + The agent can add, list, and remove cron jobs through this toolkit. + Each job will fire at the scheduled time and deliver a message back + to the originating channel and chat. + """ + + def __init__( + self, + cron_service: CronService, + channel: str | None = None, + chat_id: str | None = None, + description: Optional[dict] = None, + parser: Type[BaseParser] = JsonParser, + ) -> None: + super().__init__(description, parser) + self._cron = cron_service + self._channel = channel + self._chat_id = chat_id + + # ── tool APIs ───────────────────────────────────────────────── + + @tool_api + def add( + self, + name: Annotated[str, "A short human-readable name for the job"], + message: Annotated[str, "The message / task to deliver when the job fires"], + schedule_kind: Annotated[ + str, + "One of: 'at' (one-shot ISO-8601 datetime), " + "'every' (interval in seconds), " + "'cron' (5-field cron expression)", + ] = "at", + at: Annotated[ + Optional[str], + "ISO-8601 datetime for one-shot schedule, e.g. '2025-01-15T09:00:00+08:00'", + ] = None, + every_seconds: Annotated[ + Optional[float], + "Interval in seconds for recurring schedule", + ] = None, + cron_expr: Annotated[ + Optional[str], + "5-field cron expression, e.g. '30 9 * * 1-5' for weekdays at 09:30", + ] = None, + timezone: Annotated[ + Optional[str], + "IANA timezone for cron expression, e.g. 'Asia/Shanghai'", + ] = None, + ) -> ActionReturn: + """Add a new scheduled task. + + Args: + name: A short human-readable name for the job. + message: The message / task to deliver when the job fires. + schedule_kind: Schedule type — 'at', 'every', or 'cron'. + at: ISO-8601 datetime for one-shot schedule. + every_seconds: Interval in seconds for recurring schedule. + cron_expr: 5-field cron expression for cron schedule. + timezone: IANA timezone for cron expression. + + Returns: + ActionReturn with the created job summary. + """ + schedule = Schedule( + kind=schedule_kind, + at=at, + every_seconds=every_seconds, + expr=cron_expr, + tz=timezone, + ) + if schedule_kind not in ("at", "every", "cron"): + return ActionReturn( + type=self.name, + errmsg=f"Invalid schedule_kind: {schedule_kind!r}. " + f"Must be 'at', 'every', or 'cron'.", + state=ActionStatusCode.ARGS_ERROR, + ) + try: + job = self._cron.add_job( + name=name, + schedule=schedule, + message=message, + channel=self._channel, + chat_id=self._chat_id, + delete_after_run=(schedule_kind == "at"), + ) + except Exception as exc: + return ActionReturn( + type=self.name, + errmsg=f"Failed to add cron job: {exc}", + state=ActionStatusCode.API_ERROR, + ) + return ActionReturn( + type=self.name, + result=[ + dict( + type="text", + content=( + f"✅ Job created: **{job.name}** (id: `{job.id}`)\n" + f" Schedule: {schedule_kind}" + f"{f' at {at}' if at else ''}" + f"{f' every {every_seconds}s' if every_seconds else ''}" + f"{f' cron {cron_expr}' if cron_expr else ''}" + ), + ) + ], + ) + + @tool_api + def list(self) -> ActionReturn: + """List all active scheduled tasks. + + Returns: + ActionReturn with a summary of all active jobs. + """ + jobs = self._cron.list_jobs(include_disabled=False) + if not jobs: + return ActionReturn( + type=self.name, + result=[dict(type="text", content="No active scheduled tasks.")], + ) + lines = [] + for j in jobs: + sched = j.schedule + sched_desc = ( + f"at {sched.at}" if sched.kind == "at" + else f"every {sched.every_seconds}s" if sched.kind == "every" + else f"cron `{sched.expr}`" + ) + lines.append( + f"- **{j.name}** (`{j.id}`) — {sched_desc}\n" + f" message: {j.payload.get('message', '?')}" + ) + return ActionReturn( + type=self.name, + result=[dict(type="text", content="\n".join(lines))], + ) + + @tool_api + def remove( + self, + job_id: Annotated[str, "The ID of the job to remove"], + ) -> ActionReturn: + """Remove a scheduled task by its ID. + + Args: + job_id: The ID of the job to remove. + + Returns: + ActionReturn confirming removal or reporting not found. + """ + removed = self._cron.remove_job(job_id) + if removed: + return ActionReturn( + type=self.name, + result=[dict(type="text", content=f"✅ Job `{job_id}` removed.")], + ) + return ActionReturn( + type=self.name, + errmsg=f"Job `{job_id}` not found.", + state=ActionStatusCode.API_ERROR, + ) + + +class AsyncCronAction(AsyncActionMixin, CronAction): + """Async version of :class:`CronAction`.""" + pass diff --git a/lagent/actions/external_agent.py b/lagent/actions/external_agent.py new file mode 100644 index 00000000..1ce6f1a1 --- /dev/null +++ b/lagent/actions/external_agent.py @@ -0,0 +1,128 @@ +"""ExternalAgentAction — expose external agent adapters as @tool_api. + +Follows the same pattern as :class:`AsyncAgentAction` in +``lagent/actions/subagent.py``: a short-lived Action that wraps +a set of registered external agent adapters, allowing a PolicyAgent +to delegate tasks to external frameworks via tool calls. + +Usage:: + + from lagent.adapters.cli_adapter import CLIAgentAdapter + + claude = CLIAgentAdapter( + name="claude-code", + command_template="claude -p '{task}' --output-format text", + ) + action = ExternalAgentAction(adapters={"claude-code": claude}) + executor = AsyncActionExecutor(actions=[action, ...]) +""" + +from typing import Annotated, Any, Dict, Optional, Type + +from lagent.actions.base_action import AsyncActionMixin, BaseAction, tool_api +from lagent.actions.parser import BaseParser, JsonParser +from lagent.schema import ActionReturn, ActionStatusCode + + +class ExternalAgentAction(AsyncActionMixin, BaseAction): + """Toolkit that lets a PolicyAgent delegate tasks to external agents. + + The main agent can call external framework agents (CLI or SDK-based) + through this toolkit. Each external agent is registered by name. + + Args: + adapters: Dict mapping adapter names to adapter instances. + description: Optional description override. + parser: Parser class. Default: JsonParser. + """ + + def __init__( + self, + adapters: Optional[Dict[str, Any]] = None, + description: Optional[dict] = None, + parser: Type[BaseParser] = JsonParser, + ) -> None: + super().__init__(description, parser) + self._adapters: Dict[str, Any] = adapters or {} + + def register_adapter(self, name: str, adapter: Any) -> None: + """Register an external agent adapter by name.""" + self._adapters[name] = adapter + + @tool_api + async def run_agent( + self, + agent_name: Annotated[ + str, + "Name of the external agent to invoke. Must be one of the " + "registered adapter names (use list_agents to see available).", + ], + task: Annotated[ + str, + "A clear, self-contained description of the task. Include " + "all necessary context — the external agent has no access " + "to the current conversation.", + ], + ) -> ActionReturn: + """Delegate a task to an external agent framework. + + Runs the specified external agent with the given task and + returns its output. + + Args: + agent_name: Registered name of the external agent adapter. + task: Self-contained task description. + + Returns: + ActionReturn with the external agent's result. + """ + adapter = self._adapters.get(agent_name) + if adapter is None: + available = list(self._adapters.keys()) + return ActionReturn( + type=self.name, + errmsg=f"Unknown agent: {agent_name!r}. Available: {available}", + state=ActionStatusCode.API_ERROR, + ) + + try: + # Go through Agent.__call__() so hooks and memory work + response = await adapter(task) + except Exception as exc: + return ActionReturn( + type=self.name, + errmsg=f"External agent '{agent_name}' failed: {exc}", + state=ActionStatusCode.API_ERROR, + ) + + return ActionReturn( + type=self.name, + result=[dict( + type='text', + content=( + f"**External agent `{agent_name}` completed.**\n\n" + f"{response.content}" + ), + )], + ) + + @tool_api + async def list_agents(self) -> ActionReturn: + """List all available external agent adapters and their descriptions. + + Returns: + ActionReturn with a summary of available external agents. + """ + if not self._adapters: + return ActionReturn( + type=self.name, + result=[dict(type='text', content='No external agents registered.')], + ) + lines = [] + for name, adapter in self._adapters.items(): + desc = getattr(adapter, 'description', None) or '(no description)' + lines.append(f"- **{name}**: {desc}") + return ActionReturn( + type=self.name, + result=[dict(type='text', content='\n'.join(lines))], + ) diff --git a/lagent/actions/filesystem.py b/lagent/actions/filesystem.py new file mode 100644 index 00000000..0696cc04 --- /dev/null +++ b/lagent/actions/filesystem.py @@ -0,0 +1,252 @@ +import difflib +from pathlib import Path +from typing import Any, Optional, Type + +from asyncer import asyncify + +from lagent.actions.base_action import BaseAction, tool_api, AsyncActionMixin +from lagent.actions.parser import BaseParser, JsonParser +from lagent.schema import ActionReturn, ActionStatusCode + + +def _resolve_path( + path: str, workspace: Optional[Path] = None, allowed_dir: Optional[Path] = None +) -> Path: + """Resolve path against workspace (if relative) and enforce directory restriction.""" + p = Path(path).expanduser() + if not p.is_absolute() and workspace: + p = workspace / p + resolved = p.resolve() + if allowed_dir: + try: + resolved.relative_to(allowed_dir.resolve()) + except ValueError: + raise PermissionError(f"Path {path} is outside allowed directory {allowed_dir}") + return resolved + + +class ReadFileAction(AsyncActionMixin, BaseAction): + """Tool to read file contents.""" + + _MAX_CHARS = 128_000 + + def __init__( + self, + workspace: Optional[str] = None, + allowed_dir: Optional[str] = None, + description: Optional[dict] = None, + parser: Type[BaseParser] = JsonParser, + ): + super().__init__(description, parser) + self._workspace = Path(workspace) if workspace else None + self._allowed_dir = Path(allowed_dir) if allowed_dir else None + + @tool_api + @asyncify + def run(self, path: str) -> ActionReturn: + """Read the contents of a file at the given path. + + Args: + path (str): The file path to read + """ + try: + file_path = _resolve_path(path, self._workspace, self._allowed_dir) + if not file_path.exists(): + return ActionReturn( + type=self.name, + errmsg=f"Error: File not found: {path}", + state=ActionStatusCode.API_ERROR, + ) + if not file_path.is_file(): + return ActionReturn( + type=self.name, + errmsg=f"Error: Not a file: {path}", + state=ActionStatusCode.API_ERROR, + ) + + size = file_path.stat().st_size + if size > self._MAX_CHARS * 4: + return ActionReturn( + type=self.name, + errmsg=f"Error: File too large ({size:,} bytes). Use exec tool with head/tail/grep to read portions.", + state=ActionStatusCode.API_ERROR, + ) + + content = file_path.read_text(encoding="utf-8") + if len(content) > self._MAX_CHARS: + content = content[: self._MAX_CHARS] + f"\n\n... (truncated — file is {len(content):,} chars, limit {self._MAX_CHARS:,})" + + return ActionReturn( + type=self.name, + result=[dict(type='text', content=content)], + state=ActionStatusCode.SUCCESS, + ) + except PermissionError as e: + return ActionReturn( + type=self.name, + errmsg=f"PermissionError: {e}", + state=ActionStatusCode.API_ERROR, + ) + except Exception as e: + return ActionReturn( + type=self.name, + errmsg=f"Error reading file: {str(e)}", + state=ActionStatusCode.API_ERROR, + ) + + +class WriteFileAction(AsyncActionMixin, BaseAction): + """Tool to write content to a file.""" + + def __init__( + self, + workspace: Optional[str] = None, + allowed_dir: Optional[str] = None, + description: Optional[dict] = None, + parser: Type[BaseParser] = JsonParser, + ): + super().__init__(description, parser) + self._workspace = Path(workspace) if workspace else None + self._allowed_dir = Path(allowed_dir) if allowed_dir else None + + @tool_api + @asyncify + def run(self, path: str, content: str) -> ActionReturn: + """Write content to a file at the given path. Creates parent directories if needed. + + Args: + path (str): The file path to write to. + content (str): The content to write. + """ + try: + file_path = _resolve_path(path, self._workspace, self._allowed_dir) + + needs_newline = content and not content.endswith("\n") + if needs_newline: + content += "\n" + + verb = "Updated" if file_path.exists() else "Created" + + file_path.parent.mkdir(parents=True, exist_ok=True) + file_path.write_text(content, encoding="utf-8") + + return ActionReturn( + type=self.name, + result=[dict(type='text', content=f"Successfully {verb.lower()} file at {path}")], + state=ActionStatusCode.SUCCESS, + ) + except PermissionError as e: + return ActionReturn( + type=self.name, + errmsg=f"PermissionError: {e}", + state=ActionStatusCode.API_ERROR, + ) + except Exception as e: + return ActionReturn( + type=self.name, + errmsg=f"Error writing file: {str(e)}", + state=ActionStatusCode.API_ERROR, + ) + + +class EditFileAction(AsyncActionMixin, BaseAction): + """Tool to edit a file using search and replace blocks.""" + + def __init__( + self, + workspace: Optional[str] = None, + allowed_dir: Optional[str] = None, + description: Optional[dict] = None, + parser: Type[BaseParser] = JsonParser, + ): + super().__init__(description, parser) + self._workspace = Path(workspace) if workspace else None + self._allowed_dir = Path(allowed_dir) if allowed_dir else None + + @tool_api + @asyncify + def run(self, path: str, search: str, replace: str) -> ActionReturn: + """Edit a file by replacing a specific block of text. + + Args: + path (str): The file path to edit. + search (str): The exact text string to search for and replace. Must match the file content exactly, including whitespace. + replace (str): The new text to replace the search block with. + """ + try: + file_path = _resolve_path(path, self._workspace, self._allowed_dir) + if not file_path.exists(): + return ActionReturn( + type=self.name, + errmsg=f"Error: File not found: {path}", + state=ActionStatusCode.API_ERROR, + ) + if not file_path.is_file(): + return ActionReturn( + type=self.name, + errmsg=f"Error: Not a file: {path}", + state=ActionStatusCode.API_ERROR, + ) + + content = file_path.read_text(encoding="utf-8") + + count = content.count(search) + if count == 0: + return ActionReturn( + type=self.name, + errmsg="Error: Search text not found exactly in file. Ensure exact whitespace matching.", + state=ActionStatusCode.API_ERROR, + ) + if count > 1: + return ActionReturn( + type=self.name, + errmsg="Error: Search text matched multiple times. Provide more context to make it unique.", + state=ActionStatusCode.API_ERROR, + ) + + new_content = content.replace(search, replace) + file_path.write_text(new_content, encoding="utf-8") + + return ActionReturn( + type=self.name, + result=[dict(type='text', content=f"Successfully edited file at {path}")], + state=ActionStatusCode.SUCCESS, + ) + except PermissionError as e: + return ActionReturn( + type=self.name, + errmsg=f"PermissionError: {e}", + state=ActionStatusCode.API_ERROR, + ) + except Exception as e: + return ActionReturn( + type=self.name, + errmsg=f"Error editing file: {str(e)}", + state=ActionStatusCode.API_ERROR, + ) + + +if __name__ == "__main__": + # Example usage + workspace = "/mnt/shared-storage-user/llmit/user/liukuikun/workspace/lagent/tests/test_actions" + read_action = ReadFileAction(workspace=workspace) + write_action = WriteFileAction(workspace=workspace) + edit_action = EditFileAction(workspace=workspace) + async def test_actions(): + # Test writing a file + write_result = await write_action.run(path="test.txt", content="Hello, world!") + print(write_result) + + # Test reading the file + read_result = await read_action.run(path="test.txt") + print(read_result) + + # Test editing the file + edit_result = await edit_action.run(path="test.txt", search="world", replace="universe") + print(edit_result) + + # Read the file again to see changes + read_result_after_edit = await read_action.run(path="test.txt") + print(read_result_after_edit) + import asyncio + asyncio.run(test_actions()) \ No newline at end of file diff --git a/lagent/actions/hybrid_executor.py b/lagent/actions/hybrid_executor.py new file mode 100644 index 00000000..b3921b05 --- /dev/null +++ b/lagent/actions/hybrid_executor.py @@ -0,0 +1,111 @@ +"""HybridActionExecutor — routes actions to local or sandbox execution. + +Some actions (subagent, save_memory) must run locally because they need +access to the LLM, parent agent memory, or external services. Environment +actions (shell, ipython, file ops) run in the sandbox. + +Usage:: + + sandbox = SandboxActionExecutor( + sandbox_client=my_client, + actions_config=[{"type": "..ShellAction"}, {"type": "..IPythonInterpreter"}], + ) + await sandbox.connect() + + executor = HybridActionExecutor( + local_actions=[AsyncAgentAction(), AsyncSaveMemoryAction()], + sandbox_executor=sandbox, + ) + + # forward() automatically routes by name + await executor.forward("shell", {"command": "ls"}) # → sandbox + await executor.forward("AgentAction.spawn", {"task": "..."}) # → local +""" + +from __future__ import annotations + +import inspect +from typing import Dict, List, Optional, Union + +from lagent.actions.action_executor import AsyncActionExecutor +from lagent.actions.base_action import BaseAction +from lagent.actions.sandbox_executor import SandboxActionExecutor +from lagent.schema import ActionReturn, ActionValidCode, AgentMessage + + +class HybridActionExecutor(AsyncActionExecutor): + """Routes action calls to either local execution or a remote sandbox. + + Inherits from ``AsyncActionExecutor`` for local actions and delegates + sandbox actions to a ``SandboxActionExecutor``. + + Parameters + ---------- + local_actions : list + Actions to run locally (subagent, save_memory, etc.). + sandbox_executor : SandboxActionExecutor + Executor that routes to the sandbox daemon. + **kwargs + Passed to ``AsyncActionExecutor.__init__`` (hooks, finish_action, etc.). + """ + + def __init__( + self, + local_actions: Union[BaseAction, List[BaseAction], Dict, List[Dict]] = None, + sandbox_executor: SandboxActionExecutor = None, + **kwargs, + ): + super().__init__(actions=local_actions or [], **kwargs) + self.sandbox_executor = sandbox_executor + + @property + def _sandbox_actions(self) -> Dict[str, object]: + if self.sandbox_executor is None: + return {} + return self.sandbox_executor.actions + + def description(self) -> List[Dict]: + descs = super().description() + if self.sandbox_executor is not None: + descs = descs + self.sandbox_executor.description() + return descs + + def __contains__(self, name: str) -> bool: + return super().__contains__(name) or ( + self.sandbox_executor is not None and name in self.sandbox_executor + ) + + def keys(self) -> List[str]: + k = super().keys() + if self.sandbox_executor is not None: + k = k + self.sandbox_executor.keys() + return k + + async def forward(self, name: str, parameters: dict, **kwargs) -> ActionReturn: + action_name = name.split(".")[0] if "." in name else name + + # 1. Local actions take priority + if action_name in self.actions: + return await super().forward(name, parameters, **kwargs) + + # 2. Sandbox actions + if self.sandbox_executor is not None and action_name in self.sandbox_executor: + return await self.sandbox_executor.forward(name, parameters, **kwargs) + + # 3. Built-in fallbacks (NoAction, FinishAction, InvalidAction) + if name == self.no_action.name: + return self.no_action(parameters) + elif name == self.finish_action.name: + return self.finish_action(parameters) + else: + return self.invalid_action(parameters) + + async def connect(self) -> None: + """Connect the sandbox executor (if present).""" + if self.sandbox_executor is not None: + await self.sandbox_executor.connect() + + async def close(self) -> None: + """Close the sandbox executor (if present).""" + if self.sandbox_executor is not None: + await self.sandbox_executor.close() diff --git a/lagent/actions/ipython_interpreter.py b/lagent/actions/ipython_interpreter.py index 68e9a0de..2c52f1f5 100644 --- a/lagent/actions/ipython_interpreter.py +++ b/lagent/actions/ipython_interpreter.py @@ -167,8 +167,6 @@ class IPythonInterpreter(BaseAction): action's inputs and outputs. Defaults to :class:`JsonParser`. """ - _KERNEL_CLIENTS = {} - def __init__( self, timeout: int = 20, @@ -205,20 +203,29 @@ def start_kernel(): def initialize(self): if self._initialized: return - pid = os.getpid() - if pid not in self._KERNEL_CLIENTS: - self._KERNEL_CLIENTS[pid] = self.start_kernel() - self.kernel_manager, self.kernel_client = self._KERNEL_CLIENTS[pid] + + if getattr(self, 'kernel_manager', None) is None or getattr(self, 'kernel_client', None) is None: + self.kernel_manager, self.kernel_client = self.start_kernel() + self._initialized = True self._call(START_CODE.format(self.user_data_dir), None) def reset(self): if not self._initialized: - self.initialize() + return else: code = "get_ipython().run_line_magic('reset', '-f')\n" + \ START_CODE.format(self.user_data_dir) self._call(code, None) + + def close(self): + if hasattr(self, 'kernel_client') and self.kernel_client: + self.kernel_client.stop_channels() + if hasattr(self, 'kernel_manager') and self.kernel_manager: + self.kernel_manager.shutdown_kernel() + self._initialized = False + self.kernel_client = None + self.kernel_manager = None def _call(self, command: str, @@ -366,114 +373,89 @@ class AsyncIPythonInterpreter(AsyncActionMixin, IPythonInterpreter): Defaults to `ENV`. work_dir (str, optional): Specify which directory to save output images to. Defaults to ``'./work_dir/tmp_dir'``. + connection_dir (str, optional): Connection directory for the kernel manager. + kernel_backend (AsyncMultiKernelManager, optional): A shared backend to spawn kernels from. + If not provided, a default one will be created. description (dict): The description of the action. Defaults to ``None``. parser (Type[BaseParser]): The parser class to process the action's inputs and outputs. Defaults to :class:`JsonParser`. """ - _UNBOUND_KERNEL_CLIENTS = asyncio.Queue() - def __init__( self, timeout: int = 20, user_data_dir: str = 'ENV', work_dir=os.path.join(tempfile.gettempdir(), 'tmp_dir'), - max_kernels: Optional[int] = None, - reuse_kernel: bool = True, - startup_rate: bool = 32, connection_dir: str = tempfile.gettempdir(), + kernel_backend: Optional[AsyncMultiKernelManager] = None, description: Optional[dict] = None, parser: Type[BaseParser] = JsonParser, ): super().__init__(timeout, user_data_dir, work_dir, description, parser) - from traitlets.config import Config - - c = Config() - c.KernelManager.transport = 'ipc' - self._amkm = AsyncMultiKernelManager( - config=c, connection_dir=connection_dir) - self._max_kernels = max_kernels - self._reuse_kernel = reuse_kernel - self._sem = asyncio.Semaphore(startup_rate) + + if kernel_backend is None: + from traitlets.config import Config + c = Config() + c.KernelManager.transport = 'ipc' + self._amkm = AsyncMultiKernelManager(config=c, connection_dir=connection_dir) + else: + self._amkm = kernel_backend + self._lock = asyncio.Lock() - - async def initialize(self, session_id: str): - session_id = str(session_id) - while True: - if session_id in self._KERNEL_CLIENTS: - return self._KERNEL_CLIENTS[session_id] - if self._reuse_kernel and not self._UNBOUND_KERNEL_CLIENTS.empty(): - self._KERNEL_CLIENTS[ - session_id] = await self._UNBOUND_KERNEL_CLIENTS.get() - return self._KERNEL_CLIENTS[session_id] - async with self._sem: - if self._max_kernels is None or len( - self._KERNEL_CLIENTS - ) + self._UNBOUND_KERNEL_CLIENTS.qsize() < self._max_kernels: - kernel_id = None - try: - kernel_id = await self._amkm.start_kernel() - kernel = self._amkm.get_kernel(kernel_id) - client = kernel.client() - _, error_stacktrace, stream_text = await async_run_code( - kernel, - START_CODE.format(self.user_data_dir), - shutdown_kernel=False) - # check if the output of START_CODE meets expectations - if not (error_stacktrace is None - and stream_text == ''): - raise RuntimeError - except Exception as e: - print(f'Starting kernel error: {e}') - if kernel_id: - await self._amkm.shutdown_kernel(kernel_id) - self._amkm.remove_kernel(kernel_id) - await asyncio.sleep(1) - continue - if self._max_kernels is None: - self._KERNEL_CLIENTS[session_id] = (kernel_id, kernel, - client) - return kernel_id, kernel, client - async with self._lock: - if len(self._KERNEL_CLIENTS - ) + self._UNBOUND_KERNEL_CLIENTS.qsize( - ) < self._max_kernels: - self._KERNEL_CLIENTS[session_id] = (kernel_id, - kernel, client) - return kernel_id, kernel, client - await self._amkm.shutdown_kernel(kernel_id) - self._amkm.remove_kernel(kernel_id) - await asyncio.sleep(1) - - async def reset(self, session_id: str): - session_id = str(session_id) - if session_id not in self._KERNEL_CLIENTS: + self._kernel_id = None + self._kernel = None + self._client = None + + async def initialize(self): + if self._kernel_id is not None: + return self._kernel, self._client + + async with self._lock: + if self._kernel_id is None: + try: + self._kernel_id = await self._amkm.start_kernel() + self._kernel = self._amkm.get_kernel(self._kernel_id) + self._client = self._kernel.client() + _, error_stacktrace, stream_text = await async_run_code( + self._kernel, + START_CODE.format(self.user_data_dir), + shutdown_kernel=False) + # check if the output of START_CODE meets expectations + if not (error_stacktrace is None + and stream_text == ''): + raise RuntimeError("Failed to initialize kernel with START_CODE") + except Exception as e: + print(f'Starting kernel error: {e}') + if self._kernel_id: + await self._amkm.shutdown_kernel(self._kernel_id) + self._amkm.remove_kernel(self._kernel_id) + self._kernel_id = None + self._kernel = None + self._client = None + raise e + return self._kernel, self._client + + async def reset(self): + if self._kernel_id is None: return - _, kernel, _ = self._KERNEL_CLIENTS[session_id] code = "get_ipython().run_line_magic('reset', '-f')\n" + \ START_CODE.format(self.user_data_dir) - await async_run_code(kernel, code, shutdown_kernel=False) - - async def shutdown(self, session_id: str): - session_id = str(session_id) - if session_id in self._KERNEL_CLIENTS: - kernel_id, _, _ = self._KERNEL_CLIENTS.get(session_id) - await self._amkm.shutdown_kernel(kernel_id) - self._amkm.remove_kernel(kernel_id) - del self._KERNEL_CLIENTS[session_id] - - async def close_session(self, session_id: str): - session_id = str(session_id) - if self._reuse_kernel: - if session_id in self._KERNEL_CLIENTS: - await self.reset(session_id) - await self._UNBOUND_KERNEL_CLIENTS.put( - self._KERNEL_CLIENTS.pop(session_id)) - else: - await self.shutdown(session_id) - - async def _call(self, command, timeout=None, session_id=None): - _, kernel, _ = await self.initialize(str(session_id)) + await async_run_code(self._kernel, code, shutdown_kernel=False) + + async def shutdown(self): + if self._kernel_id is not None: + await self._amkm.shutdown_kernel(self._kernel_id) + self._amkm.remove_kernel(self._kernel_id) + self._kernel_id = None + self._kernel = None + self._client = None + + async def close_session(self): + # Kept for backward compatibility, but now it just shuts down the single kernel + await self.shutdown() + + async def _call(self, command, timeout=None): + kernel, _ = await self.initialize() result = await async_run_code( kernel, extract_code(command), @@ -494,8 +476,7 @@ async def _call(self, command, timeout=None, session_id=None): @tool_api async def run(self, command: str, - timeout: Optional[int] = None, - session_id: Optional[str] = None) -> ActionReturn: + timeout: Optional[int] = None) -> ActionReturn: r"""When you send a message containing Python code to python, it will be executed in a stateful Jupyter notebook environment. python will respond with the output of the execution or time out after 60.0 seconds. The drive at '/mnt/data' can be used to save and persist user files. Internet access for this session is disabled. Do not make external web requests or API calls as they will fail. Args: @@ -504,7 +485,7 @@ async def run(self, """ tool_return = ActionReturn(url=None, args=None, type=self.name) tool_return.args = dict(text=command) - succeed, result = await self._call(command, timeout, session_id) + succeed, result = await self._call(command, timeout) if succeed: text = result['text'] image = result.get('image', []) @@ -512,8 +493,6 @@ async def run(self, if image: resp.extend([dict(type='image', content=im) for im in image]) tool_return.result = resp - # tool_return.result = dict( - # text=result['text'], image=result.get('image', [])[0]) tool_return.state = ActionStatusCode.SUCCESS else: tool_return.errmsg = result.get('text', '') if isinstance( @@ -579,6 +558,97 @@ def get_multiline_input(hint): if __name__ == '__main__': + # ======================================================== + # Test 1: Synchronous IPython Interpreter (Single Kernel) + # ======================================================== + print("--- Testing Synchronous IPythonInterpreter ---") code_interpreter = IPythonInterpreter() - while True: - print(code_interpreter(get_multiline_input('Enter python code:'))) + + # Test simple execution + res1 = code_interpreter(inputs=json.dumps({"command": "print('Hello from Sync Interpreter!')"})) + print(f"Result 1: {res1.result[0]['content']}\n") + + # Test statefulness (variable persistence) + code_interpreter(inputs=json.dumps({"command": "x = 42"})) + res2 = code_interpreter(inputs=json.dumps({"command": "print(f'x is {x}')"})) + print(f"Result 2: {res2.result[0]['content']}\n") + + # Cleanup + code_interpreter.close() + + + # ======================================================== + # Test 2: Asynchronous IPython Interpreter (Single Kernel) + # ======================================================== + async def test_async_interpreter(): + print("--- Testing AsyncIPythonInterpreter ---") + async_interpreter = AsyncIPythonInterpreter() + + # Test simple execution + res1 = await async_interpreter(inputs=json.dumps({"command": "print('Hello from Async Interpreter!')"})) + print(f"Result 1: {res1.result[0]['content']}\n") + + # Test statefulness + await async_interpreter(inputs=json.dumps({"command": "y = 100"})) + res2 = await async_interpreter(inputs=json.dumps({"command": "print(f'y is {y}')"})) + print(f"Result 2: {res2.result[0]['content']}\n") + + # Cleanup + await async_interpreter.close_session() + + asyncio.run(test_async_interpreter()) + + + # ======================================================== + # Test 3: Shared vs Non-Shared Backend Performance + # ======================================================== + async def test_shared_backend_performance(): + print("--- Testing Shared vs Non-Shared Backend Performance ---") + import time + from traitlets.config import Config + + num_kernels = 10 + + # Test 1: Non-Shared Backend (Each instance creates its own AsyncMultiKernelManager) + print(f"\nStarting {num_kernels} Non-Shared kernels...") + start_time = time.time() + non_shared_agents = [] + for _ in range(num_kernels): + agent = AsyncIPythonInterpreter() + # Force initialization + await agent.initialize() + non_shared_agents.append(agent) + + non_shared_duration = time.time() - start_time + print(f"Non-Shared initialization took {non_shared_duration:.2f} seconds") + + # Cleanup + for agent in non_shared_agents: + await agent.close_session() + + # Test 2: Shared Backend (All instances share one AsyncMultiKernelManager) + print(f"\nStarting {num_kernels} Shared kernels...") + start_time = time.time() + + c = Config() + c.KernelManager.transport = 'ipc' + shared_backend = AsyncMultiKernelManager(config=c, connection_dir=tempfile.gettempdir()) + + shared_agents = [] + for _ in range(num_kernels): + agent = AsyncIPythonInterpreter(kernel_backend=shared_backend) + # Force initialization + await agent.initialize() + shared_agents.append(agent) + + shared_duration = time.time() - start_time + print(f"Shared initialization took {shared_duration:.2f} seconds") + + # Cleanup + for agent in shared_agents: + await agent.close_session() + await shared_backend.shutdown_all() + + print(f"\nPerformance Improvement: {(non_shared_duration/shared_duration - 1)*100:.1f}% faster with Shared Backend") + + asyncio.run(test_shared_backend_performance()) diff --git a/lagent/actions/ipython_manager.py b/lagent/actions/ipython_manager.py index 4e8a2fdb..e5d928b8 100644 --- a/lagent/actions/ipython_manager.py +++ b/lagent/actions/ipython_manager.py @@ -132,8 +132,7 @@ def __init__( self.out_queue = Queue() def __call__(self, - commands: Union[str, List[str]], - session_ids: Union[int, List[int]] = None): + commands: Union[str, List[str]]): if isinstance(commands, list): batch_size = len(commands) is_batch = True @@ -141,14 +140,10 @@ def __call__(self, batch_size = 1 commands = [commands] is_batch = False - if session_ids is None: - session_ids = range(batch_size) - elif isinstance(session_ids, int): - session_ids = [session_ids] - if len(session_ids) != batch_size or len(session_ids) != len( - set(session_ids)): - raise ValueError( - 'the size of `session_ids` must equal that of `commands`') + + # We just generate some dummy ids for the batch processing internally + session_ids = range(batch_size) + try: exec_results = self.run_code_blocks([ (session_id, command) diff --git a/lagent/actions/mcp_client.py b/lagent/actions/mcp_client.py new file mode 100644 index 00000000..b9c27772 --- /dev/null +++ b/lagent/actions/mcp_client.py @@ -0,0 +1,612 @@ +import asyncio +import logging +import random +import threading +import time +from collections import deque +from contextlib import AsyncExitStack, nullcontext +from typing import Deque, Literal, Optional, TypeAlias + +from lagent.actions.base_action import AsyncActionMixin, BaseAction +from lagent.actions.parser import JsonParser, ParseError +from lagent.schema import ActionReturn, ActionStatusCode + +ServerType: TypeAlias = Literal["stdio", "sse", "http"] + +logger = logging.getLogger(__name__) +_loop = None + + +def _get_event_loop(): + try: + event_loop = asyncio.get_event_loop() + except Exception: + logger.warning('Can not found event loop in current thread. Create a new event loop.') + event_loop = asyncio.new_event_loop() + asyncio.set_event_loop(event_loop) + + if event_loop.is_running(): + global _loop + if _loop: + return _loop + + from threading import Thread + + def _start_loop(loop): + asyncio.set_event_loop(loop) + loop.run_forever() + + event_loop = asyncio.new_event_loop() + Thread(target=_start_loop, args=(event_loop,), daemon=True).start() + _loop = event_loop + return event_loop + + +logger = logging.getLogger(__file__) + + +class TokenBucket: + def __init__(self, rate_limit: float): + self.rate_limit = rate_limit # tokens per second + self.tokens = rate_limit + self.last_update = time.time() + self.lock = threading.Lock() + + def acquire(self) -> bool: + with self.lock: + now = time.time() + # Add new tokens based on time elapsed + new_tokens = (now - self.last_update) * self.rate_limit + self.tokens = min(self.rate_limit, self.tokens + new_tokens) + self.last_update = now + + if self.tokens >= 1: + self.tokens -= 1 + return True + return False + + +class AsyncTokenBucket: + def __init__(self, rate_limit: float): + self.rate_limit = rate_limit + self.capacity = rate_limit + self.tokens = rate_limit + self.last_update = time.monotonic() + self._lock = asyncio.Lock() + + def _refill(self): + now = time.monotonic() + elapsed = now - self.last_update + if elapsed <= 0: + return + self.tokens = min(self.capacity, self.tokens + elapsed * self.rate_limit) + self.last_update = now + + async def acquire(self): + while True: + async with self._lock: + self._refill() + if self.tokens >= 1: + self.tokens -= 1 + return + missing = 1 - self.tokens + wait_time = missing / self.rate_limit + await asyncio.sleep(wait_time) + + +class FairAsyncTokenBucket: + def __init__(self, rate_limit: float, capacity: Optional[float] = None): + """ + rate_limit: 每秒生成多少个 token + capacity: 桶容量(最大可累积多少 token),默认和 rate_limit 一样 + """ + self.rate_limit = float(rate_limit) + self.capacity = float(capacity) if capacity is not None else float(rate_limit) + + self.tokens = self.capacity + self.last_update = time.monotonic() + + self._lock = asyncio.Lock() + self._waiters: Deque[asyncio.Future] = deque() + self._drainer_running = False # 是否已有后台协程在发 token + + # ---------- 内部工具方法 ---------- + + def _refill_unlocked(self) -> None: + """ + 在不持锁的前提下不要调用。 + 根据时间流逝计算当前 token 数。 + """ + now = time.monotonic() + elapsed = now - self.last_update + if elapsed <= 0: + return + self.tokens = min(self.capacity, self.tokens + elapsed * self.rate_limit) + self.last_update = now + + async def _drain_waiters(self) -> None: + """ + 后台协程:按 FIFO 顺序给排队的协程发 token。 + - 没 token 时,就 sleep 到下一个 token 产生的时间点。 + - 有 token 且有排队,就唤醒队头的一个,再继续循环。 + """ + try: + while True: + fut_to_wake: Optional[asyncio.Future] = None + sleep_time: Optional[float] = None + + async with self._lock: + self._refill_unlocked() + + # 队列空了,没什么好做的了,退出 drainer + if not self._waiters: + self._drainer_running = False + return + + if self.tokens >= 1: + # 有 token,按 FIFO 唤醒一个排队的协程 + self.tokens -= 1 + fut_to_wake = self._waiters.popleft() + sleep_time = 0.0 + else: + # 没 token,算一下距离下一个 token 的时间 + missing = 1.0 - self.tokens # 还差多少 token 才能发下一枚 + sleep_time = max(0.0, missing / self.rate_limit) + + # 出锁之后再唤醒,避免在锁里执行用户代码 / 回调 + if fut_to_wake is not None and not fut_to_wake.done(): + fut_to_wake.set_result(None) + + # 如果刚刚唤醒了一个协程,立刻回到循环,看是否还能继续发 + if sleep_time == 0.0: + continue + + # 没 token,就等到有 token 再继续 + await asyncio.sleep(sleep_time) + finally: + # 兜底,避免异常时 drainer_running 一直是 True 导致无法重启 + async with self._lock: + self._drainer_running = False + + # ---------- 对外接口 ---------- + + async def acquire(self) -> None: + """ + 获取一个 token(公平:排队 FIFO) + """ + loop = asyncio.get_running_loop() + + # 先尝试直接拿 token(快速路径) + async with self._lock: + self._refill_unlocked() + + # 如果有 token 且没有历史排队的协程,直接拿走返回 + if self.tokens >= 1 and not self._waiters: + self.tokens -= 1 + return + + # 否则需要排队 + fut = loop.create_future() + self._waiters.append(fut) + + # 启动 drainer(只要一个就够了) + if not self._drainer_running: + self._drainer_running = True + asyncio.create_task(self._drain_waiters()) + + # 等待被 drainer 唤醒,唤醒后说明自己拿到了 token + await fut + + +# --- 复用你原本的辅助工具 --- +_loop = None + + +def _get_event_loop(): + try: + event_loop = asyncio.get_event_loop() + except Exception: + event_loop = asyncio.new_event_loop() + asyncio.set_event_loop(event_loop) + + if event_loop.is_running(): + global _loop + if _loop: + return _loop + from threading import Thread + + def _start_loop(loop): + asyncio.set_event_loop(loop) + loop.run_forever() + + event_loop = asyncio.new_event_loop() + Thread(target=_start_loop, args=(event_loop,), daemon=True).start() + _loop = event_loop + return event_loop + + +class AsyncMCPClient(AsyncActionMixin, BaseAction): + """ + Standard Lagent Action that wraps a SINGLE tool from an MCP Server. + + Robustness Fix: + Creates a new connection for every request and closes it immediately after. + This prevents connection leaks and 'ConnectTimeout' in high-concurrency RL environments. + """ + + + def __init__( + self, + server_type: ServerType, + rate_limit: float = None, + max_concurrency: int = None, + # 注意:这里的 name 主要用于 Lagent 注册,但工具的实际元数据来自 MCP Server + name: Optional[str] = None, + **server_params, + ): + self._is_toolkit = False + self.server_type = server_type + self.server_params = server_params + + # 并发控制组件 + self.rate_limiter = FairAsyncTokenBucket(rate_limit) if rate_limit is not None else None + self._sem = asyncio.Semaphore(max_concurrency) if max_concurrency is not None else nullcontext() + + # 1. 临时连接获取工具元数据 (Metadata) + # 必须在 __init__ 完成,因为 Lagent 需要 self.description + loop = _get_event_loop() + if loop.is_running(): + fut = asyncio.run_coroutine_threadsafe(self._fetch_tool_metadata(), loop) + tools = fut.result() + else: + tools = loop.run_until_complete(self._fetch_tool_metadata()) + + # Single Action 约束:一个 Action 实例对应一个 MCP 工具 + if len(tools) != 1: + logger.warning( + f"MCP Server returned {len(tools)} tools, but AsyncMCPAction is designed for a Single Action. " + f"Using the first one: {tools[0].name}" + ) + + self.tool_info = tools[0] + tool_name = self.tool_info.name + description = { + 'name': tool_name, + 'description': self.tool_info.description, + 'parameters': [ + {'name': k, 'type': v['type'].upper(), 'description': v.get('description', '')} + for k, v in self.tool_info.inputSchema['properties'].items() + ], + 'required': self.tool_info.inputSchema.get('required', []), + } + # 2. 初始化父类 BaseAction + super().__init__( + description=description, + parser=JsonParser, + ) + self._is_toolkit = False + + async def _connect(self, stack: AsyncExitStack): + """ + 内部辅助:建立连接并注册关闭回调。 + 所有网络资源都注册到 `stack` 中,确保自动释放。 + """ + from mcp import ClientSession, StdioServerParameters + + # --- Transport Layer --- + if self.server_type == "stdio": + from mcp.client.stdio import stdio_client + + logger.info( + f"Connecting to stdio MCP server with command: {self.server_params['command']} " + f"{self.server_params.get('args', [])}" + ) + client_kwargs = {"command": self.server_params["command"]} + for key in ["args", "env", "cwd"]: + if self.server_params.get(key) is not None: + client_kwargs[key] = self.server_params[key] + + server_params_obj = StdioServerParameters(**client_kwargs) + read, write = await stack.enter_async_context(stdio_client(server_params_obj)) + + elif self.server_type == "sse": + from mcp.client.sse import sse_client + + logger.info(f"Connecting to SSE MCP server at: {self.server_params['url']}") + + url = self.server_params["url"] + target_url = random.choice(url) if isinstance(url, list) else url + + client_kwargs = {"url": target_url} + for key in ["headers", "timeout", "sse_read_timeout"]: + if self.server_params.get(key) is not None: + client_kwargs[key] = self.server_params[key] + + read, write = await stack.enter_async_context(sse_client(**client_kwargs)) + + elif self.server_type == "http": + from mcp.client.streamable_http import streamablehttp_client + + logger.info(f"Connecting to StreamableHTTP MCP server at: {self.server_params['url']}") + + url = self.server_params["url"] + target_url = random.choice(url) if isinstance(url, list) else url + + client_kwargs = {"url": target_url} + for key in ["headers", "timeout", "sse_read_timeout", "terminate_on_close"]: + if self.server_params.get(key) is not None: + client_kwargs[key] = self.server_params[key] + + read, write, _ = await stack.enter_async_context(streamablehttp_client(**client_kwargs)) + + else: + raise ValueError(f"Unsupported server type: {self.server_type}") + + # --- Protocol Layer --- + session = await stack.enter_async_context(ClientSession(read, write)) + await session.initialize() + return session + + async def _fetch_tool_metadata(self): + """在 init 阶段使用一次性连接获取工具定义""" + async with AsyncExitStack() as stack: + session = await self._connect(stack) + result = await session.list_tools() + return result.tools + + async def run(self, **kwargs) -> ActionReturn: + """ + Standard Lagent Action Entrypoint. + """ + fallback_args = kwargs.copy() + + try: + # 1. 并发/速率控制 + async with self._sem: + if self.rate_limiter is not None: + await self.rate_limiter.acquire() + + # 2. 执行逻辑 (Critical Resource Scope) + # 使用 AsyncExitStack 确保本次请求结束后,HTTP连接/进程管道被彻底关闭 + async with AsyncExitStack() as stack: + session = await self._connect(stack) + + # 调用 MCP 工具 + # 注意:Lagent 传入的是 kwargs 字典,MCP call_tool 正好接受字典 + outputs_obj = await session.call_tool(self.tool_info.name, kwargs) + + # 提取文本结果 + if outputs_obj.content and hasattr(outputs_obj.content[0], 'text'): + outputs = outputs_obj.content[0].text + else: + outputs = str(outputs_obj) + + except ParseError as exc: + return ActionReturn(fallback_args, type=self.name, errmsg=exc.err_msg, state=ActionStatusCode.ARGS_ERROR) + except Exception as exc: + # 记录详细堆栈以便调试 RL 过程中的错误 + logger.warning(f"MCP Action {self.name} failed: {exc}") + return ActionReturn(fallback_args, type=self.name, errmsg=str(exc), state=ActionStatusCode.API_ERROR) + + # 3. 结果封装 + if isinstance(outputs, ActionReturn): + action_return = outputs + if not action_return.args: + action_return.args = kwargs + if not action_return.type: + action_return.type = self.name + else: + # 尝试使用 JsonParser 解析结果(如果 MCP 返回的是 JSON 字符串) + # 否则直接作为字符串返回 + try: + result = self._parser.parse_outputs(outputs) + except: + result = str(outputs) + + action_return = ActionReturn(fallback_args, type=self.name, result=result) + + return action_return + + + +class AsyncMCPClientSandbox(AsyncMCPClient): + """ + Sandbox MCP Action that wraps a SINGLE tool from an MCP Server. + + Maintains a persistent connection to a remote sandbox environment. + One instance = one persistent connection (no session pool). + Call ``connect()`` to establish the connection and ``close()`` to tear it down. + The connection is also lazily created on the first ``run()`` call. + """ + + + def __init__( + self, + server_type: ServerType, + rate_limit: float = None, + max_concurrency: int = None, + name: Optional[str] = None, + init_dir: Optional[str] = None, + **server_params, + ): + # 1. 初始化父类(会自动获取 metadata 并初始化 BaseAction) + super().__init__( + server_type=server_type, + rate_limit=rate_limit, + max_concurrency=max_concurrency, + name=name, + **server_params + ) + + # 2. 单连接状态 + self.init_dir = init_dir + self._session: object | None = None # MCP ClientSession + self._request_queue: asyncio.Queue | None = None + self._worker_task: asyncio.Task | None = None + self._lock = asyncio.Lock() + self._connected = False + + # ------------------------------------------------------------------ + # Connection lifecycle + # ------------------------------------------------------------------ + + async def connect(self): + """Establish the persistent MCP connection (idempotent).""" + async with self._lock: + if self._connected: + return self._session + + logger.info(f"Creating persistent MCP connection for {self.name}") + + # anyio 要求进入和退出 CancelScope 在同一个 Task, + # 所以用后台 worker Task 持有 AsyncExitStack 的生命周期。 + request_queue = asyncio.Queue() + response_queue: asyncio.Queue = asyncio.Queue() + + async def _connection_worker(): + try: + async with AsyncExitStack() as stack: + session = await self._connect(stack) + await response_queue.put(session) + # 等待关闭信号 + while True: + msg = await request_queue.get() + if msg == "close": + break + except Exception as e: + await response_queue.put(e) + + worker_task = asyncio.create_task(_connection_worker()) + + result = await response_queue.get() + if isinstance(result, Exception): + raise result + + self._session = result + self._request_queue = request_queue + self._worker_task = worker_task + self._connected = True + + # 初始化远程工作目录 + if self.init_dir: + await self._initialize_dir(self._session) + + return self._session + + async def _initialize_dir(self, session): + """将本地目录打包并同步到远程 Session""" + import tarfile + import io + import base64 + import os + + if not self.init_dir or not os.path.exists(self.init_dir): + logger.warning(f"Init dir {self.init_dir} not found or not set, skipping initialization.") + return + + logger.info(f"Initializing MCP session with directory {self.init_dir}") + + buf = io.BytesIO() + dir_name = os.path.basename(os.path.normpath(self.init_dir)) + with tarfile.open(fileobj=buf, mode='w:gz') as tar: + tar.add(self.init_dir, arcname=dir_name) + + encoded = base64.b64encode(buf.getvalue()).decode('utf-8') + init_cmd = f"echo '{encoded}' | base64 -d | tar -xz" + + try: + await session.call_tool(self.tool_info.name, {"command": init_cmd}) + logger.info("MCP session initialized successfully.") + except Exception as e: + logger.error(f"Failed to initialize MCP session: {e}") + + async def close(self): + """关闭连接,释放服务端资源""" + async with self._lock: + if not self._connected: + return + logger.info(f"Closing persistent MCP connection for {self.name}") + self._connected = False + session = self._session + self._session = None + + if self._request_queue is not None: + await self._request_queue.put("close") + if self._worker_task is not None: + try: + await self._worker_task + except Exception as e: + logger.warning(f"Error while closing MCP connection: {e}") + self._request_queue = None + self._worker_task = None + + # backward compat + async def close_session(self, session_id: str | None = None): + """Deprecated — use ``close()`` instead.""" + await self.close() + + # ------------------------------------------------------------------ + # run — no session_id + # ------------------------------------------------------------------ + + async def run(self, **kwargs) -> ActionReturn: + """ + Standard Lagent Action Entrypoint for stateful execution. + """ + fallback_args = kwargs.copy() + + try: + async with self._sem: + if self.rate_limiter is not None: + await self.rate_limiter.acquire() + + # 懒连接 + session = await self.connect() + + outputs_obj = await session.call_tool(self.tool_info.name, kwargs) + + if outputs_obj.content and hasattr(outputs_obj.content[0], 'text'): + outputs = outputs_obj.content[0].text + else: + outputs = str(outputs_obj) + + except ParseError as exc: + return ActionReturn(fallback_args, type=self.name, errmsg=exc.err_msg, state=ActionStatusCode.ARGS_ERROR) + except Exception as exc: + logger.warning(f"MCP Action {self.name} failed: {exc}") + return ActionReturn(fallback_args, type=self.name, errmsg=str(exc), state=ActionStatusCode.API_ERROR) + + # 结果封装 + if isinstance(outputs, ActionReturn): + action_return = outputs + if not action_return.args: + action_return.args = kwargs + if not action_return.type: + action_return.type = self.name + else: + try: + result = self._parser.parse_outputs(outputs) + except: + result = str(outputs) + + action_return = ActionReturn(fallback_args, type=self.name, result=result) + + return action_return +if __name__ == '__main__': + import asyncio + import json + from lagent.agents.internclaw_agent import get_tool_prompt + + action = AsyncMCPClientSandbox("http", url='http://simple-shell.ailab.ailab.ai/mcp', init_dir="/mnt/shared-storage-user/llmit/user/liukuikun/workspace/lagent/workspace/") + print(get_tool_prompt([action])) + + async def test(): + res = await action.run(command='ls -a') + home_path = json.loads(res.result[0]['content'])['cwd'] + print(res) + res = await action.run(command='echo hello') + print(res) + await action.close() + asyncio.run(test()) + action1 = AsyncMCPClientSandbox('http', url="http://hb-3d-scan-calc.ailab.ailab.ai") + \ No newline at end of file diff --git a/lagent/actions/sandbox_executor.py b/lagent/actions/sandbox_executor.py new file mode 100644 index 00000000..c0e76be2 --- /dev/null +++ b/lagent/actions/sandbox_executor.py @@ -0,0 +1,382 @@ +"""SandboxActionExecutor — drop-in replacement for ``AsyncActionExecutor`` +that routes action calls to an :class:`~lagent.serving.sandbox.daemon.ActionDaemon` +running inside a remote sandbox. + +Communication goes through the sandbox's bash execution channel:: + + SandboxActionExecutor.forward("shell", {"command": "ls"}) + → sandbox_client.execute('python -m lagent.serving.sandbox.daemon call ...') + → daemon inside sandbox executes ShellAction locally + → JSON result flows back through stdout + +Usage:: + + from lagent.serving.sandbox.executor import SandboxActionExecutor + + executor = SandboxActionExecutor( + sandbox_client=sandbox_client, # your SandboxClient instance + actions_config=[ + {"type": "lagent.actions.shell.ShellAction"}, + {"type": "lagent.actions.ipython_interpreter.AsyncIPythonInterpreter"}, + ], + ) + await executor.connect() + # Now use as AsyncActionExecutor + +The ``sandbox_client`` can be **sync or async**. If ``execute()`` is a +regular (sync) method it will be called via ``asyncio.to_thread`` so it +never blocks the event loop. The return value may be a ``str`` (raw +stdout) or a ``dict`` with a ``"stdout"`` key — both are handled. +""" + +from __future__ import annotations + +import asyncio +import inspect +import json +import logging +from collections import OrderedDict +from typing import Any, Dict, List, Optional, Union + +from lagent.actions.builtin_actions import FinishAction, InvalidAction, NoAction +from lagent.hooks import Hook, RemovableHandle +from lagent.schema import ( + ActionReturn, + ActionStatusCode, + ActionValidCode, + AgentMessage, + FunctionCall, +) +from lagent.utils import create_object + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _deserialize_action_return(text: str) -> ActionReturn: + """Deserialize a JSON string back into an ActionReturn.""" + data = json.loads(text) + if data.get("state") is not None: + data["state"] = ActionStatusCode(data["state"]) + if data.get("valid") is not None: + data["valid"] = ActionValidCode(data["valid"]) + return ActionReturn(**data) + + +class _ToolDescriptionStub: + """Minimal stub that quacks like a BaseAction for ``get_tool_prompt()``. + + ``create_object()`` returns non-dict inputs as-is, so these stubs + pass through correctly. + """ + + def __init__(self, desc: dict): + self._desc = desc + + @property + def name(self) -> str: + return self._desc["name"] + + @property + def is_toolkit(self) -> bool: + return False + + @property + def description(self) -> dict: + return self._desc + + +class _ToolDescriptionProxy: + """Makes ``executor.actions`` behave like ``Dict[str, BaseAction]`` + for code that iterates ``.values()`` and reads ``.name`` / ``.description``. + """ + + def __init__(self, descriptions: Dict[str, dict]): + self._descriptions = descriptions + + def __contains__(self, name: str) -> bool: + return name in self._descriptions + + def values(self): + return [_ToolDescriptionStub(d) for d in self._descriptions.values()] + + def keys(self): + return list(self._descriptions.keys()) + + def items(self): + return [(k, _ToolDescriptionStub(v)) for k, v in self._descriptions.items()] + + def __getitem__(self, key: str): + return _ToolDescriptionStub(self._descriptions[key]) + + +# --------------------------------------------------------------------------- +# SandboxActionExecutor +# --------------------------------------------------------------------------- + + +class SandboxActionExecutor: + """Drop-in replacement for ``AsyncActionExecutor`` that routes action + calls to an :class:`ActionDaemon` running inside a sandbox. + + Parameters + ---------- + sandbox_client + Object with an ``execute(command, ...)`` method that runs bash + inside the sandbox. May be sync or async. The return value can + be a plain ``str`` (stdout) **or** a ``dict`` with a ``"stdout"`` + key (like the real ``SandboxClient.execute``). + actions_config : list of dict + Action configurations (same format as ActionDaemon's config file). + sock_path : str + Unix socket path inside the sandbox. + cwd : str + Working directory for bash commands inside the sandbox. + daemon_module : str + Python module path for the daemon entry point. + """ + + def __init__( + self, + sandbox_client, + actions_config: List[Dict], + sock_path: str = "/tmp/lagent_action.sock", + cwd: str = "/root", + daemon_module: str = "lagent.serving.sandbox.daemon", + invalid_action=dict(type=InvalidAction), + no_action=dict(type=NoAction), + finish_action=dict(type=FinishAction), + hooks: Optional[List[Dict]] = None, + ): + self.sandbox_client = sandbox_client + self.actions_config = actions_config + self.sock_path = sock_path + self.cwd = cwd + self.daemon_module = daemon_module + + # Built-in actions handled locally + self.invalid_action = create_object(invalid_action) + self.no_action = create_object(no_action) + self.finish_action = create_object(finish_action) + + # Hooks + self._hooks: Dict[int, Hook] = OrderedDict() + if hooks: + for hook in hooks: + hook = create_object(hook) + self.register_hook(hook) + + # Connection state + self._connected = False + self._lock = asyncio.Lock() + + # Tool descriptions (populated on connect) + self._tool_descriptions: Dict[str, dict] = {} + + # -- sandbox communication helpers -- + + async def _exec(self, command: str, **kwargs) -> str: + """Execute a command via the sandbox client. + + Handles both sync and async clients, and normalises the return + value to a plain stdout string. + """ + kwargs.setdefault("cwd", self.cwd) + execute_fn = self.sandbox_client.execute + if inspect.iscoroutinefunction(execute_fn): + result = await execute_fn(command, **kwargs) + else: + result = await asyncio.to_thread(execute_fn, command, **kwargs) + + # Normalise: dict with "stdout" → str + if isinstance(result, dict): + stdout = result.get("stdout", "") + if not stdout.strip() and result.get("stderr", "").strip(): + raise RuntimeError(f"Command stderr: {result['stderr'][:500]}") + return stdout + return result + + # -- connection lifecycle -- + + async def connect(self) -> None: + """Start the daemon inside the sandbox (idempotent) and fetch tool list.""" + async with self._lock: + if self._connected: + return + + # 1. Write actions config to sandbox + config_json = json.dumps(self.actions_config, ensure_ascii=False) + escaped_config = config_json.replace("'", "'\\''") + await self._exec( + f"echo '{escaped_config}' > /tmp/lagent_actions_config.json" + ) + + # 2. Check if daemon already running, start if not + check = await self._exec( + f"pgrep -f '[l]agent.serving.sandbox.daemon.*--sock {self.sock_path}'" + f" > /dev/null 2>&1 && echo 'running' || echo 'stopped'" + ) + if "stopped" in check: + await self._exec( + f"nohup python -m {self.daemon_module} start " + f"--sock {self.sock_path} " + f"--actions-config /tmp/lagent_actions_config.json " + f"> /tmp/lagent_daemon.log 2>&1 &" + ) + + # 3. Wait for socket to be ready + for _ in range(30): + try: + output = await self._exec( + f"test -S {self.sock_path} && echo 'ready' || echo 'waiting'" + ) + if "ready" in output: + break + except Exception: + pass + await asyncio.sleep(0.5) + else: + raise TimeoutError( + f"ActionDaemon did not start within 15s. " + f"Check /tmp/lagent_daemon.log inside sandbox." + ) + + # 4. Fetch tool list from daemon + result = await self._daemon_call({"cmd": "list_tools"}) + tools = result.get("tools", []) + self._tool_descriptions = {t["name"]: t for t in tools} + self._connected = True + logger.info( + "SandboxActionExecutor connected, %d tools: %s", + len(tools), + list(self._tool_descriptions.keys()), + ) + + async def close(self) -> None: + """Shut down the daemon.""" + if not self._connected: + return + try: + await self._daemon_call({"cmd": "shutdown"}) + except Exception: + pass + self._connected = False + self._tool_descriptions.clear() + logger.info("SandboxActionExecutor closed") + + # -- daemon communication -- + + async def _daemon_call(self, request: dict, timeout_sec: int = 300) -> dict: + """Send a JSON request to the daemon via bash and parse response.""" + request_json = json.dumps(request, ensure_ascii=False) + escaped = request_json.replace("'", "'\\''") + output = await self._exec( + f"python -m {self.daemon_module} call " + f"--sock {self.sock_path} " + f"'{escaped}'", + timeout_sec=timeout_sec, + ) + if not output.strip(): + raise RuntimeError("Daemon returned empty response (may have crashed or timed out)") + return json.loads(output.strip()) + + # -- AsyncActionExecutor-compatible interface -- + + @property + def actions(self) -> _ToolDescriptionProxy: + return _ToolDescriptionProxy(self._tool_descriptions) + + def __contains__(self, name: str) -> bool: + return name in self._tool_descriptions + + def keys(self) -> List[str]: + return list(self._tool_descriptions.keys()) + + def description(self) -> List[Dict]: + return list(self._tool_descriptions.values()) + + async def forward(self, name: str, parameters: dict, **kwargs) -> ActionReturn: + action_name = name.split(".")[0] if "." in name else name + + # Built-in actions stay local + if action_name not in self: + if name == self.no_action.name: + return self.no_action(parameters) + elif name == self.finish_action.name: + return self.finish_action(parameters) + else: + return self.invalid_action(parameters) + + # Ensure daemon is running + if not self._connected: + await self.connect() + + # Route to daemon + try: + result = await self._daemon_call( + {"name": name, "parameters": parameters} + ) + if "error" in result: + return ActionReturn( + args=parameters, + type=name, + errmsg=result["error"], + state=ActionStatusCode.API_ERROR, + ) + action_return = _deserialize_action_return(json.dumps(result)) + action_return.valid = ActionValidCode.OPEN + return action_return + except Exception as exc: + logger.warning("SandboxActionExecutor: %s failed: %s", name, exc) + return ActionReturn( + args=parameters, + type=name, + errmsg=str(exc), + state=ActionStatusCode.API_ERROR, + ) + + async def __call__(self, message: AgentMessage, **kwargs) -> AgentMessage: + for hook in self._hooks.values(): + if inspect.iscoroutinefunction(hook.before_action): + result = await hook.before_action(self, message) + else: + result = hook.before_action(self, message) + if result: + message = result + + assert isinstance(message.content, FunctionCall) or ( + isinstance(message.content, dict) + and "name" in message.content + and "parameters" in message.content + ) + if isinstance(message.content, dict): + name = message.content.get("name") + parameters = message.content.get("parameters") + else: + name = message.content.name + parameters = message.content.parameters + + response_message = await self.forward(name=name, parameters=parameters, **kwargs) + if not isinstance(response_message, AgentMessage): + response_message = AgentMessage( + sender=self.__class__.__name__, + content=response_message, + ) + + for hook in self._hooks.values(): + if inspect.iscoroutinefunction(hook.after_action): + result = await hook.after_action(self, response_message) + else: + result = hook.after_action(self, response_message) + if result: + response_message = result + return response_message + + def register_hook(self, hook): + handle = RemovableHandle(self._hooks) + self._hooks[handle.id] = hook + return handle diff --git a/lagent/actions/save_memory.py b/lagent/actions/save_memory.py new file mode 100644 index 00000000..ebb1b7ce --- /dev/null +++ b/lagent/actions/save_memory.py @@ -0,0 +1,127 @@ +"""SaveMemoryAction — OpenClaw memory write action. + +Writes to MEMORY.md + HISTORY.md in the workspace memory directory. +Operates independently from ``OpenClawMemoryProvider`` (which reads). +Both are initialized with the same workspace path by the caller. + +Usage:: + + provider = OpenClawMemoryProvider(workspace_path) + save_action = SaveMemoryAction(workspace_path) + + env = AsyncEnvAgent( + actions=[save_action, ...], + long_term_memory=provider, + ) +""" + +from __future__ import annotations + +import asyncio +from pathlib import Path +from typing import Optional, Type + +from lagent.actions.base_action import AsyncActionMixin, BaseAction, tool_api +from lagent.actions.parser import BaseParser, JsonParser +from lagent.schema import ActionReturn, ActionStatusCode + + +class SaveMemoryAction(BaseAction): + """Write consolidated memory to OpenClaw's MEMORY.md + HISTORY.md. + + Parameters + ---------- + workspace : Path + Workspace directory containing the ``memory/`` subdirectory. + """ + + def __init__( + self, + workspace: Path, + description: Optional[dict] = None, + parser: Type[BaseParser] = JsonParser, + ) -> None: + super().__init__(description, parser) + memory_dir = Path(workspace) / "memory" + memory_dir.mkdir(parents=True, exist_ok=True) + self._memory_file = memory_dir / "MEMORY.md" + self._history_file = memory_dir / "HISTORY.md" + + async def _save(self, history_entry: str, memory_update: str) -> None: + if memory_update: + await asyncio.to_thread( + self._memory_file.write_text, memory_update, encoding="utf-8" + ) + if history_entry: + def _append(): + with open(self._history_file, "a", encoding="utf-8") as f: + f.write(history_entry.rstrip() + "\n\n") + await asyncio.to_thread(_append) + + @tool_api + def run( + self, + history_entry: str = '', + memory_update: str = '', + ) -> ActionReturn: + """Save memory consolidation result to persistent storage. + + Args: + history_entry: A paragraph (2-5 sentences) summarizing key + events/decisions/topics. Start with [YYYY-MM-DD HH:MM]. + Include detail useful for grep search. + memory_update: Full updated long-term memory as markdown. + Include all existing facts plus new ones. Return + unchanged if nothing new. + + Returns: + ActionReturn with confirmation message. + """ + try: + asyncio.run(self._save(history_entry, memory_update)) + return ActionReturn( + type=self.name, + result=[dict(type="text", content="Memory saved successfully.")], + ) + except Exception as exc: + return ActionReturn( + type=self.name, + errmsg=f"Failed to save memory: {exc}", + state=ActionStatusCode.API_ERROR, + ) + + +class AsyncSaveMemoryAction(AsyncActionMixin, SaveMemoryAction): + """Async version of SaveMemoryAction.""" + + @tool_api + async def run( + self, + history_entry: str = '', + memory_update: str = '', + ) -> ActionReturn: + """Save memory consolidation result to persistent storage. + + Args: + history_entry: A paragraph (2-5 sentences) summarizing key + events/decisions/topics. Start with [YYYY-MM-DD HH:MM]. + Include detail useful for grep search. + memory_update: Full updated long-term memory as markdown. + Include all existing facts plus new ones. Return + unchanged if nothing new. + + Returns: + ActionReturn with confirmation message. + """ + try: + await self._save(history_entry, memory_update) + return ActionReturn( + type=self.name, + result=[dict(type="text", content="Memory saved successfully.")], + ) + except Exception as exc: + return ActionReturn( + type=self.name, + errmsg=f"Failed to save memory: {exc}", + state=ActionStatusCode.API_ERROR, + ) diff --git a/lagent/actions/send_message.py b/lagent/actions/send_message.py new file mode 100644 index 00000000..6e447eef --- /dev/null +++ b/lagent/actions/send_message.py @@ -0,0 +1,157 @@ +"""SendMessageAction -- agent-facing toolkit for inter-agent messaging. + +Wraps :class:`~lagent.services.mailbox.Mailbox` to expose message +sending as ``@tool_api`` methods. + +Usage:: + + mailbox = Mailbox() + action = SendMessageAction(mailbox, agent_name="coder") + executor = AsyncActionExecutor(actions=[action, ...]) +""" + +from __future__ import annotations + +import logging +from typing import Annotated, Optional, Type + +from lagent.actions.base_action import AsyncActionMixin, BaseAction, tool_api +from lagent.actions.parser import BaseParser, JsonParser +from lagent.schema import ActionReturn, ActionStatusCode, AgentMessage +from lagent.services.mailbox import Mailbox + +logger = logging.getLogger("lagent.actions.send_message") + + +class SendMessageAction(BaseAction): + """Send and receive messages to/from other agents. + + The ``allowed_receivers`` parameter controls the communication + topology: + + * ``None`` — fully distributed, can message any agent. + * ``["lead"]`` — centralised, can only message the lead. + * ``["lead", "tester"]`` — partial mesh. + """ + + def __init__( + self, + mailbox: Mailbox, + agent_name: str, + allowed_receivers: list[str] | None = None, + description: Optional[dict] = None, + parser: Type[BaseParser] = JsonParser, + ) -> None: + super().__init__(description, parser) + self._mailbox = mailbox + self._agent_name = agent_name + self._allowed = allowed_receivers + + # Ensure this agent is registered in the mailbox + self._mailbox.register(agent_name) + + @tool_api + def send( + self, + to: Annotated[ + str, + "Target agent name, or '*' to broadcast to all agents", + ], + content: Annotated[ + str, + "The message content to send", + ], + ) -> ActionReturn: + """Send a message to another agent. + + Use this to ask questions, share findings, request help, or + coordinate work with teammates. The recipient will see the + message in their next execution cycle. + + Args: + to: Recipient agent name or '*' for broadcast. + content: Message text. + + Returns: + ActionReturn confirming delivery. + """ + # Topology enforcement + if ( + self._allowed is not None + and to != "*" + and to not in self._allowed + ): + return ActionReturn( + type=self.name, + errmsg=( + f"Cannot send to {to!r}. " + f"Allowed recipients: {self._allowed}" + ), + state=ActionStatusCode.API_ERROR, + ) + + try: + msg = AgentMessage( + sender=self._agent_name, + receiver=to, + content=content, + ) + self._mailbox.send(msg) + + if to == "*": + targets = [ + n for n in self._mailbox.agents + if n != self._agent_name + ] + return ActionReturn( + type=self.name, + result=[dict( + type="text", + content=f"Broadcast sent to {len(targets)} agents.", + )], + ) + + return ActionReturn( + type=self.name, + result=[dict( + type="text", + content=f"Message sent to @{to}.", + )], + ) + except Exception as exc: + return ActionReturn( + type=self.name, + errmsg=f"Failed to send message: {exc}", + state=ActionStatusCode.API_ERROR, + ) + + @tool_api + def check_inbox(self) -> ActionReturn: + """Check for pending messages without consuming them. + + Use this to see if any teammate has sent you a message. + Messages will be fully delivered through the environment + context automatically; this tool is for a quick preview. + + Returns: + ActionReturn listing pending messages or 'no messages'. + """ + msgs = self._mailbox.peek(self._agent_name) + if not msgs: + return ActionReturn( + type=self.name, + result=[dict(type="text", content="No new messages.")], + ) + lines = [] + for m in msgs: + lines.append(f"- From @{m.sender}: {m.content}") + header = f"{len(msgs)} pending message(s):\n" + return ActionReturn( + type=self.name, + result=[dict(type="text", content=header + "\n".join(lines))], + ) + + +class AsyncSendMessageAction(AsyncActionMixin, SendMessageAction): + """Async version of :class:`SendMessageAction`.""" + pass diff --git a/lagent/actions/shell.py b/lagent/actions/shell.py new file mode 100644 index 00000000..c57c6935 --- /dev/null +++ b/lagent/actions/shell.py @@ -0,0 +1,164 @@ +import asyncio +import os +import re +from pathlib import Path +from typing import Any, List, Optional, Type + +from asyncer import asyncify + +from lagent.actions.base_action import AsyncActionMixin, BaseAction, tool_api +from lagent.actions.parser import BaseParser, JsonParser +from lagent.schema import ActionReturn, ActionStatusCode + +class ShellAction(AsyncActionMixin, BaseAction): + """Tool to execute shell commands.""" + + def __init__( + self, + timeout: int = 60, + working_dir: Optional[str] = None, + deny_patterns: Optional[List[str]] = None, + allow_patterns: Optional[List[str]] = None, + restrict_to_workspace: bool = False, + path_append: str = "", + description: Optional[dict] = None, + parser: Type[BaseParser] = JsonParser, + ): + super().__init__(description, parser) + self.timeout = timeout + self.working_dir = working_dir + self.deny_patterns = deny_patterns or [ + r"\brm\s+-[rf]{1,2}\b", # rm -r, rm -rf, rm -fr + r"\bdel\s+/[fq]\b", # del /f, del /q + r"\brmdir\s+/s\b", # rmdir /s + r"(?:^|[;&|]\s*)format\b", # format (as standalone command only) + r"\b(mkfs|diskpart)\b", # disk operations + r"\bdd\s+if=", # dd + r">\s*/dev/sd", # write to disk + r"\b(shutdown|reboot|poweroff)\b", # system power + r":\(\)\s*\{.*\};\s*:", # fork bomb + ] + self.allow_patterns = allow_patterns or [] + self.restrict_to_workspace = restrict_to_workspace + self.path_append = path_append + + def _guard_command(self, command: str, cwd: str) -> Optional[str]: + """Best-effort safety guard for potentially destructive commands.""" + cmd = command.strip() + lower = cmd.lower() + + for pattern in self.deny_patterns: + if re.search(pattern, lower): + return "Error: Command blocked by safety guard (dangerous pattern detected)" + + if self.allow_patterns: + if not any(re.search(p, lower) for p in self.allow_patterns): + return "Error: Command blocked by safety guard (not in allowlist)" + + if self.restrict_to_workspace: + if "..\\" in cmd or "../" in cmd: + return "Error: Command blocked by safety guard (path traversal detected)" + + cwd_path = Path(cwd).resolve() + # A rough implementation of checking absolute paths + paths = re.findall(r'(/[^\s]+)', cmd) + for raw in paths: + try: + p = Path(raw.strip()).resolve() + except Exception: + continue + if p.is_absolute() and cwd_path not in p.parents and p != cwd_path: + return "Error: Command blocked by safety guard (path outside working dir)" + return None + + @tool_api + @asyncify + def run(self, command: str, working_dir: Optional[str] = None) -> ActionReturn: + """Execute a shell command and return its output. Use with caution. + + Args: + command (str): The shell command to execute. + working_dir (str, optional): Optional working directory for the command. + """ + cwd = working_dir or self.working_dir or os.getcwd() + guard_error = self._guard_command(command, cwd) + if guard_error: + return ActionReturn( + type=self.name, + errmsg=guard_error, + state=ActionStatusCode.API_ERROR + ) + + env = os.environ.copy() + if self.path_append: + env["PATH"] = env.get("PATH", "") + os.pathsep + self.path_append + + # We need an event loop, since run needs to be synchronously resolvable but internally we use async logic. + # Since it's decorated with @asyncify, we can actually write sync code, but since the original used + # asyncio.create_subprocess_shell, we'll wrap it. + + try: + result = asyncio.run(self._execute_async(command, cwd, env)) + return ActionReturn( + type=self.name, + result=[dict(type='text', content=result)], + state=ActionStatusCode.SUCCESS + ) + except Exception as e: + return ActionReturn( + type=self.name, + errmsg=f"Error executing command: {str(e)}", + state=ActionStatusCode.API_ERROR + ) + + async def _execute_async(self, command: str, cwd: str, env: dict) -> str: + process = await asyncio.create_subprocess_shell( + command, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + cwd=cwd, + env=env, + ) + + try: + stdout, stderr = await asyncio.wait_for( + process.communicate(), + timeout=self.timeout + ) + except asyncio.TimeoutError: + process.kill() + try: + await asyncio.wait_for(process.wait(), timeout=5.0) + except asyncio.TimeoutError: + pass + raise TimeoutError(f"Command timed out after {self.timeout} seconds") + + output_parts = [] + + if stdout: + output_parts.append(stdout.decode("utf-8", errors="replace")) + + if stderr: + stderr_text = stderr.decode("utf-8", errors="replace") + if stderr_text.strip(): + output_parts.append(f"STDERR:\n{stderr_text}") + + if process.returncode != 0: + output_parts.append(f"\nExit code: {process.returncode}") + + result = "\n".join(output_parts) if output_parts else "(no output)" + + # Truncate very long output + max_len = 10000 + if len(result) > max_len: + result = result[:max_len] + f"\n... (truncated, {len(result) - max_len} more chars)" + + return result + + +if __name__ == "__main__": + import asyncio + action = ShellAction(timeout=10, restrict_to_workspace=True) + async def test(): + result = await action.run("echo Hello World && ls -la && sleep 2 && echo Done") + asyncio.run(test()) \ No newline at end of file diff --git a/lagent/actions/subagent.py b/lagent/actions/subagent.py new file mode 100644 index 00000000..b14bc4bf --- /dev/null +++ b/lagent/actions/subagent.py @@ -0,0 +1,353 @@ +"""AgentAction — agent-facing toolkit for managing sub-agents. + +Short-lived Action created per-request by the Dispatcher. Wraps the +long-lived :class:`AgentService` singleton, analogous to how +:class:`CronAction` wraps :class:`CronService`. + +The agent can spawn, list, query, and resume sub-agents through this +toolkit. Each ``@tool_api`` method maps to an ``AgentService`` operation. + +Usage:: + + agent_action = AgentAction( + agent_service=app.agent_service, + ) + executor = AsyncActionExecutor(actions=[agent_action, ...]) +""" + +from typing import Annotated, Optional, Type + +from lagent.actions.base_action import AsyncActionMixin, BaseAction, tool_api +from lagent.actions.parser import BaseParser, JsonParser +from lagent.schema import ActionReturn, ActionStatusCode + +from ..services.agent import AgentService, AgentStatus + + +class AsyncAgentAction(AsyncActionMixin, BaseAction): + """Manage sub-agents: spawn, list, query, and resume. + + The main agent can delegate tasks to specialised sub-agents through + this toolkit. Each sub-agent runs independently with its own tools + and context, and returns a result when finished. + + Parameters + ---------- + agent_service : AgentService + The service managing agent lifecycle. + parent_actions : list, optional + Action instances available to the parent agent. When a + dynamic sub-agent is spawned, it inherits (a subset of) these. + If ``None``, dynamic agents get no tools. + """ + + def __init__( + self, + agent_service: AgentService, + parent_actions: list | None = None, + description: Optional[dict] = None, + parser: Type[BaseParser] = JsonParser, + ) -> None: + super().__init__(description, parser) + self._service = agent_service + self._parent_actions = parent_actions or [] + + # ── tool APIs ───────────────────────────────────────────────── + + @tool_api + async def spawn( + self, + task: Annotated[ + str, + "A clear, self-contained description of the task for the " + "sub-agent. Include all necessary context — the sub-agent " + "has no access to the current conversation history.", + ], + agent_type: Annotated[ + Optional[str], + "The type of agent to spawn. Must be one of the available " + "agent types. Omit to create a dynamic agent with " + "system_prompt.", + ] = None, + system_prompt: Annotated[ + Optional[str], + "System prompt for a dynamically created agent. Only used " + "when agent_type is not specified.", + ] = None, + tools: Annotated[ + Optional[str], + "Comma-separated tool names to give the sub-agent " + "(e.g. 'Shell,Read,Grep'). Only used for dynamic agents. " + "Omit to inherit the parent's full tool set.", + ] = None, + label: Annotated[ + str, + "A short human-readable label for the task. " + "Defaults to the first 40 characters of the task.", + ] = "", + mode: Annotated[ + str, + "Execution mode: 'sync' blocks until done (default), " + "'async' runs in the background.", + ] = "sync", + ) -> ActionReturn: + """Spawn a sub-agent to execute a task. + + Two ways to create the sub-agent: + + 1. Provide ``agent_type`` to use a pre-registered agent template + (with its own prompt, tools, and configuration). + 2. Omit ``agent_type`` and optionally provide ``system_prompt`` + and ``tools`` to dynamically create a lightweight agent. + + Args: + task: Self-contained task description. + agent_type: Type of pre-registered agent to spawn. + system_prompt: Custom system prompt for dynamic agents. + tools: Comma-separated tool name whitelist for dynamic agents. + label: Optional short label. + mode: 'sync' (default) or 'async'. + + Returns: + ActionReturn with the result or acknowledgement. + """ + # Resolve tool instances from names + resolved_tools = None + if tools and not agent_type: + allowed = {s.strip() for s in tools.split(",") if s.strip()} + resolved_tools = [ + a for a in self._parent_actions + if type(a).__name__ in allowed + ] + + # Build kwargs for spec.acreate() + spec_kwargs: dict = {} + if self._default_llm is not None: + spec_kwargs["llm"] = self._default_llm + if resolved_tools is not None: + spec_kwargs["actions"] = resolved_tools + elif self._parent_actions: + spec_kwargs["actions"] = list(self._parent_actions) + if system_prompt: + spec_kwargs["system_prompt"] = system_prompt + + try: + entry = await self._service.spawn( + task=task, + agent_type=agent_type or "default", + label=label, + mode=mode, + **spec_kwargs, + ) + except ValueError as exc: + return ActionReturn( + type=self.name, + errmsg=str(exc), + state=ActionStatusCode.API_ERROR, + ) + except Exception as exc: + return ActionReturn( + type=self.name, + errmsg=f"Failed to spawn agent: {exc}", + state=ActionStatusCode.API_ERROR, + ) + + if mode == "sync": + if entry.status == AgentStatus.FAILED: + return ActionReturn( + type=self.name, + errmsg=f"Agent failed: {entry.error}", + state=ActionStatusCode.API_ERROR, + ) + return ActionReturn( + type=self.name, + result=[dict( + type="text", + content=( + f"**Agent `{entry.agent_type}` (id: `{entry.id}`) completed.**\n\n" + f"{entry.result or '(no output)'}" + ), + )], + ) + else: + return ActionReturn( + type=self.name, + result=[dict( + type="text", + content=( + f"✅ Agent spawned in background:\n" + f" - **type**: `{entry.agent_type}`\n" + f" - **id**: `{entry.id}`\n" + f" - **label**: {entry.label}\n" + f" - **status**: {entry.status}\n\n" + f"Use `list_agents` to check progress, or " + f"`query_agent` with the ID to get details." + ), + )], + ) + + @tool_api + async def list_agents( + self, + status: Annotated[ + Optional[str], + "Filter by status: 'running', 'stopped', 'failed', or 'pending'. " + "Omit to list all.", + ] = None, + ) -> ActionReturn: + """List all managed sub-agents and their status. + + Args: + status: Optional status filter. + + Returns: + ActionReturn with a summary of all agents. + """ + entries = self._service.list(status=status) + if not entries: + msg = "No sub-agents" + if status: + msg += f" with status '{status}'" + msg += " found." + return ActionReturn( + type=self.name, + result=[dict(type="text", content=msg)], + ) + + lines = [] + for e in entries: + status_icon = { + AgentStatus.PENDING: "⏳", + AgentStatus.RUNNING: "🔄", + AgentStatus.STOPPED: "✅", + AgentStatus.FAILED: "❌", + }.get(e.status, "❓") + lines.append( + f"- {status_icon} **{e.label}** (`{e.id}`) — " + f"type: `{e.agent_type}`, status: {e.status}" + ) + return ActionReturn( + type=self.name, + result=[dict(type="text", content="\n".join(lines))], + ) + + @tool_api + async def query_agent( + self, + agent_id: Annotated[str, "The ID of the agent to query"], + ) -> ActionReturn: + """Get detailed information about a specific sub-agent. + + Args: + agent_id: The agent's ID. + + Returns: + ActionReturn with the agent's full details. + """ + entry = self._service.get(agent_id) + if entry is None: + return ActionReturn( + type=self.name, + errmsg=f"Agent `{agent_id}` not found.", + state=ActionStatusCode.API_ERROR, + ) + + info = ( + f"**Agent `{entry.id}`**\n" + f"- **type**: `{entry.agent_type}`\n" + f"- **label**: {entry.label}\n" + f"- **status**: {entry.status}\n" + f"- **task**: {entry.task}\n" + ) + if entry.result: + info += f"\n**Result:**\n{entry.result}" + if entry.error: + info += f"\n**Error:** {entry.error}" + + return ActionReturn( + type=self.name, + result=[dict(type="text", content=info)], + ) + + @tool_api + async def resume_agent( + self, + agent_id: Annotated[str, "The ID of the stopped agent to resume"], + message: Annotated[ + str, + "The new message to send to the agent. The agent will " + "resume with its full previous context plus this message.", + ], + ) -> ActionReturn: + """Resume a stopped sub-agent with a new message. + + The agent picks up where it left off, retaining its previous + conversation history. + + Args: + agent_id: ID of the stopped agent. + message: New message to send. + + Returns: + ActionReturn with the agent's new result. + """ + try: + entry = await self._service.resume(agent_id, message) + except ValueError as exc: + return ActionReturn( + type=self.name, + errmsg=str(exc), + state=ActionStatusCode.API_ERROR, + ) + except Exception as exc: + return ActionReturn( + type=self.name, + errmsg=f"Failed to resume agent: {exc}", + state=ActionStatusCode.API_ERROR, + ) + + if entry.status == AgentStatus.FAILED: + return ActionReturn( + type=self.name, + errmsg=f"Agent failed after resume: {entry.error}", + state=ActionStatusCode.API_ERROR, + ) + + return ActionReturn( + type=self.name, + result=[dict( + type="text", + content=( + f"**Agent `{entry.id}` resumed and completed.**\n\n" + f"{entry.result or '(no output)'}" + ), + )], + ) + + @tool_api + async def stop_agent( + self, + agent_id: Annotated[str, "The ID of the running agent to stop"], + ) -> ActionReturn: + """Stop a running background sub-agent. + + Args: + agent_id: ID of the agent to stop. + + Returns: + ActionReturn confirming the stop. + """ + stopped = await self._service.stop(agent_id) + if stopped: + return ActionReturn( + type=self.name, + result=[dict( + type="text", + content=f"✅ Agent `{agent_id}` stopped.", + )], + ) + return ActionReturn( + type=self.name, + errmsg=f"Agent `{agent_id}` not found or not running.", + state=ActionStatusCode.API_ERROR, + ) diff --git a/lagent/actions/task.py b/lagent/actions/task.py new file mode 100644 index 00000000..4976fb25 --- /dev/null +++ b/lagent/actions/task.py @@ -0,0 +1,400 @@ +"""TaskAction -- agent-facing toolkit for managing a task board. + +Wraps :class:`~lagent.services.task.TaskBoard` to expose task CRUD as +``@tool_api`` methods. Follows the same pattern as +:class:`~lagent.actions.cron.CronAction`. + +Usage:: + + board = TaskBoard() + task_action = TaskAction(board) + executor = AsyncActionExecutor(actions=[task_action, ...]) +""" + +from __future__ import annotations + +import json +import logging +from typing import Annotated, Optional, Type + +from lagent.actions.base_action import AsyncActionMixin, BaseAction, tool_api +from lagent.actions.parser import BaseParser, JsonParser +from lagent.schema import ActionReturn, ActionStatusCode +from lagent.services.task import TaskBoard + +logger = logging.getLogger("lagent.actions.task") + + +def _parse_id_list(value: str | None) -> list[str]: + """Parse a comma-separated string of task IDs into a list.""" + if not value: + return [] + return [s.strip() for s in value.split(",") if s.strip()] + + +def _parse_metadata(value: str | None) -> dict | None: + """Parse a JSON string into a dict, or return None.""" + if not value: + return None + try: + result = json.loads(value) + if isinstance(result, dict): + return result + return None + except (json.JSONDecodeError, TypeError): + return None + + +def _format_task_detail(task) -> str: + """Format a Task into a detailed multi-line string.""" + lines = [ + f"Task #{task.id}", + f" Subject: {task.subject}", + f" Status: {task.status}", + f" Description: {task.description}", + ] + if task.active_form: + lines.append(f" Active Form: {task.active_form}") + if task.owner: + lines.append(f" Owner: {task.owner}") + if task.blocks: + refs = ", ".join(f"#{bid}" for bid in task.blocks) + lines.append(f" Blocks: {refs}") + if task.blocked_by: + refs = ", ".join(f"#{bid}" for bid in task.blocked_by) + lines.append(f" Blocked By: {refs}") + if task.metadata: + lines.append(f" Metadata: {json.dumps(task.metadata, ensure_ascii=False)}") + return "\n".join(lines) + + +class TaskAction(BaseAction): + """Manage a structured task board with dependency tracking. + + Use this toolkit to create, update, list, and inspect tasks. + Tasks support status tracking (pending / in_progress / completed), + ownership, and dependency relationships (blocks / blocked_by). + """ + + def __init__( + self, + task_board: TaskBoard, + description: Optional[dict] = None, + parser: Type[BaseParser] = JsonParser, + ) -> None: + super().__init__(description, parser) + self._board = task_board + + @tool_api + def create( + self, + subject: Annotated[ + str, + "A brief, actionable title in imperative form " + "(e.g. 'Fix authentication bug in login flow')", + ], + description: Annotated[ + str, + "Full details of what needs to be done", + ], + active_form: Annotated[ + Optional[str], + "Present continuous form shown when in_progress " + "(e.g. 'Fixing authentication bug'). " + "If omitted, the subject is used instead.", + ] = None, + blocked_by: Annotated[ + Optional[str], + "Comma-separated task IDs that must complete before this " + "task can start (e.g. '1,3')", + ] = None, + metadata: Annotated[ + Optional[str], + "JSON string of arbitrary key-value metadata", + ] = None, + ) -> ActionReturn: + """Create a new task on the board. + + Use this tool to break down complex work into trackable steps. + Create tasks when work requires 3 or more distinct steps, when + the user provides multiple things to do, or after receiving new + instructions that should be captured immediately. + + Do NOT create tasks for single trivial actions or purely + conversational/informational exchanges. + + All tasks start with status ``pending``. After creating tasks, + use ``update`` to set dependencies if needed, and check + ``list`` first to avoid duplicates. + + Args: + subject: Brief imperative title. + description: Full task details. + active_form: Present continuous label for progress display. + blocked_by: Comma-separated prerequisite task IDs. + metadata: JSON metadata string. + + Returns: + ActionReturn with the created task summary. + """ + try: + dep_ids = _parse_id_list(blocked_by) + meta = _parse_metadata(metadata) + + task = self._board.create( + subject=subject, + description=description, + active_form=active_form, + blocked_by=dep_ids, + metadata=meta, + ) + + content = f"Created task #{task.id}: {task.subject}" + if dep_ids: + refs = ", ".join(f"#{d}" for d in dep_ids) + content += f" (blocked by {refs})" + + return ActionReturn( + type=self.name, + result=[dict(type="text", content=content)], + ) + except Exception as exc: + return ActionReturn( + type=self.name, + errmsg=f"Failed to create task: {exc}", + state=ActionStatusCode.API_ERROR, + ) + + @tool_api + def update( + self, + task_id: Annotated[str, "The ID of the task to update"], + status: Annotated[ + Optional[str], + "New status: 'pending', 'in_progress', 'completed', or " + "'deleted' (permanently removes the task)", + ] = None, + subject: Annotated[ + Optional[str], "New subject for the task" + ] = None, + description: Annotated[ + Optional[str], "New description for the task" + ] = None, + active_form: Annotated[ + Optional[str], + "Present continuous form for progress display " + "(e.g. 'Running tests')", + ] = None, + owner: Annotated[ + Optional[str], "New owner (agent name)" + ] = None, + metadata: Annotated[ + Optional[str], + "JSON string of metadata keys to merge " + "(set a key to null to delete it)", + ] = None, + add_blocks: Annotated[ + Optional[str], + "Comma-separated task IDs that cannot start until this " + "one completes", + ] = None, + add_blocked_by: Annotated[ + Optional[str], + "Comma-separated task IDs that must complete before this " + "one can start", + ] = None, + ) -> ActionReturn: + """Update a task's status, details, or dependencies. + + Mark ``in_progress`` BEFORE beginning work on a task, not after. + Only mark ``completed`` when work is FULLY done -- tests pass, + no partial implementation, no unresolved errors. + + If blocked, create a new task describing what needs to be + resolved rather than leaving the current task stuck. + + After completing a task, call ``list`` to find the next + available work. + + Use ``status='deleted'`` to permanently remove a task. + + Args: + task_id: ID of the task to update. + status: New status value. + subject: New subject. + description: New description. + active_form: New progress display label. + owner: New owner. + metadata: JSON metadata to merge. + add_blocks: Task IDs this task blocks. + add_blocked_by: Task IDs blocking this task. + + Returns: + ActionReturn with the updated task summary. + """ + try: + kwargs: dict = {} + if status is not None: + kwargs["status"] = status + if subject is not None: + kwargs["subject"] = subject + if description is not None: + kwargs["description"] = description + if active_form is not None: + kwargs["active_form"] = active_form + if owner is not None: + kwargs["owner"] = owner + + meta = _parse_metadata(metadata) + if meta is not None: + kwargs["metadata"] = meta + + blocks_list = _parse_id_list(add_blocks) + if blocks_list: + kwargs["add_blocks"] = blocks_list + + blocked_by_list = _parse_id_list(add_blocked_by) + if blocked_by_list: + kwargs["add_blocked_by"] = blocked_by_list + + if not kwargs: + return ActionReturn( + type=self.name, + errmsg="No fields to update were provided.", + state=ActionStatusCode.ARGS_ERROR, + ) + + task = self._board.update(task_id, **kwargs) + + # status="deleted" returns None + if task is None and status == "deleted": + return ActionReturn( + type=self.name, + result=[dict( + type="text", + content=f"Task #{task_id} has been deleted.", + )], + ) + + if task is None: + return ActionReturn( + type=self.name, + errmsg=f"Task #{task_id} not found.", + state=ActionStatusCode.API_ERROR, + ) + + return ActionReturn( + type=self.name, + result=[dict( + type="text", + content=f"Updated task #{task.id}: [{task.status}] {task.subject}", + )], + ) + except Exception as exc: + return ActionReturn( + type=self.name, + errmsg=f"Failed to update task: {exc}", + state=ActionStatusCode.API_ERROR, + ) + + @tool_api + def get( + self, + task_id: Annotated[str, "The ID of the task to retrieve"], + ) -> ActionReturn: + """Retrieve full details of a specific task. + + Use this before starting work on a task to understand the + complete requirements. Also useful for checking dependency + relationships (what it blocks, what blocks it). + + Read the task's latest state before updating it to avoid + stale overwrites. + + Args: + task_id: The task ID. + + Returns: + ActionReturn with the task's full details including + dependencies. + """ + task = self._board.get(task_id) + if task is None: + return ActionReturn( + type=self.name, + errmsg=f"Task #{task_id} not found.", + state=ActionStatusCode.API_ERROR, + ) + return ActionReturn( + type=self.name, + result=[dict(type="text", content=_format_task_detail(task))], + ) + + @tool_api + def list( + self, + status: Annotated[ + Optional[str], + "Filter by status: 'pending', 'in_progress', or " + "'completed'. Omit to list all tasks.", + ] = None, + ) -> ActionReturn: + """List all tasks with a status summary. + + Use this to see what tasks are available (pending, not blocked, + no owner), check overall progress, or find newly unblocked + work after completing a task. + + Prefer working on tasks in ID order (lowest first) when + multiple tasks are available, as earlier tasks often set up + context for later ones. + + Args: + status: Optional status filter. + + Returns: + ActionReturn with a summary of tasks. + """ + summary = self._board.get_summary() + if status is not None: + tasks = self._board.list(status=status) + if not tasks: + return ActionReturn( + type=self.name, + result=[dict( + type="text", + content=f"No tasks with status '{status}'.", + )], + ) + # Rebuild summary with filter + completed_ids = { + t.id for t in self._board.list(status="completed") + } + lines = [] + for t in tasks: + line = f"#{t.id}. [{t.status}] {t.subject}" + if t.owner: + line += f" @{t.owner}" + active_blockers = [ + bid for bid in t.blocked_by + if bid not in completed_ids + ] + if active_blockers: + refs = ", ".join(f"#{bid}" for bid in active_blockers) + line += f" ▶ blocked by {refs}" + lines.append(line) + content = f"Tasks ({status}): {len(tasks)}\n\n" + "\n".join(lines) + return ActionReturn( + type=self.name, + result=[dict(type="text", content=content)], + ) + + return ActionReturn( + type=self.name, + result=[dict(type="text", content=summary)], + ) + + +class AsyncTaskAction(AsyncActionMixin, TaskAction): + """Async version of :class:`TaskAction`.""" + pass diff --git a/lagent/adapters/__init__.py b/lagent/adapters/__init__.py new file mode 100644 index 00000000..fda0830b --- /dev/null +++ b/lagent/adapters/__init__.py @@ -0,0 +1,16 @@ +from .base import AsyncExternalAgent, BaseExternalAgent +from .cli_adapter import CLIAgentAdapter +from .claude_code import ClaudeCodeAdapter +from .claude_code_sdk import ClaudeCodeSDKAdapter +from .proxy import LLMProxyRecorder +from .sdk_adapter import SDKAgentAdapter + +__all__ = [ + 'BaseExternalAgent', + 'AsyncExternalAgent', + 'CLIAgentAdapter', + 'ClaudeCodeAdapter', + 'ClaudeCodeSDKAdapter', + 'SDKAgentAdapter', + 'LLMProxyRecorder', +] diff --git a/lagent/adapters/base.py b/lagent/adapters/base.py new file mode 100644 index 00000000..476f83c2 --- /dev/null +++ b/lagent/adapters/base.py @@ -0,0 +1,214 @@ +"""Base classes for external agent adapters. + +Provides ``BaseExternalAgent`` and ``AsyncExternalAgent`` — abstract base +classes that wrap external agent frameworks (CLI tools or Python SDKs) +into lagent's Agent protocol. + +These adapters are fully compatible with lagent's ecosystem: they can be +used as ``agent_engine`` in InterclawApp, placed into Sequential chains, +managed by AgentService, and discovered by AgentLoader. + +Key design: + - ``forward()`` = execute external agent, return final output as AgentMessage + - ``state_dict()`` = memory + LLM trace from Proxy (if enabled) + - ``llm=None`` because external frameworks bring their own reasoning engine + - ``setup()`` is lazy (called on first forward, not at init) + - Proxy integration via ``_build_env()`` injects base_url + session key + +Usage:: + + class MyCLIAgent(AsyncExternalAgent): + async def setup(self): + ... # verify binary exists + + async def run_external_async(self, task, **kwargs): + ... # subprocess call, return stdout + + agent = MyCLIAgent(name="my-agent", timeout=300) + result = await agent("Fix the bug in main.py") + trace = agent.state_dict().get('llm_trace', []) +""" + +import os +from abc import abstractmethod +from typing import Any, Dict, Optional, Union +from uuid import uuid4 + +from lagent.agents.agent import Agent, AsyncAgentMixin +from lagent.schema import AgentMessage + + +class BaseExternalAgent(Agent): + """Abstract base for wrapping external agent frameworks as lagent Agents. + + Subclasses implement ``setup()`` and ``run_external()``. + The ``forward()`` method handles the lifecycle: + setup → build env → run → wrap output as AgentMessage. + + This class does NOT require an LLM, memory, or aggregator from lagent. + The external framework provides its own reasoning engine. + + Args: + name: Agent name, used as AgentMessage.sender. + description: Human-readable description. + working_dir: Working directory for the external agent. + env_vars: Extra environment variables for the external agent. + timeout: Maximum execution time in seconds. None = no limit. + proxy: Optional LLMProxyRecorder for trajectory capture. + hooks: Optional hooks (same as Agent). + """ + + def __init__( + self, + name: Optional[str] = None, + description: Optional[str] = None, + working_dir: Optional[str] = None, + env_vars: Optional[Dict[str, str]] = None, + timeout: Optional[float] = None, + proxy: Any = None, + **kwargs, + ): + # Don't pass llm, template, output_format, aggregator — not needed + super().__init__( + llm=None, + name=name, + description=description, + hooks=kwargs.pop('hooks', None), + ) + self.working_dir = working_dir + self.env_vars = env_vars or {} + self.timeout = timeout + self.session_id = uuid4().hex[:8] + self.proxy = proxy + self._setup_done = False + + @abstractmethod + def setup(self) -> None: + """One-time initialization (verify binary / check SDK import). + + Called lazily on first ``forward()`` call. Must be idempotent. + """ + + @abstractmethod + def run_external(self, task: str, **kwargs) -> str: + """Execute the external agent synchronously. + + Args: + task: The task/prompt string. + + Returns: + The external agent's textual output. + + Raises: + RuntimeError: If the external agent fails. + TimeoutError: If execution exceeds self.timeout. + """ + + def _build_env(self) -> dict: + """Build environment variables dict with proxy injection.""" + env = os.environ.copy() + env.update(self.env_vars) + if self.proxy: + session_key = f"sk-proxy-{self.session_id}" + env.update({ + 'OPENAI_BASE_URL': self.proxy.url, + 'OPENAI_API_KEY': session_key, + 'ANTHROPIC_BASE_URL': self.proxy.url, + 'ANTHROPIC_API_KEY': session_key, + }) + return env + + def _extract_task(self, messages: tuple) -> str: + """Join multiple AgentMessage contents into a single task string.""" + parts = [] + for m in messages: + if isinstance(m, AgentMessage): + parts.append(str(m.content)) + else: + parts.append(str(m)) + return '\n'.join(parts) + + def forward(self, *message: AgentMessage, **kwargs) -> Union[AgentMessage, str]: + """Lagent Agent protocol implementation. + + Extracts task from messages, runs external agent, wraps result. + """ + task = self._extract_task(message) + if not self._setup_done: + self.setup() + self._setup_done = True + + try: + output = self.run_external(task, **kwargs) + except Exception as exc: + return AgentMessage( + sender=self.name, + content=f"External agent failed: {exc}", + extra_info={'error': str(exc), 'adapter': self.__class__.__name__}, + ) + + return AgentMessage( + sender=self.name, + content=output, + extra_info={'adapter': self.__class__.__name__, 'session_id': self.session_id}, + ) + + def state_dict(self, prefix='', destination=None) -> Dict: + dest = super().state_dict(prefix=prefix, destination=destination) + if self.proxy: + dest[prefix + 'llm_trace'] = self.proxy.get_records(self.session_id) + return dest + + def load_state_dict(self, state_dict: Dict): + # Filter out llm_trace keys before passing to parent + filtered = { + k: v for k, v in state_dict.items() + if not k.endswith('llm_trace') + } + # Parent expects exact key match, add missing memory key if needed + if not any(k.endswith('memory') for k in filtered): + filtered['' + 'memory'] = [] + super().load_state_dict(filtered) + + +class AsyncExternalAgent(AsyncAgentMixin, BaseExternalAgent): + """Async variant of BaseExternalAgent. + + Subclasses implement ``run_external_async()`` instead of + ``run_external()``. + """ + + @abstractmethod + async def run_external_async(self, task: str, **kwargs) -> str: + """Async version of run_external.""" + + def run_external(self, task: str, **kwargs) -> str: + """Sync fallback — not used in async path.""" + raise NotImplementedError( + "Use run_external_async() for AsyncExternalAgent" + ) + + async def forward(self, *message: AgentMessage, **kwargs) -> Union[AgentMessage, str]: + task = self._extract_task(message) + if not self._setup_done: + self.setup() + self._setup_done = True + + # Lazily start proxy if present + if self.proxy and not self.proxy.is_running: + await self.proxy.start() + + try: + output = await self.run_external_async(task, **kwargs) + except Exception as exc: + return AgentMessage( + sender=self.name, + content=f"External agent failed: {exc}", + extra_info={'error': str(exc), 'adapter': self.__class__.__name__}, + ) + + return AgentMessage( + sender=self.name, + content=output, + extra_info={'adapter': self.__class__.__name__, 'session_id': self.session_id}, + ) diff --git a/lagent/adapters/claude_code.py b/lagent/adapters/claude_code.py new file mode 100644 index 00000000..34f841b8 --- /dev/null +++ b/lagent/adapters/claude_code.py @@ -0,0 +1,188 @@ +"""Claude Code CLI adapter — wraps the claude CLI as a lagent Agent. + +Supports real multi-turn via Claude Code's ``--continue`` flag: +each ``forward()`` call after the first automatically resumes +the previous session. + +Usage:: + + from lagent.adapters.claude_code import ClaudeCodeAdapter + + agent = ClaudeCodeAdapter(timeout=120) + r1 = await agent("Read main.py and explain what it does") + r2 = await agent("Now refactor the error handling") # continues same session + print(r2.content) + + # With proxy for trajectory capture: + from lagent.adapters.proxy import LLMProxyRecorder + proxy = LLMProxyRecorder(real_api_key="...", real_base_url="...") + agent = ClaudeCodeAdapter(proxy=proxy, timeout=120) + result = await agent("Fix the bug") + trace = agent.state_dict()['llm_trace'] +""" + +import asyncio +import json +import os +import shlex +import shutil +from typing import Callable, Dict, List, Optional + +from .base import AsyncExternalAgent + + +def _parse_claude_stream_json(stdout: str, stderr: str) -> str: + """Parse claude --output-format=stream-json output.""" + result_parts = [] + for line in stdout.strip().split('\n'): + line = line.strip() + if not line: + continue + try: + data = json.loads(line) + except json.JSONDecodeError: + continue + msg_type = data.get('type', '') + if msg_type == 'result': + text = data.get('result', '') + if text: + result_parts.append(text) + elif msg_type == 'assistant': + msg = data.get('message', {}) + content = msg.get('content', []) + for block in content if isinstance(content, list) else []: + if isinstance(block, dict) and block.get('type') == 'text': + result_parts.append(block.get('text', '')) + elif isinstance(block, str): + result_parts.append(block) + if result_parts: + return '\n'.join(result_parts) + return stdout.strip() + + +def _parse_claude_text(stdout: str, stderr: str) -> str: + """Parse claude --output-format=text output.""" + return stdout.strip() + + +class ClaudeCodeAdapter(AsyncExternalAgent): + """Wraps Claude Code CLI as a lagent Agent with real multi-turn support. + + Uses Claude Code's ``--continue`` flag to resume sessions across + ``forward()`` calls. Each call spawns a new subprocess but the + conversation history is preserved by Claude Code internally. + + Args: + output_format: "text" or "stream-json". Default: "text". + max_turns: Maximum number of agent turns per call. Default: None. + permission_mode: Permission mode. Default: "default". + extra_flags: Additional CLI flags as a list of strings. + model: Model name override. + parse_output: Custom output parser ``(stdout, stderr) -> str``. + max_output_chars: Truncate output beyond this limit. + **kwargs: Passed to AsyncExternalAgent (name, working_dir, + env_vars, timeout, proxy, hooks). + """ + + def __init__( + self, + output_format: str = 'text', + max_turns: Optional[int] = None, + permission_mode: str = 'default', + extra_flags: Optional[List[str]] = None, + model: Optional[str] = None, + parse_output: Optional[Callable[[str, str], str]] = None, + max_output_chars: int = 50000, + **kwargs, + ): + # Build env vars + env_vars = kwargs.pop('env_vars', {}) or {} + if model: + env_vars['ANTHROPIC_MODEL'] = model + for key in ('ANTHROPIC_AUTH_TOKEN', 'ANTHROPIC_API_KEY', + 'ANTHROPIC_BASE_URL'): + if key not in env_vars and os.environ.get(key): + env_vars[key] = os.environ[key] + + kwargs.setdefault('name', 'claude-code') + kwargs.setdefault('description', 'Claude Code CLI agent') + + super().__init__(env_vars=env_vars, **kwargs) + + self.output_format = output_format + self.max_turns = max_turns + self.permission_mode = permission_mode + self.extra_flags = extra_flags or [] + self.max_output_chars = max_output_chars + self._call_count = 0 + + if parse_output: + self.parse_output = parse_output + elif output_format == 'stream-json': + self.parse_output = _parse_claude_stream_json + else: + self.parse_output = _parse_claude_text + + def setup(self) -> None: + if not shutil.which('claude'): + raise RuntimeError( + "Claude Code CLI ('claude') not found on PATH. " + "Install it with: npm install -g @anthropic-ai/claude-code" + ) + + def _build_command(self, task: str) -> str: + """Build the CLI command, adding --continue for subsequent calls.""" + cmd_parts = ['claude', '--print', '-p', shlex.quote(task)] + cmd_parts.append(f'--output-format={self.output_format}') + if self.permission_mode != 'default': + cmd_parts.append(f'--permission-mode={self.permission_mode}') + if self.max_turns is not None: + cmd_parts.append(f'--max-turns={self.max_turns}') + + # Real multi-turn: --continue on subsequent calls + if self._call_count > 0: + cmd_parts.append('--continue') + + if self.extra_flags: + cmd_parts.extend(self.extra_flags) + + return ' '.join(cmd_parts) + + async def run_external_async(self, task: str, **kwargs) -> str: + cmd = self._build_command(task) + env = self._build_env() + + process = await asyncio.create_subprocess_shell( + cmd, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + cwd=self.working_dir, + env=env, + ) + + try: + stdout_bytes, stderr_bytes = await asyncio.wait_for( + process.communicate(), + timeout=self.timeout, + ) + except asyncio.TimeoutError: + process.kill() + await process.wait() + raise TimeoutError( + f"Claude Code timed out after {self.timeout}s" + ) + + stdout = stdout_bytes.decode('utf-8', errors='replace') + stderr = stderr_bytes.decode('utf-8', errors='replace') + + if process.returncode != 0: + raise RuntimeError( + f"Claude Code exited with code {process.returncode}.\n" + f"stderr: {stderr[:2000]}" + ) + + if len(stdout) > self.max_output_chars: + stdout = stdout[:self.max_output_chars] + '\n...(truncated)' + + self._call_count += 1 + return self.parse_output(stdout, stderr) diff --git a/lagent/adapters/claude_code_sdk.py b/lagent/adapters/claude_code_sdk.py new file mode 100644 index 00000000..53e24277 --- /dev/null +++ b/lagent/adapters/claude_code_sdk.py @@ -0,0 +1,239 @@ +"""Claude Code SDK adapter — wraps claude-agent-sdk as a lagent Agent. + +Uses the Python SDK instead of CLI subprocess, providing: +- Real multi-turn via session_id (no --continue hack) +- Structured message access (TextBlock, ThinkingBlock, ToolUseBlock) +- Runtime hooks (PreToolUse, PostToolUse, etc.) +- Full usage/cost tracking per turn + +Usage:: + + from lagent.adapters.claude_code_sdk import ClaudeCodeSDKAdapter + + agent = ClaudeCodeSDKAdapter(max_turns=5, timeout=120) + r1 = await agent("Read main.py") + r2 = await agent("Now fix the bug") # real multi-turn, same session + print(r2.content) + + # All messages captured structurally + trace = agent.state_dict()['sdk_trace'] +""" + +import asyncio +import os +from dataclasses import asdict +from datetime import datetime +from typing import Any, Dict, List, Optional + +from .base import AsyncExternalAgent + + +class ClaudeCodeSDKAdapter(AsyncExternalAgent): + """Wraps claude-agent-sdk as a lagent Agent with real multi-turn. + + Each ``forward()`` call uses the same ``session_id``, so Claude Code + maintains full conversation history internally. No Proxy needed for + trace capture — the SDK yields structured messages directly. + + Args: + max_turns: Max agent turns per call. Default: None (unlimited). + permission_mode: Permission mode. Default: "default". + model: Model name override. + system_prompt: Custom system prompt. + allowed_tools: List of allowed tool names. + disallowed_tools: List of disallowed tool names. + cwd: Working directory for Claude Code. + effort: Reasoning effort level ("low", "medium", "high", "max"). + thinking: Thinking config dict. Default: adaptive. + **kwargs: Passed to AsyncExternalAgent (name, timeout, proxy, hooks). + """ + + def __init__( + self, + max_turns: Optional[int] = None, + permission_mode: str = 'default', + model: Optional[str] = None, + system_prompt: Optional[str] = None, + allowed_tools: Optional[List[str]] = None, + disallowed_tools: Optional[List[str]] = None, + cwd: Optional[str] = None, + effort: Optional[str] = None, + thinking: Optional[dict] = None, + **kwargs, + ): + kwargs.setdefault('name', 'claude-code-sdk') + kwargs.setdefault('description', 'Claude Code SDK agent') + super().__init__(**kwargs) + + self.max_turns = max_turns + self.permission_mode = permission_mode + self.model = model + self.system_prompt = system_prompt + self.allowed_tools = allowed_tools or [] + self.disallowed_tools = disallowed_tools or [] + self.cwd = cwd or self.working_dir + self.effort = effort + self.thinking = thinking + + self._session_id: Optional[str] = None + self._sdk_trace: List[dict] = [] + self._call_count = 0 + + def setup(self) -> None: + try: + import claude_agent_sdk + except ImportError: + raise RuntimeError( + "claude-agent-sdk is required. " + "Install with: pip install claude-agent-sdk" + ) + + async def run_external_async(self, task: str, **kwargs) -> str: + from claude_agent_sdk import ( + AssistantMessage, + ClaudeAgentOptions, + ResultMessage, + StreamEvent, + SystemMessage, + UserMessage, + query, + ) + + options = ClaudeAgentOptions( + permission_mode=self.permission_mode, + max_turns=self.max_turns, + ) + + if self.model: + options.model = self.model + if self.system_prompt: + options.system_prompt = self.system_prompt + if self.allowed_tools: + options.allowed_tools = self.allowed_tools + if self.disallowed_tools: + options.disallowed_tools = self.disallowed_tools + if self.cwd: + options.cwd = self.cwd + if self.effort: + options.effort = self.effort + if self.thinking: + options.thinking = self.thinking + + # Multi-turn: resume session on subsequent calls + if self._session_id: + options.resume = self._session_id + + # Inject proxy env if present + if self.proxy: + session_key = f"sk-proxy-{self.session_id}" + options.env = { + 'ANTHROPIC_BASE_URL': self.proxy.url, + 'ANTHROPIC_API_KEY': session_key, + } + + # Collect messages + messages = [] + result_text = '' + result_msg = None + + try: + async for message in query(prompt=task, options=options): + record = { + 'timestamp': datetime.now().isoformat(), + 'type': type(message).__name__, + 'call_index': self._call_count, + } + + if isinstance(message, AssistantMessage): + blocks = [] + for block in message.content: + block_dict = asdict(block) + block_dict['block_type'] = type(block).__name__ + blocks.append(block_dict) + record['content'] = blocks + record['model'] = message.model + record['usage'] = message.usage + record['stop_reason'] = message.stop_reason + record['message_id'] = message.message_id + + # Extract text from content blocks + for block in message.content: + if hasattr(block, 'text'): + result_text = block.text + + # Try to capture session_id from AssistantMessage + if hasattr(message, 'session_id') and message.session_id: + self._session_id = message.session_id + + elif isinstance(message, UserMessage): + if isinstance(message.content, str): + record['content'] = message.content + else: + record['content'] = [asdict(b) for b in message.content] + + elif isinstance(message, ResultMessage): + result_msg = message + record['result'] = message.result + record['session_id'] = message.session_id + record['usage'] = message.usage + record['total_cost_usd'] = message.total_cost_usd + record['num_turns'] = message.num_turns + record['is_error'] = message.is_error + record['stop_reason'] = message.stop_reason + + elif isinstance(message, SystemMessage): + record['subtype'] = message.subtype + record['data'] = message.data + + messages.append(record) + except Exception as exc: + # SDK may error after yielding some messages. + # Log but don't lose what we already captured. + import logging + logging.getLogger(__name__).warning( + f"SDK query error (captured {len(messages)} events): {exc}" + ) + + self._sdk_trace.extend(messages) + + # Capture session_id for multi-turn + if result_msg and result_msg.session_id: + self._session_id = result_msg.session_id + + # Fallback: try to get session_id from trace events + if not self._session_id: + for evt in reversed(messages): + sid = evt.get('session_id') + if sid: + self._session_id = sid + break + + self._call_count += 1 + + # Return the final result + if result_msg and result_msg.result: + return result_msg.result + return result_text or '(no output)' + + def state_dict(self, prefix='', destination=None) -> dict: + dest = super().state_dict(prefix=prefix, destination=destination) + dest[prefix + 'sdk_trace'] = list(self._sdk_trace) + if self._session_id: + dest[prefix + 'claude_session_id'] = self._session_id + return dest + + def load_state_dict(self, state_dict: dict): + filtered = { + k: v for k, v in state_dict.items() + if not k.endswith(('sdk_trace', 'claude_session_id')) + } + if not any(k.endswith('memory') for k in filtered): + filtered['memory'] = [] + super().load_state_dict(filtered) + + # Restore session for multi-turn + for k, v in state_dict.items(): + if k.endswith('claude_session_id'): + self._session_id = v + if k.endswith('sdk_trace'): + self._sdk_trace = v or [] diff --git a/lagent/adapters/cli_adapter.py b/lagent/adapters/cli_adapter.py new file mode 100644 index 00000000..2b4355d1 --- /dev/null +++ b/lagent/adapters/cli_adapter.py @@ -0,0 +1,118 @@ +"""CLI Agent Adapter — wraps external CLI agents as lagent Agents. + +Spawns external agent CLIs (claude, aider, codex, etc.) as subprocesses, +captures stdout/stderr, and converts the output into AgentMessages. + +The ``command_template`` must contain a ``{task}`` placeholder:: + + adapter = CLIAgentAdapter( + name="claude-code", + command_template="claude -p '{task}' --output-format text", + timeout=300, + ) + result = await adapter("Fix the bug in main.py") + +A custom ``parse_output`` function can be injected to extract structured +results from the raw CLI output. +""" + +import asyncio +import os +import shlex +import shutil +from typing import Callable, Dict, List, Optional + +from .base import AsyncExternalAgent + + +class CLIAgentAdapter(AsyncExternalAgent): + """Wraps an external agent accessible via CLI (subprocess). + + Args: + command_template: Format string for the CLI command. + Must contain ``{task}`` placeholder. + Example: ``"claude -p '{task}' --output-format text"`` + parse_output: Callable ``(stdout, stderr) -> str`` that extracts + the desired result from raw output. Default: return stdout. + shell: Whether to use shell execution. Default: True. + extra_args: Additional CLI arguments appended to the command. + max_output_chars: Truncate output beyond this limit. + **kwargs: Passed to AsyncExternalAgent (name, working_dir, + env_vars, timeout, proxy, hooks). + """ + + def __init__( + self, + command_template: str, + parse_output: Optional[Callable[[str, str], str]] = None, + shell: bool = True, + extra_args: Optional[List[str]] = None, + max_output_chars: int = 50000, + **kwargs, + ): + super().__init__(**kwargs) + self.command_template = command_template + self.parse_output = parse_output or self._default_parse + self.shell = shell + self.extra_args = extra_args or [] + self.max_output_chars = max_output_chars + + def setup(self) -> None: + """Verify the CLI binary exists on PATH.""" + binary = self.command_template.split()[0] + if not shutil.which(binary): + raise RuntimeError( + f"CLI binary '{binary}' not found on PATH. " + f"Ensure the external agent is installed." + ) + + async def run_external_async(self, task: str, **kwargs) -> str: + """Spawn subprocess, capture output, parse, return.""" + escaped_task = shlex.quote(task) + cmd = self.command_template.format(task=task) + if self.extra_args: + cmd += ' ' + ' '.join(self.extra_args) + + env = self._build_env() + + process = await asyncio.create_subprocess_shell( + cmd, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + cwd=self.working_dir, + env=env, + ) + + try: + stdout_bytes, stderr_bytes = await asyncio.wait_for( + process.communicate(), + timeout=self.timeout, + ) + except asyncio.TimeoutError: + process.kill() + await process.wait() + raise TimeoutError( + f"CLI agent timed out after {self.timeout}s" + ) + + stdout = stdout_bytes.decode('utf-8', errors='replace') + stderr = stderr_bytes.decode('utf-8', errors='replace') + + if process.returncode != 0: + raise RuntimeError( + f"CLI agent exited with code {process.returncode}.\n" + f"stderr: {stderr[:2000]}" + ) + + if len(stdout) > self.max_output_chars: + stdout = stdout[:self.max_output_chars] + '\n...(truncated)' + + return self.parse_output(stdout, stderr) + + @staticmethod + def _default_parse(stdout: str, stderr: str) -> str: + """Default parser: return stdout, append stderr if non-empty.""" + result = stdout.strip() + if stderr.strip(): + result += f'\n[stderr]: {stderr.strip()}' + return result diff --git a/lagent/adapters/openai_agents.py b/lagent/adapters/openai_agents.py new file mode 100644 index 00000000..9c3e92c1 --- /dev/null +++ b/lagent/adapters/openai_agents.py @@ -0,0 +1,129 @@ +"""OpenAI Agents SDK adapter — wraps openai-agents as a lagent Agent. + +Demonstrates the SDKAgentAdapter pattern with a real framework. +Supports real multi-turn via RunResult chaining, and Proxy for +LLM trace capture. + +Usage:: + + from lagent.adapters.openai_agents import OpenAIAgentsAdapter + + agent = OpenAIAgentsAdapter( + model="gpt-4o-mini", + instructions="You are a helpful assistant.", + max_turns=5, + ) + r1 = await agent("What is 2+2?") + r2 = await agent("Now multiply that by 3") # real multi-turn +""" + +import asyncio +import os +from typing import Any, Dict, List, Optional + +from .sdk_adapter import SDKAgentAdapter + + +class OpenAIAgentsAdapter(SDKAgentAdapter): + """Wraps OpenAI Agents SDK as a lagent Agent. + + Real multi-turn is achieved by passing the previous ``RunResult`` + as input to the next ``Runner.run()`` call, so the agent sees + the full conversation history. + + Args: + model: Model name (e.g. "gpt-4o-mini"). Required. + instructions: System instructions for the agent. + max_turns: Max turns per call. Default: 10. + agent_name: Name for the OpenAI Agent. Default: "assistant". + **kwargs: Passed to SDKAgentAdapter. + """ + + def __init__( + self, + model: str = 'gpt-4o-mini', + instructions: Optional[str] = None, + max_turns: int = 10, + agent_name: str = 'assistant', + api_key: Optional[str] = None, + base_url: Optional[str] = None, + http_proxy: Optional[str] = None, + **kwargs, + ): + kwargs.setdefault('name', 'openai-agents') + kwargs.setdefault('description', f'OpenAI Agents SDK ({model})') + kwargs.setdefault('sdk_module', 'agents') + kwargs.setdefault('sdk_config', { + 'model': model, + 'instructions': instructions or 'You are a helpful assistant.', + 'max_turns': max_turns, + 'agent_name': agent_name, + 'api_key': api_key or os.environ.get('OPENAI_API_KEY', ''), + 'base_url': base_url or os.environ.get('OPENAI_BASE_URL', ''), + 'http_proxy': http_proxy, + }) + super().__init__(**kwargs) + self._last_result: Any = None + + def create_sdk_agent(self, config: Dict[str, Any]) -> Any: + from agents import Agent + return Agent( + name=config.get('agent_name', 'assistant'), + model=config['model'], + instructions=config.get('instructions', 'You are a helpful assistant.'), + ) + + def invoke_sdk_agent(self, agent: Any, task: str, **kwargs) -> str: + # Sync fallback — use async path instead + raise NotImplementedError("Use invoke_sdk_agent_async") + + async def invoke_sdk_agent_async(self, agent: Any, task: str, **kwargs) -> str: + from agents import Runner, RunConfig + from agents.models.openai_provider import OpenAIProvider + from openai import AsyncOpenAI + import httpx + + config = self.sdk_config + + # Build OpenAI client with optional proxy and base_url + client_kwargs = {} + if config.get('api_key'): + client_kwargs['api_key'] = config['api_key'] + if config.get('base_url'): + client_kwargs['base_url'] = config['base_url'] + if config.get('http_proxy'): + client_kwargs['http_client'] = httpx.AsyncClient( + proxy=config['http_proxy'] + ) + + # Override with LLM proxy if present + if self.proxy: + session_key = f"sk-proxy-{self.session_id}" + client_kwargs['api_key'] = session_key + client_kwargs['base_url'] = self.proxy.url + '/v1' + # No http_proxy needed for local proxy + client_kwargs.pop('http_client', None) + + client = AsyncOpenAI(**client_kwargs) + run_config = RunConfig( + model=config.get('model'), + model_provider=OpenAIProvider(openai_client=client), + ) + + # Multi-turn: pass previous result as input + if self._last_result is not None: + input_data = self._last_result.to_input_list() + [ + {"role": "user", "content": task} + ] + else: + input_data = task + + result = await Runner.run( + starting_agent=agent, + input=input_data, + max_turns=config.get('max_turns', 10), + run_config=run_config, + ) + + self._last_result = result + return result.final_output diff --git a/lagent/adapters/openai_chat.py b/lagent/adapters/openai_chat.py new file mode 100644 index 00000000..692d0900 --- /dev/null +++ b/lagent/adapters/openai_chat.py @@ -0,0 +1,128 @@ +"""OpenAI Chat adapter — wraps openai SDK as a lagent Agent. + +Uses the raw openai ChatCompletion API (no openai-agents dependency). +Supports real multi-turn by maintaining messages history internally, +and Proxy for LLM trace capture. + +Usage:: + + from lagent.adapters.openai_chat import OpenAIChatAdapter + + agent = OpenAIChatAdapter( + model="gpt-4o-mini", + api_key="sk-...", + base_url="http://...", + ) + r1 = await agent("What is 2+2?") + r2 = await agent("Now multiply that by 3") # real multi-turn +""" + +import os +from typing import Any, Dict, List, Optional + +from .sdk_adapter import SDKAgentAdapter + + +class OpenAIChatAdapter(SDKAgentAdapter): + """Wraps OpenAI ChatCompletion API as a lagent Agent. + + Real multi-turn is achieved by maintaining the messages list + internally — each call appends the new user message and assistant + response, so the full history is sent on every LLM call. + + Args: + model: Model name. Required. + api_key: API key. Falls back to OPENAI_API_KEY env var. + base_url: API base URL. Falls back to OPENAI_BASE_URL env var. + http_proxy: HTTP proxy URL for outbound connections. + system_prompt: System instructions. + temperature: Sampling temperature. + max_tokens: Max output tokens per call. + **kwargs: Passed to SDKAgentAdapter. + """ + + def __init__( + self, + model: str = 'gpt-4o-mini', + api_key: Optional[str] = None, + base_url: Optional[str] = None, + http_proxy: Optional[str] = None, + system_prompt: Optional[str] = None, + temperature: Optional[float] = None, + max_tokens: Optional[int] = None, + **kwargs, + ): + kwargs.setdefault('name', 'openai-chat') + kwargs.setdefault('description', f'OpenAI Chat ({model})') + kwargs.setdefault('sdk_module', 'openai') + kwargs.setdefault('sdk_config', { + 'model': model, + 'api_key': api_key or os.environ.get('OPENAI_API_KEY', ''), + 'base_url': base_url or os.environ.get('OPENAI_BASE_URL', ''), + 'http_proxy': http_proxy, + 'system_prompt': system_prompt or 'You are a helpful assistant.', + 'temperature': temperature, + 'max_tokens': max_tokens, + }) + super().__init__(**kwargs) + self._messages: List[dict] = [] + + def create_sdk_agent(self, config: Dict[str, Any]) -> Any: + from openai import AsyncOpenAI + import httpx + + client_kwargs = {} + if config.get('api_key'): + client_kwargs['api_key'] = config['api_key'] + if config.get('base_url'): + client_kwargs['base_url'] = config['base_url'] + if config.get('http_proxy'): + client_kwargs['http_client'] = httpx.AsyncClient( + proxy=config['http_proxy'] + ) + return AsyncOpenAI(**client_kwargs) + + def invoke_sdk_agent(self, agent: Any, task: str, **kwargs) -> str: + raise NotImplementedError("Use invoke_sdk_agent_async") + + async def invoke_sdk_agent_async(self, agent: Any, task: str, **kwargs) -> str: + config = self.sdk_config + + # Initialize system prompt on first call + if not self._messages and config.get('system_prompt'): + self._messages.append({ + 'role': 'system', + 'content': config['system_prompt'], + }) + + # Append user message + self._messages.append({'role': 'user', 'content': task}) + + # If proxy is active, create a separate client pointing to proxy + if self.proxy: + from openai import AsyncOpenAI + session_key = f"sk-proxy-{self.session_id}" + client = AsyncOpenAI( + api_key=session_key, + base_url=self.proxy.url + '/v1', + ) + else: + client = agent # use the original client + + # Call LLM + call_kwargs = { + 'model': config['model'], + 'messages': list(self._messages), + } + if config.get('temperature') is not None: + call_kwargs['temperature'] = config['temperature'] + if config.get('max_tokens') is not None: + call_kwargs['max_tokens'] = config['max_tokens'] + + response = await client.chat.completions.create(**call_kwargs) + content = response.choices[0].message.content or '' + + # Append assistant response for multi-turn + self._messages.append({'role': 'assistant', 'content': content}) + + return content diff --git a/lagent/adapters/openclaw.py b/lagent/adapters/openclaw.py new file mode 100644 index 00000000..4ca05604 --- /dev/null +++ b/lagent/adapters/openclaw.py @@ -0,0 +1,161 @@ +"""OpenClaw CLI adapter — wraps the openclaw CLI as a lagent Agent. + +OpenClaw is a personal AI assistant with multi-channel support. +This adapter uses its ``agent --local --message`` CLI interface, +similar to ClaudeCodeAdapter's approach. + +Requires: Node 22+, ``npm install -g openclaw`` + +Usage:: + + from lagent.adapters.openclaw import OpenClawAdapter + + agent = OpenClawAdapter( + thinking='medium', + timeout=120, + ) + r1 = await agent("What is 2+2?") + r2 = await agent("Now multiply by 3") # multi-turn via --session-id +""" + +import asyncio +import json +import os +import shlex +import shutil +from typing import Any, Dict, List, Optional + +from .base import AsyncExternalAgent + + +class OpenClawAdapter(AsyncExternalAgent): + """Wraps OpenClaw CLI as a lagent Agent. + + Uses ``openclaw agent --local --message`` for execution. + Real multi-turn via ``--session-id`` (same session across calls). + + Args: + thinking: Thinking level (off/minimal/low/medium/high/xhigh). + agent_id: OpenClaw agent id. Default: None (use default agent). + json_output: Return JSON output. Default: True. + node_version: Node version to use via nvm. Default: "22". + nvm_dir: NVM directory. Default: ~/.nvm. + **kwargs: Passed to AsyncExternalAgent. + """ + + def __init__( + self, + thinking: str = 'medium', + agent_id: Optional[str] = 'main', + json_output: bool = True, + node_version: str = '22', + nvm_dir: Optional[str] = None, + **kwargs, + ): + kwargs.setdefault('name', 'openclaw') + kwargs.setdefault('description', 'OpenClaw personal AI assistant') + super().__init__(**kwargs) + + self.thinking = thinking + self.agent_id = agent_id + self.json_output = json_output + self.node_version = node_version + self.nvm_dir = nvm_dir or os.path.expanduser('~/.nvm') + self._openclaw_session_id: Optional[str] = None + self._call_count = 0 + + def setup(self) -> None: + """Verify openclaw is accessible (via nvm).""" + # We'll use nvm to access the right node version at runtime + # Just check nvm exists + nvm_sh = os.path.join(self.nvm_dir, 'nvm.sh') + if not os.path.exists(nvm_sh): + raise RuntimeError( + f"nvm not found at {nvm_sh}. " + f"Install nvm and then: nvm install {self.node_version} && " + f"npm install -g openclaw" + ) + + def _build_command(self, task: str) -> str: + """Build the openclaw CLI command.""" + # Source nvm to get the right node version + nvm_prefix = ( + f'export NVM_DIR="{self.nvm_dir}" && ' + f'[ -s "$NVM_DIR/nvm.sh" ] && . "$NVM_DIR/nvm.sh" && ' + f'nvm use {self.node_version} > /dev/null 2>&1 && ' + ) + + cmd_parts = ['openclaw', 'agent', '--local'] + cmd_parts.extend(['--message', shlex.quote(task)]) + cmd_parts.extend(['--thinking', self.thinking]) + + if self.json_output: + cmd_parts.append('--json') + if self.agent_id: + cmd_parts.extend(['--agent', self.agent_id]) + + # Multi-turn: use same session-id across calls + if self._openclaw_session_id: + cmd_parts.extend(['--session-id', self._openclaw_session_id]) + + return nvm_prefix + ' '.join(cmd_parts) + + async def run_external_async(self, task: str, **kwargs) -> str: + cmd = self._build_command(task) + env = self._build_env() + + process = await asyncio.create_subprocess_shell( + cmd, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + cwd=self.working_dir, + env=env, + ) + + try: + stdout_bytes, stderr_bytes = await asyncio.wait_for( + process.communicate(), + timeout=self.timeout, + ) + except asyncio.TimeoutError: + process.kill() + await process.wait() + raise TimeoutError( + f"OpenClaw timed out after {self.timeout}s" + ) + + stdout = stdout_bytes.decode('utf-8', errors='replace') + stderr = stderr_bytes.decode('utf-8', errors='replace') + + if process.returncode != 0: + raise RuntimeError( + f"OpenClaw exited with code {process.returncode}.\n" + f"stderr: {stderr[:2000]}" + ) + + self._call_count += 1 + + # Parse JSON output to extract session-id and content + if self.json_output: + return self._parse_json_output(stdout) + return stdout.strip() + + def _parse_json_output(self, stdout: str) -> str: + """Parse openclaw --json output.""" + try: + data = json.loads(stdout.strip()) + except json.JSONDecodeError: + return stdout.strip() + + # Capture session-id for multi-turn + if isinstance(data, dict): + sid = data.get('sessionId') or data.get('session_id') + if sid: + self._openclaw_session_id = sid + # Extract the reply text + reply = data.get('reply') or data.get('message') or data.get('content', '') + if isinstance(reply, str): + return reply + return json.dumps(data) + + return stdout.strip() diff --git a/lagent/adapters/proxy.py b/lagent/adapters/proxy.py new file mode 100644 index 00000000..4e6ff0d4 --- /dev/null +++ b/lagent/adapters/proxy.py @@ -0,0 +1,716 @@ +"""LLM Proxy Recorder — lightweight HTTP proxy that records LLM +request/response pairs for trajectory capture. + +The proxy intercepts all LLM API calls from external agents, records the +full request (including messages history) and response (including usage +and logprobs), then forwards the response unchanged. + +Session routing is done via the API key: external agents receive a +synthetic key ``sk-proxy-{session_id}`` which the proxy uses to tag +records, then replaces with the real API key before forwarding. + +Usage:: + + proxy = LLMProxyRecorder( + real_api_key="sk-ant-...", + real_base_url="https://api.anthropic.com", + ) + await proxy.start() + # set env for external agent: + # OPENAI_BASE_URL=http://localhost:{proxy.port}/v1 + # OPENAI_API_KEY=sk-proxy-{session_id} + records = proxy.get_records(session_id) + await proxy.stop() +""" + +import asyncio +import json +import logging +import re +from collections import defaultdict +from datetime import datetime +from typing import Any, Dict, List, Optional + +try: + from aiohttp import ClientSession, web +except ImportError: + ClientSession = None + web = None + +logger = logging.getLogger(__name__) + +SESSION_KEY_PATTERN = re.compile(r'^sk-proxy-(.+)$') + + +class LLMProxyRecorder: + """Lightweight HTTP proxy that records LLM request/response pairs. + + Args: + real_api_key: The actual API key to use when forwarding requests. + real_base_url: The actual LLM API base URL to forward to. + port: Port to listen on. 0 means auto-assign. + """ + + def __init__( + self, + real_api_key: str, + real_base_url: str, + port: int = 0, + http_proxy: Optional[str] = None, + ): + if web is None: + raise ImportError( + "aiohttp is required for LLMProxyRecorder. " + "Install it with: pip install aiohttp" + ) + self.real_api_key = real_api_key + self.real_base_url = real_base_url.rstrip('/') + self.port = port + self.http_proxy = http_proxy + self._records: Dict[str, List[dict]] = defaultdict(list) + self._app: Optional[web.Application] = None + self._runner: Optional[web.AppRunner] = None + self._site: Optional[web.TCPSite] = None + + @property + def url(self) -> str: + """The proxy base URL that external agents should use.""" + return f"http://127.0.0.1:{self.port}" + + @property + def is_running(self) -> bool: + return self._site is not None + + async def start(self): + """Start the proxy HTTP server.""" + if self.is_running: + return + self._app = web.Application() + # Catch-all route to proxy any path + self._app.router.add_route('*', '/{path:.*}', self._handle_request) + self._runner = web.AppRunner(self._app) + await self._runner.setup() + self._site = web.TCPSite(self._runner, '127.0.0.1', self.port) + await self._site.start() + # Update port if auto-assigned + if self.port == 0: + self.port = self._site._server.sockets[0].getsockname()[1] + logger.info(f"LLMProxyRecorder started on port {self.port}") + + async def stop(self): + """Stop the proxy HTTP server.""" + if self._runner: + await self._runner.cleanup() + self._site = None + self._runner = None + self._app = None + logger.info("LLMProxyRecorder stopped") + + def _parse_session_id(self, auth_header: str) -> Optional[str]: + """Extract session_id from Authorization header. + + Expects format: ``Bearer sk-proxy-{session_id}`` + """ + if not auth_header: + return None + token = auth_header.removeprefix('Bearer ').strip() + # Also handle x-api-key style (Anthropic) + match = SESSION_KEY_PATTERN.match(token) + return match.group(1) if match else None + + async def _handle_request(self, request: web.Request) -> web.Response: + """Proxy handler: extract session, forward, record, return.""" + # 1. Extract session from auth header + auth = request.headers.get('Authorization', '') + api_key = request.headers.get('x-api-key', '') + session_id = self._parse_session_id(auth) or self._parse_session_id(api_key) + + # 2. Read request body + request_body = await request.read() + request_data = None + try: + request_data = json.loads(request_body) if request_body else None + except (json.JSONDecodeError, UnicodeDecodeError): + pass + + # 3. Build forwarding headers — replace auth with real key + forward_headers = dict(request.headers) + forward_headers.pop('Host', None) + forward_headers.pop('host', None) + if 'Authorization' in forward_headers: + forward_headers['Authorization'] = f'Bearer {self.real_api_key}' + if 'x-api-key' in forward_headers: + forward_headers['x-api-key'] = self.real_api_key + + # 4. Forward to real LLM + # Build target URL, avoiding path duplication + # e.g. real_base_url="http://api.com/v1", path="/v1/chat/completions" + # should produce "http://api.com/v1/chat/completions" not "http://api.com/v1/v1/..." + req_path = request.match_info['path'] + from urllib.parse import urlparse + base_parsed = urlparse(self.real_base_url) + base_path = base_parsed.path.rstrip('/') + if req_path.startswith(base_path.lstrip('/')): + # Path already includes the base path prefix, use as-is + target_url = f"{base_parsed.scheme}://{base_parsed.netloc}/{req_path}" + else: + target_url = f"{self.real_base_url}/{req_path}" + if request.query_string: + target_url += f"?{request.query_string}" + + is_stream = request_data.get('stream', False) if request_data else False + + async with ClientSession() as client: + async with client.request( + method=request.method, + url=target_url, + headers=forward_headers, + data=request_body, + proxy=self.http_proxy, + ) as resp: + if is_stream: + # Stream response: collect chunks, forward as-is + response_chunks = [] + response = web.StreamResponse( + status=resp.status, + headers={ + k: v for k, v in resp.headers.items() + if k.lower() not in ('transfer-encoding', 'content-length', + 'content-encoding') + }, + ) + await response.prepare(request) + async for chunk in resp.content.iter_any(): + response_chunks.append(chunk) + await response.write(chunk) + await response.write_eof() + raw_response = b''.join(response_chunks) + else: + raw_response = await resp.read() + response = web.Response( + status=resp.status, + headers={ + k: v for k, v in resp.headers.items() + if k.lower() not in ('transfer-encoding', 'content-length', + 'content-encoding') + }, + body=raw_response, + ) + + # 5. Parse response for recording + response_data = None + if is_stream: + # Parse SSE stream to extract final data + response_data = self._parse_stream_response(raw_response) + else: + try: + response_data = json.loads(raw_response) + except (json.JSONDecodeError, UnicodeDecodeError): + pass + + # 6. Record + if session_id and request_data: + record = { + 'timestamp': datetime.now().isoformat(), + 'request': request_data, + 'response': response_data, + 'path': request.path, + 'method': request.method, + 'stream': is_stream, + } + self._records[session_id].append(record) + logger.debug( + f"Recorded LLM call for session {session_id}: " + f"{request.path} ({len(self._records[session_id])} total)" + ) + + return response + + @staticmethod + def _parse_stream_response(raw: bytes) -> Optional[dict]: + """Parse SSE stream response to reconstruct the complete message. + + Supports both Anthropic and OpenAI streaming formats. + """ + text = raw.decode('utf-8', errors='replace') + events = [] + for line in text.split('\n'): + line = line.strip() + if line.startswith('data: ') and line != 'data: [DONE]': + try: + events.append(json.loads(line[6:])) + except json.JSONDecodeError: + pass + + if not events: + return None + + # Detect format: OpenAI has "choices", Anthropic has "type" + first = events[0] + if 'choices' in first or first.get('object') == 'chat.completion.chunk': + return LLMProxyRecorder._parse_openai_stream(events) + else: + return LLMProxyRecorder._parse_anthropic_stream(events) + + @staticmethod + def _parse_openai_stream(events: list) -> Optional[dict]: + """Reconstruct OpenAI chat completion from stream chunks.""" + message = { + 'choices': [{'message': {'role': 'assistant', 'content': ''}}], + } + content_parts = [] + tool_calls_map: Dict[int, dict] = {} # index → {id, type, function} + usage = {} + + for event in events: + if event.get('id') and 'id' not in message: + message['id'] = event['id'] + if event.get('model'): + message['model'] = event['model'] + + choices = event.get('choices', []) + for choice in choices: + delta = choice.get('delta', {}) + + # Content + if delta.get('content'): + content_parts.append(delta['content']) + + # Tool calls + for tc_delta in delta.get('tool_calls', []): + idx = tc_delta.get('index', 0) + if idx not in tool_calls_map: + tool_calls_map[idx] = { + 'id': tc_delta.get('id', ''), + 'type': tc_delta.get('type', 'function'), + 'function': { + 'name': '', + 'arguments': '', + }, + } + tc = tool_calls_map[idx] + fn = tc_delta.get('function', {}) + if fn.get('name'): + tc['function']['name'] += fn['name'] + if fn.get('arguments'): + tc['function']['arguments'] += fn['arguments'] + if tc_delta.get('id'): + tc['id'] = tc_delta['id'] + + if choice.get('finish_reason'): + message['choices'][0]['finish_reason'] = choice['finish_reason'] + + if event.get('usage'): + usage = event['usage'] + + msg = message['choices'][0]['message'] + msg['content'] = ''.join(content_parts) + if tool_calls_map: + msg['tool_calls'] = [ + tool_calls_map[i] for i in sorted(tool_calls_map) + ] + if usage: + message['usage'] = usage + return message + + @staticmethod + def _parse_anthropic_stream(events: list) -> Optional[dict]: + """Reconstruct Anthropic message from stream events.""" + message = {} + content_blocks = [] + current_block = {} + + for event in events: + event_type = event.get('type', '') + + if event_type == 'message_start': + # Initial message metadata + msg = event.get('message', {}) + message = { + 'id': msg.get('id'), + 'type': 'message', + 'role': msg.get('role'), + 'model': msg.get('model'), + 'usage': msg.get('usage', {}), + 'content': [], + } + + elif event_type == 'content_block_start': + # New content block + current_block = dict(event.get('content_block', {})) + + elif event_type == 'content_block_delta': + # Incremental content + delta = event.get('delta', {}) + delta_type = delta.get('type', '') + if delta_type == 'text_delta': + current_block.setdefault('text', '') + current_block['text'] += delta.get('text', '') + elif delta_type == 'thinking_delta': + current_block.setdefault('thinking', '') + current_block['thinking'] += delta.get('thinking', '') + elif delta_type == 'input_json_delta': + current_block.setdefault('partial_json', '') + current_block['partial_json'] += delta.get('partial_json', '') + + elif event_type == 'content_block_stop': + # Finalize current block + if current_block: + # Parse partial_json into input for tool_use blocks + if 'partial_json' in current_block: + try: + current_block['input'] = json.loads( + current_block.pop('partial_json') + ) + except json.JSONDecodeError: + current_block['input'] = current_block.pop('partial_json') + content_blocks.append(current_block) + current_block = {} + + elif event_type == 'message_delta': + # Final metadata (stop_reason, usage delta) + delta = event.get('delta', {}) + message['stop_reason'] = delta.get('stop_reason') + # Merge usage delta + usage_delta = event.get('usage', {}) + if usage_delta: + for k, v in usage_delta.items(): + if isinstance(v, (int, float)): + message['usage'][k] = message['usage'].get(k, 0) + v + else: + message['usage'][k] = v + + # Assemble final message + message['content'] = content_blocks + return message + + def get_records(self, session_id: str) -> List[dict]: + """Get all recorded LLM call records for a session. + + Each record contains: + - timestamp: ISO 8601 timestamp + - request: Full request body (messages, tools, etc.) + - response: Full response body (choices, usage, etc.) + - path: API path + - method: HTTP method + - stream: Whether streaming was used + + Args: + session_id: The session identifier. + + Returns: + List of record dicts, ordered by timestamp. + """ + return list(self._records.get(session_id, [])) + + # ── Standardization ──────────────────────────────────────── + + @staticmethod + def normalize_record(record: dict) -> dict: + """Normalize a raw proxy record into a standard training format. + + Handles both Anthropic and OpenAI API formats, strips billing + headers and other noise, and produces a uniform structure:: + + { + "messages": [ + {"role": "system", "content": "..."}, + {"role": "user", "content": "..."}, + {"role": "assistant", "content": "...", + "reasoning_content": "...", "extra_info": {...}}, + ... + ], + "tools": [...], + "meta": { + "model": "...", + "usage": {...}, + "stop_reason": "...", + "timestamp": "...", + }, + "response": { + "role": "assistant", + "content": "...", + "reasoning_content": "...", + "extra_info": {"usage": {...}, "model": "...", ...}, + }, + } + """ + req = record.get('request', {}) + resp = record.get('response') or {} + + # ── Normalize system prompt ── + system = req.get('system') + system_text = None + if system: + if isinstance(system, list): + # Anthropic: list of content blocks, skip billing headers + parts = [] + for block in system: + if not isinstance(block, dict): + continue + text = block.get('text', '') + # Skip billing/tracking headers + if text.startswith('x-anthropic-billing-header'): + continue + parts.append(text) + system_text = '\n'.join(parts) if parts else None + elif isinstance(system, str): + system_text = system + + # ── Normalize messages ── + messages = [] + if system_text: + messages.append({'role': 'system', 'content': system_text}) + + for msg in req.get('messages', []): + role = msg.get('role', 'user') + content = msg.get('content', '') + + # Flatten content blocks to text + if isinstance(content, list): + text_parts = [] + reasoning_parts = [] + for block in content: + if not isinstance(block, dict): + text_parts.append(str(block)) + continue + btype = block.get('type', '') + if btype == 'text': + text_parts.append(block.get('text', '')) + elif btype in ('thinking', 'reasoning'): + reasoning_parts.append(block.get('thinking', block.get('text', ''))) + elif btype == 'tool_use': + text_parts.append(f"[tool_use: {block.get('name', '')}]") + elif btype == 'tool_result': + text_parts.append(f"[tool_result: {str(block.get('content', ''))[:200]}]") + else: + text_parts.append(block.get('text', str(block))) + + norm_msg = {'role': role, 'content': '\n'.join(text_parts)} + if reasoning_parts: + norm_msg['reasoning_content'] = '\n'.join(reasoning_parts) + else: + norm_msg = {'role': role, 'content': str(content)} + + messages.append(norm_msg) + + # ── Normalize tools ── + tools = req.get('tools') + + # ── Normalize response ── + resp_content = '' + resp_reasoning = '' + resp_extra = {} + + # Anthropic format + resp_blocks = resp.get('content', []) + if isinstance(resp_blocks, list): + text_parts = [] + reasoning_parts = [] + for block in resp_blocks: + if not isinstance(block, dict): + continue + btype = block.get('type', '') + if btype == 'text': + text_parts.append(block.get('text', '')) + elif btype in ('thinking', 'reasoning'): + reasoning_parts.append(block.get('thinking', block.get('text', ''))) + elif btype == 'tool_use': + text_parts.append(f"[tool_use: {block.get('name', '')}]") + resp_content = '\n'.join(text_parts) + resp_reasoning = '\n'.join(reasoning_parts) + + # OpenAI format + choices = resp.get('choices', []) + if choices and not resp_content: + choice = choices[0] + msg = choice.get('message', {}) + resp_content = msg.get('content', '') or '' + # Handle tool_calls + tool_calls = msg.get('tool_calls', []) + if tool_calls: + tc_strs = [] + for tc in tool_calls: + fn = tc.get('function', {}) + tc_strs.append(f"[tool_call: {fn.get('name', '')}({fn.get('arguments', '')[:100]})]") + if not resp_content: + resp_content = '\n'.join(tc_strs) + resp_extra['tool_calls'] = tool_calls + resp_extra['finish_reason'] = choice.get('finish_reason') + + # Usage (both formats) + usage = resp.get('usage', {}) + if usage: + resp_extra['usage'] = usage + if resp.get('model'): + resp_extra['model'] = resp['model'] + if resp.get('stop_reason'): + resp_extra['stop_reason'] = resp['stop_reason'] + + # Meta + model = req.get('model') or resp.get('model') + meta = { + 'model': model, + 'usage': usage, + 'timestamp': record.get('timestamp'), + } + if resp.get('stop_reason'): + meta['stop_reason'] = resp['stop_reason'] + if choices and choices[0].get('finish_reason'): + meta['stop_reason'] = choices[0]['finish_reason'] + + response_msg = {'role': 'assistant', 'content': resp_content} + if resp_reasoning: + response_msg['reasoning_content'] = resp_reasoning + if resp_extra: + response_msg['extra_info'] = resp_extra + + return { + 'messages': messages, + 'tools': tools, + 'meta': meta, + 'response': response_msg, + } + + def get_normalized_records(self, session_id: str) -> List[dict]: + """Get all records in normalized format.""" + return [self.normalize_record(r) for r in self.get_records(session_id)] + + # ── Chain Rebuilding ────────────────────────────────────── + + def rebuild_chains(self, session_id: str) -> List[List[dict]]: + """Rebuild conversation chains from normalized records. + + Two consecutive records belong to the same chain if: + 1. Messages count grew (history is appending) + 2. The previous response text appears in current messages + + Args: + session_id: The session identifier. + + Returns: + List of chains. Each chain is a list of normalized records. + """ + records = self.get_normalized_records(session_id) + if not records: + return [] + + chains: List[List[dict]] = [] + current_chain: List[dict] = [records[0]] + + for prev, curr in zip(records, records[1:]): + prev_msgs = prev['messages'] + curr_msgs = curr['messages'] + + msgs_grew = len(curr_msgs) > len(prev_msgs) + + prev_response_text = prev['response'].get('content', '')[:200] + has_prev_response = ( + prev_response_text + and any( + m.get('role') == 'assistant' + and prev_response_text in m.get('content', '') + for m in curr_msgs + ) + ) + + if msgs_grew and has_prev_response: + current_chain.append(curr) + else: + chains.append(current_chain) + current_chain = [curr] + + chains.append(current_chain) + return chains + + # ── Training Sample Export ──────────────────────────────── + + def to_training_samples(self, session_id: str) -> List[dict]: + """Convert recorded LLM calls into SFT/RL training samples. + + Each chain produces one sample:: + + { + "messages": [ + {"role": "system", "content": "..."}, + {"role": "user", "content": "..."}, + {"role": "assistant", "content": "...", + "reasoning_content": "...", + "extra_info": {"usage": {...}, "model": "..."}}, + ... + ], + "tools": [...], + "meta": { + "num_calls": 3, + "model": "...", + "total_usage": {...}, + }, + } + """ + chains = self.rebuild_chains(session_id) + samples = [] + + for chain in chains: + if not chain: + continue + + last = chain[-1] + + # Take the last record's messages (most complete history) + # + append the last response + messages = list(last['messages']) + + # Attach extra_info to assistant messages by matching response text + response_extra_map = {} + for rec in chain: + resp = rec['response'] + text = resp.get('content', '')[:200] + if text: + extra = dict(resp.get('extra_info', {})) + response_extra_map[text] = extra + + for msg in messages: + if msg.get('role') == 'assistant': + text = msg.get('content', '')[:200] + if text in response_extra_map: + msg['extra_info'] = response_extra_map[text] + + # Append final response as the last assistant message + last_resp = dict(last['response']) + messages.append(last_resp) + + # Aggregate usage + total_usage = { + 'total_input_tokens': 0, + 'total_output_tokens': 0, + } + for rec in chain: + u = rec['meta'].get('usage', {}) + total_usage['total_input_tokens'] += u.get( + 'input_tokens', u.get('prompt_tokens', 0)) + total_usage['total_output_tokens'] += u.get( + 'output_tokens', u.get('completion_tokens', 0)) + + sample = { + 'messages': messages, + 'tools': last.get('tools'), + 'meta': { + 'num_calls': len(chain), + 'model': last['meta'].get('model'), + 'total_usage': total_usage, + }, + } + samples.append(sample) + + return samples + + def clear(self, session_id: Optional[str] = None): + """Clear recorded data. + + Args: + session_id: Clear only this session. If None, clear all. + """ + if session_id: + self._records.pop(session_id, None) + else: + self._records.clear() diff --git a/lagent/adapters/sdk_adapter.py b/lagent/adapters/sdk_adapter.py new file mode 100644 index 00000000..bc31feae --- /dev/null +++ b/lagent/adapters/sdk_adapter.py @@ -0,0 +1,116 @@ +"""SDK Agent Adapter — wraps external Python SDK agents as lagent Agents. + +For frameworks that expose a Python API (LangChain, CrewAI, OpenAI Agents, +etc.), subclass ``SDKAgentAdapter`` and implement two methods: +``create_sdk_agent()`` and ``invoke_sdk_agent()``. + +Example:: + + class LangchainAdapter(SDKAgentAdapter): + def create_sdk_agent(self, config): + from langchain.agents import AgentExecutor, create_react_agent + ... + return AgentExecutor(agent=agent, tools=tools) + + def invoke_sdk_agent(self, agent, task, **kwargs): + result = agent.invoke({"input": task}) + return result["output"] + + async def invoke_sdk_agent_async(self, agent, task, **kwargs): + result = await agent.ainvoke({"input": task}) + return result["output"] + + adapter = LangchainAdapter( + name="langchain-react", + sdk_module="langchain.agents", + sdk_config={"model_name": "gpt-4", "tools": ["search"]}, + ) + result = await adapter("Research quantum computing trends") +""" + +import asyncio +import importlib +from abc import abstractmethod +from typing import Any, Callable, Dict, Optional + +from .base import AsyncExternalAgent + + +class SDKAgentAdapter(AsyncExternalAgent): + """Wraps an external agent accessible via Python SDK. + + Subclasses must implement ``create_sdk_agent()`` and + ``invoke_sdk_agent()``. Optionally override + ``invoke_sdk_agent_async()`` for native async support. + + Args: + sdk_module: Dotted module path to verify during setup. + Example: ``"langchain.agents"`` + sdk_config: Configuration dict passed to ``create_sdk_agent()``. + **kwargs: Passed to AsyncExternalAgent (name, working_dir, + env_vars, timeout, proxy, hooks). + """ + + def __init__( + self, + sdk_module: Optional[str] = None, + sdk_config: Optional[Dict[str, Any]] = None, + **kwargs, + ): + super().__init__(**kwargs) + self.sdk_module = sdk_module + self.sdk_config = sdk_config or {} + self._sdk_agent: Any = None + + def setup(self) -> None: + """Verify SDK is importable, then create the agent instance.""" + if self.sdk_module: + try: + importlib.import_module(self.sdk_module) + except ImportError as exc: + raise RuntimeError( + f"SDK module '{self.sdk_module}' not importable: {exc}. " + f"Install the required package." + ) from exc + if self._sdk_agent is None: + self._sdk_agent = self.create_sdk_agent(self.sdk_config) + + @abstractmethod + def create_sdk_agent(self, config: Dict[str, Any]) -> Any: + """Instantiate the external framework's agent object. + + Args: + config: The ``sdk_config`` dict from ``__init__``. + + Returns: + The external agent object (framework-specific type). + """ + + @abstractmethod + def invoke_sdk_agent(self, agent: Any, task: str, **kwargs) -> str: + """Run the external agent synchronously. + + Args: + agent: Object returned by ``create_sdk_agent()``. + task: The task/prompt string. + + Returns: + Textual result from the external agent. + """ + + async def invoke_sdk_agent_async( + self, agent: Any, task: str, **kwargs + ) -> str: + """Async version. Default: runs sync invoke in a thread executor. + + Override this if the SDK provides native async support. + """ + loop = asyncio.get_running_loop() + return await loop.run_in_executor( + None, lambda: self.invoke_sdk_agent(agent, task, **kwargs) + ) + + async def run_external_async(self, task: str, **kwargs) -> str: + return await self.invoke_sdk_agent_async( + self._sdk_agent, task, **kwargs + ) diff --git a/lagent/agents/__init__.py b/lagent/agents/__init__.py index 0a995d28..1a200725 100644 --- a/lagent/agents/__init__.py +++ b/lagent/agents/__init__.py @@ -10,24 +10,27 @@ StreamingAgent, StreamingSequential, ) -from .react import AsyncReAct, ReAct -from .stream import AgentForInternLM, AsyncAgentForInternLM, AsyncMathCoder, MathCoder +from .compact_agent import AsyncCompactAgent, estimate_token_count +from .internclaw_agent import ( + AsyncEnvAgent, + AsyncPolicyAgent, + InternClawAgent, +) __all__ = [ 'Agent', 'AgentDict', 'AgentList', 'AsyncAgent', - 'AgentForInternLM', - 'AsyncAgentForInternLM', - 'MathCoder', - 'AsyncMathCoder', - 'ReAct', - 'AsyncReAct', 'Sequential', 'AsyncSequential', 'StreamingAgent', 'StreamingSequential', 'AsyncStreamingAgent', 'AsyncStreamingSequential', + 'AsyncCompactAgent', + 'estimate_token_count', + 'AsyncEnvAgent', + 'AsyncPolicyAgent', + 'InternClawAgent', ] diff --git a/lagent/agents/agent.py b/lagent/agents/agent.py index ae99d6d9..e5ac8270 100644 --- a/lagent/agents/agent.py +++ b/lagent/agents/agent.py @@ -8,7 +8,7 @@ from lagent.agents.aggregator import DefaultAggregator from lagent.hooks import Hook, RemovableHandle from lagent.llms import BaseLLM -from lagent.memory import Memory, MemoryManager +from lagent.memory import Memory from lagent.prompts.parsers import StrParser from lagent.prompts.prompt_template import PromptTemplate from lagent.schema import AgentMessage, ModelStatusCode @@ -48,7 +48,7 @@ def __init__( ): self.name = name or self.__class__.__name__ self.llm: BaseLLM = create_object(llm) - self.memory: MemoryManager = MemoryManager(memory) if memory else None + self.memory: Memory = create_object(memory) self.output_format: StrParser = create_object(output_format) self.template = template self.description = description @@ -59,34 +59,31 @@ def __init__( hook = create_object(hook) self.register_hook(hook) - def update_memory(self, message, session_id=0): - if self.memory: - self.memory.add(message, session_id=session_id) - def __call__(self, *message: AgentMessage, session_id=0, **kwargs) -> AgentMessage: + def __call__(self, *message: AgentMessage, **kwargs) -> AgentMessage: # message.receiver = self.name message = [AgentMessage(sender='user', content=m) if isinstance(m, str) else copy.deepcopy(m) for m in message] for hook in self._hooks.values(): - result = hook.before_agent(self, message, session_id) + result = hook.before_agent(self, message) if result: message = result - self.update_memory(message, session_id=session_id) - response_message = self.forward(*message, session_id=session_id, **kwargs) + self.memory.add(message) + response_message = self.forward(*message, **kwargs) if not isinstance(response_message, AgentMessage): response_message = AgentMessage(sender=self.name, content=response_message) - self.update_memory(response_message, session_id=session_id) + self.memory.add(response_message) response_message = copy.deepcopy(response_message) for hook in self._hooks.values(): - result = hook.after_agent(self, response_message, session_id) + result = hook.after_agent(self, response_message) if result: response_message = result return response_message - def forward(self, *message: AgentMessage, session_id=0, **kwargs) -> Union[AgentMessage, str]: - formatted_messages = self.aggregator.aggregate( - self.memory.get(session_id), self.name, self.output_format, self.template + def forward(self, *message: AgentMessage, **kwargs) -> Union[AgentMessage, str]: + formatted_messages, tools = self.aggregator.aggregate( + self.memory, self.name, self.output_format, self.template ) - llm_response = self.llm.chat(formatted_messages, **kwargs) + llm_response = self.llm.chat(formatted_messages, tools=tools, **kwargs) if self.output_format: formatted_messages = self.output_format.parse_response(llm_response) return AgentMessage(sender=self.name, content=llm_response, formatted=formatted_messages) @@ -99,21 +96,18 @@ def __setattr__(self, __name: str, __value: Any) -> None: super().__setattr__('_agents', _agents) super().__setattr__(__name, __value) - def state_dict(self, session_id=None, prefix='', destination=None) -> Dict: + def state_dict(self, prefix='', destination=None) -> Dict: if destination is None: destination = {} if self.memory is not None: - if session_id not in self.memory.memory_map: - warnings.warn(f'No session id {session_id} in {prefix}memory') - memory = self.memory.get(session_id) - saved_memory = memory and memory.save() or [] + saved_memory = self.memory and self.memory.save() or [] destination.update({prefix + 'memory': saved_memory}) for name, agent in getattr(self, '_agents', {}).items(): if isinstance(agent, Agent): - agent.state_dict(destination=destination, prefix=prefix + name + ".", session_id=session_id) + agent.state_dict(destination=destination, prefix=prefix + name + ".") return destination - def load_state_dict(self, state_dict: Dict, session_id=0): + def load_state_dict(self, state_dict: Dict): _state_dict = self.state_dict() missing_keys = set(_state_dict) - set(state_dict) if missing_keys: @@ -132,16 +126,14 @@ def load_state_dict(self, state_dict: Dict, session_id=0): else: obj = getattr(obj, attr) if obj.memory is not None: - if session_id not in obj.memory.memory_map: - obj.memory.create_instance(session_id) - obj.memory.memory_map[session_id].load(state_dict[key] or []) + obj.memory.load(state_dict[key] or []) def register_hook(self, hook: Callable): handle = RemovableHandle(self._hooks) self._hooks[handle.id] = hook return handle - def reset(self, session_id=0, keypath: Optional[str] = None, recursive: bool = False): + def reset(self, keypath: Optional[str] = None, recursive: bool = False): assert not (keypath and recursive), 'keypath and recursive can\'t be used together' if keypath: keys, agent = keypath.split('.'), self @@ -150,13 +142,13 @@ def reset(self, session_id=0, keypath: Optional[str] = None, recursive: bool = F if key not in agents: raise KeyError(f'No sub-agent named {key} in {agent}') agent = agents[key] - agent.reset(session_id, recursive=False) + agent.reset(recursive=False) else: if self.memory: - self.memory.reset(session_id=session_id) + self.memory.reset() if recursive: for agent in getattr(self, '_agents', {}).values(): - agent.reset(session_id, recursive=True) + agent.reset(recursive=True) def __repr__(self): @@ -177,29 +169,29 @@ def _rcsv_repr(agent, n_indent=1): class AsyncAgentMixin: - async def __call__(self, *message: AgentMessage, session_id=0, **kwargs) -> AgentMessage: + async def __call__(self, *message: AgentMessage, **kwargs) -> AgentMessage: message = [AgentMessage(sender='user', content=m) if isinstance(m, str) else copy.deepcopy(m) for m in message] for hook in self._hooks.values(): - result = hook.before_agent(self, message, session_id) + result = hook.before_agent(self, message) if result: message = result - self.update_memory(message, session_id=session_id) - response_message = await self.forward(*message, session_id=session_id, **kwargs) + self.memory.add(message) + response_message = await self.forward(*message, **kwargs) if not isinstance(response_message, AgentMessage): response_message = AgentMessage(sender=self.name, content=response_message) - self.update_memory(response_message, session_id=session_id) + self.memory.add(response_message) response_message = copy.deepcopy(response_message) for hook in self._hooks.values(): - result = hook.after_agent(self, response_message, session_id) + result = hook.after_agent(self, response_message) if result: response_message = result return response_message - async def forward(self, *message: AgentMessage, session_id=0, **kwargs) -> Union[AgentMessage, str]: - formatted_messages = self.aggregator.aggregate( - self.memory.get(session_id), self.name, self.output_format, self.template + async def forward(self, *message: AgentMessage, **kwargs) -> Union[AgentMessage, str]: + formatted_messages, tools = self.aggregator.aggregate( + self.memory, self.name, self.output_format, self.template ) - llm_response = await self.llm.chat(formatted_messages, session_id, **kwargs) + llm_response = await self.llm.chat(formatted_messages, tools=tools, **kwargs) if self.output_format: formatted_messages = self.output_format.parse_response(llm_response) return AgentMessage(sender=self.name, content=llm_response, formatted=formatted_messages) @@ -215,34 +207,34 @@ class AsyncAgent(AsyncAgentMixin, Agent): class StreamingAgentMixin: """Component that makes agent calling output a streaming response.""" - def __call__(self, *message: AgentMessage, session_id=0, **kwargs) -> Generator[AgentMessage, None, None]: + def __call__(self, *message: AgentMessage, **kwargs) -> Generator[AgentMessage, None, None]: message = [AgentMessage(sender='user', content=m) if isinstance(m, str) else copy.deepcopy(m) for m in message] for hook in self._hooks.values(): - result = hook.before_agent(self, message, session_id) + result = hook.before_agent(self, message) if result: message = result - self.update_memory(message, session_id=session_id) + self.memory.add(message) response_message = AgentMessage(sender=self.name, content="") - for response_message in self.forward(*message, session_id=session_id, **kwargs): + for response_message in self.forward(*message, **kwargs): if not isinstance(response_message, AgentMessage): model_state, response = response_message response_message = AgentMessage(sender=self.name, content=response, stream_state=model_state) yield response_message.model_copy() - self.update_memory(response_message, session_id=session_id) + self.memory.add(response_message) response_message = copy.deepcopy(response_message) for hook in self._hooks.values(): - result = hook.after_agent(self, response_message, session_id) + result = hook.after_agent(self, response_message) if result: response_message = result yield response_message def forward( - self, *message: AgentMessage, session_id=0, **kwargs + self, *message: AgentMessage, **kwargs ) -> Generator[Union[AgentMessage, Tuple[ModelStatusCode, str]], None, None]: formatted_messages = self.aggregator.aggregate( - self.memory.get(session_id), self.name, self.output_format, self.template + self.memory, self.name, self.output_format, self.template ) - for model_state, response, *_ in self.llm.stream_chat(formatted_messages, session_id=session_id, **kwargs): + for model_state, response, *_ in self.llm.stream_chat(formatted_messages, **kwargs): yield ( AgentMessage( sender=self.name, @@ -258,35 +250,35 @@ def forward( class AsyncStreamingAgentMixin: """Component that makes asynchronous agent calling output a streaming response.""" - async def __call__(self, *message: AgentMessage, session_id=0, **kwargs) -> AsyncGenerator[AgentMessage, None]: + async def __call__(self, *message: AgentMessage, **kwargs) -> AsyncGenerator[AgentMessage, None]: message = [AgentMessage(sender='user', content=m) if isinstance(m, str) else copy.deepcopy(m) for m in message] for hook in self._hooks.values(): - result = hook.before_agent(self, message, session_id) + result = hook.before_agent(self, message) if result: message = result - self.update_memory(message, session_id=session_id) + self.memory.add(message) response_message = AgentMessage(sender=self.name, content="") - async for response_message in self.forward(*message, session_id=session_id, **kwargs): + async for response_message in self.forward(*message, **kwargs): if not isinstance(response_message, AgentMessage): model_state, response = response_message response_message = AgentMessage(sender=self.name, content=response, stream_state=model_state) yield response_message.model_copy() - self.update_memory(response_message, session_id=session_id) + self.memory.add(response_message) response_message = copy.deepcopy(response_message) for hook in self._hooks.values(): - result = hook.after_agent(self, response_message, session_id) + result = hook.after_agent(self, response_message) if result: response_message = result yield response_message async def forward( - self, *message: AgentMessage, session_id=0, **kwargs + self, *message: AgentMessage, **kwargs ) -> AsyncGenerator[Union[AgentMessage, Tuple[ModelStatusCode, str]], None]: formatted_messages = self.aggregator.aggregate( - self.memory.get(session_id), self.name, self.output_format, self.template + self.memory, self.name, self.output_format, self.template ) async for model_state, response, *_ in self.llm.stream_chat( - formatted_messages, session_id=session_id, **kwargs + formatted_messages, **kwargs ): yield ( AgentMessage( @@ -336,7 +328,7 @@ def add_agent(self, name: str, agent: Agent): assert isinstance(agent, Agent), f'{type(agent)} is not an Agent subclass' self._agents[str(name)] = agent - def forward(self, *message: AgentMessage, session_id=0, exit_at: Optional[int] = None, **kwargs) -> AgentMessage: + def forward(self, *message: AgentMessage, exit_at: Optional[int] = None, **kwargs) -> AgentMessage: assert exit_at is None or exit_at >= 0, 'exit_at should be greater than or equal to 0' if exit_at is None: exit_at = len(self) - 1 @@ -345,7 +337,7 @@ def forward(self, *message: AgentMessage, session_id=0, exit_at: Optional[int] = agent = next(iterator) if isinstance(message, AgentMessage): message = (message,) - message = agent(*message, session_id=session_id, **kwargs) + message = agent(*message, **kwargs) return message def __getitem__(self, key): @@ -361,7 +353,7 @@ def __len__(self): class AsyncSequential(AsyncAgentMixin, Sequential): async def forward( - self, *message: AgentMessage, session_id=0, exit_at: Optional[int] = None, **kwargs + self, *message: AgentMessage, exit_at: Optional[int] = None, **kwargs ) -> AgentMessage: assert exit_at is None or exit_at >= 0, 'exit_at should be greater than or equal to 0' if exit_at is None: @@ -371,14 +363,14 @@ async def forward( agent = next(iterator) if isinstance(message, AgentMessage): message = (message,) - message = await agent(*message, session_id=session_id, **kwargs) + message = await agent(*message, **kwargs) return message class StreamingSequential(StreamingAgentMixin, Sequential): """Streaming variant of the Sequential class""" - def forward(self, *message: AgentMessage, session_id=0, exit_at: Optional[int] = None, **kwargs): + def forward(self, *message: AgentMessage, exit_at: Optional[int] = None, **kwargs): assert exit_at is None or exit_at >= 0, 'exit_at should be greater than or equal to 0' if exit_at is None: exit_at = len(self) - 1 @@ -387,14 +379,14 @@ def forward(self, *message: AgentMessage, session_id=0, exit_at: Optional[int] = agent = next(iterator) if isinstance(message, AgentMessage): message = (message,) - for message in agent(*message, session_id=session_id, **kwargs): + for message in agent(*message, **kwargs): yield message class AsyncStreamingSequential(AsyncStreamingAgentMixin, Sequential): """Streaming variant of the AsyncSequential class""" - async def forward(self, *message: AgentMessage, session_id=0, exit_at: Optional[int] = None, **kwargs): + async def forward(self, *message: AgentMessage, exit_at: Optional[int] = None, **kwargs): assert exit_at is None or exit_at >= 0, 'exit_at should be greater than or equal to 0' if exit_at is None: exit_at = len(self) - 1 @@ -403,7 +395,7 @@ async def forward(self, *message: AgentMessage, session_id=0, exit_at: Optional[ agent = next(iterator) if isinstance(message, AgentMessage): message = (message,) - async for message in agent(*message, session_id=session_id, **kwargs): + async for message in agent(*message, **kwargs): yield message diff --git a/lagent/agents/aggregator/__init__.py b/lagent/agents/aggregator/__init__.py index 5d9065cc..89cda70c 100644 --- a/lagent/agents/aggregator/__init__.py +++ b/lagent/agents/aggregator/__init__.py @@ -1,4 +1,5 @@ from .default_aggregator import DefaultAggregator +from .context import InternClawContextBuilder from .tool_aggregator import InternLMToolAggregator -__all__ = ['DefaultAggregator', 'InternLMToolAggregator'] +__all__ = ['DefaultAggregator', 'InternClawContextBuilder', 'InternLMToolAggregator'] diff --git a/lagent/agents/aggregator/compact_aggregator.py b/lagent/agents/aggregator/compact_aggregator.py new file mode 100644 index 00000000..1e58ea9f --- /dev/null +++ b/lagent/agents/aggregator/compact_aggregator.py @@ -0,0 +1,81 @@ +"""Aggregator for CompactAgent. + +Takes formatted_messages (list[dict]) from policy's aggregator +as the input message content, and assembles them into: + [system: COMPACT_PROMPT, user: formatted conversation text] +""" + +from typing import Dict, List, Optional, Tuple + +from lagent.memory import Memory + + +class CompactAggregator: + """Aggregator that formats list[dict] messages into readable text for compaction. + + When CompactAgent receives a message whose content is list[dict] + (the formatted_messages from policy's aggregator), this aggregator + converts it into a human-readable conversation transcript and + prepends the compact prompt as system message. + """ + + def aggregate( + self, + messages: Memory, + name: str, + parser=None, + system_instruction: str = None, + tools: List[Dict] = None, + ) -> Tuple[List[Dict[str, str]], Optional[List[Dict]]]: + _messages = [] + + # System message: the compact prompt (passed as template/system_instruction) + if system_instruction: + _messages.append(dict(role='system', content=system_instruction)) + + # Find the input message with formatted_messages as content + all_msgs = messages.get_memory() + for msg in all_msgs: + content = msg.content + if isinstance(content, list): + # list[dict] from policy's aggregator → format as readable text + formatted = self._format_messages(content) + _messages.append(dict(role='user', content=formatted)) + elif isinstance(content, str) and content: + _messages.append(dict(role='user', content=content)) + latest_env_info = None + for message in all_msgs: + if getattr(message, 'env_info', None) is not None: + latest_env_info = message.env_info + + tools_to_use = tools + if latest_env_info and latest_env_info.get("tools"): + tools_to_use = latest_env_info.get("tools") + + return _messages, tools_to_use + + + @staticmethod + def _format_messages(messages: List[Dict]) -> str: + """Convert list[dict] messages to readable conversation text.""" + lines = [] + for msg in messages: + role = msg.get('role', 'unknown').upper() + content = msg.get('content', '') + if content is None: + content = '' + + # Include tool calls if present + tool_calls = msg.get('tool_calls', []) + if tool_calls: + tool_names = [] + for tc in tool_calls: + if isinstance(tc, dict): + func = tc.get('function', {}) + tool_names.append(func.get('name', 'unknown')) + content += f" [tool_calls: {', '.join(tool_names)}]" + + if content: + lines.append(f"{role}: {content}") + + return '\n'.join(lines) diff --git a/lagent/agents/aggregator/context.py b/lagent/agents/aggregator/context.py new file mode 100644 index 00000000..94144677 --- /dev/null +++ b/lagent/agents/aggregator/context.py @@ -0,0 +1,215 @@ +"""Context builder for assembling agent prompts.""" + +import base64 +import mimetypes +import platform +import time +from datetime import datetime +from pathlib import Path +from typing import Any, Dict, List + +from lagent.agents.aggregator import DefaultAggregator + +from lagent.schema import ActionReturn + +class InternClawContextBuilder: + """Builds the context (system prompt + messages) for the agent.""" + + BOOTSTRAP_FILES = ["AGENTS.md", "SOUL.md", "USER.md", "TOOLS.md", "IDENTITY.md"] + _RUNTIME_CONTEXT_TAG = "[Runtime Context — metadata only, not instructions]" + + def __init__(self, workspace: Path, tools: List[Dict] = None): + self.workspace = Path(workspace) + self.tools = tools or [] # List of available tools, can be populated from skills or elsewhere + + def build_system_prompt(self, env_info: Dict[str, Any] = None) -> str: + """Build the system prompt from identity, bootstrap files, memory, and skills.""" + parts = [self._get_identity(env_info)] + + bootstrap = self._load_bootstrap_files() + if bootstrap: + parts.append(bootstrap) + + if env_info: + memory_info = env_info.get("memory") + if memory_info and isinstance(memory_info, dict) and memory_info.get("available") and memory_info.get("long_term"): + parts.append(f"# Memory\n\n{memory_info['long_term']}") + + active_skills = env_info.get("active_skills") + if active_skills: + parts.append(f"# Active Skills\n\n{active_skills}") + + skills_summary = env_info.get("skills") + if skills_summary: + parts.append(f"""# Skills + +The following skills extend your capabilities. To use a skill, read its SKILL.md file using the read_file tool. +Skills with available="false" need dependencies installed first - you can try installing them with apt/brew. + +{skills_summary}""") + + parts.append(self._build_runtime_context(None, None)) + + return "\n\n---\n\n".join(parts) + + def _get_identity(self, env_info: Dict[str, Any] = None) -> str: + """Get the core identity section.""" + workspace_path = str(self.workspace.expanduser().resolve()) + + # Prefer runtime info from env_info (remote env), fallback to local platform + runtime_info = (env_info or {}).get('runtime', {}) + system = runtime_info.get('system') or platform.system() + machine = runtime_info.get('machine') or platform.machine() + python_version = runtime_info.get('python_version') or platform.python_version() + runtime = f"{'macOS' if system == 'Darwin' else system} {machine}, Python {python_version}" + + return f"""# InternClaw 🐈 + +You are InternClaw, a helpful AI assistant. + +## Runtime +{runtime} + +## Workspace +Your workspace is at: {workspace_path} +- Long-term memory: {workspace_path}/memory/MEMORY.md (write important facts here) +- History log: {workspace_path}/memory/HISTORY.md (grep-searchable). Each entry starts with [YYYY-MM-DD HH:MM]. +- Custom skills: {workspace_path}/skills/{{skill-name}}/SKILL.md + +## InternClaw Guidelines +- State intent before tool calls, but NEVER predict or claim results before receiving them. +- Before modifying a file, read it first. Do not assume files or directories exist. +- After writing or editing a file, re-read it if accuracy matters. +- If a tool call fails, analyze the error before retrying with a different approach. +- Ask for clarification when the request is ambiguous. + +Reply directly with text for conversations. Only use the 'message' tool to send to a specific chat channel.""" + + @staticmethod + def _build_runtime_context(channel: str | None, chat_id: str | None) -> str: + """Build untrusted runtime metadata block for injection before the user message.""" + now = datetime.now().strftime("%Y-%m-%d %H:%M (%A)") + tz = time.strftime("%Z") or "UTC" + lines = [f"Current Time: {now} ({tz})"] + if channel and chat_id: + lines += [f"Channel: {channel}", f"Chat ID: {chat_id}"] + return InternClawContextBuilder._RUNTIME_CONTEXT_TAG + "\n" + "\n".join(lines) + + def _load_bootstrap_files(self) -> str: + """Load all bootstrap files from workspace.""" + parts = [] + + for filename in self.BOOTSTRAP_FILES: + file_path = self.workspace / filename + if file_path.exists(): + content = file_path.read_text(encoding="utf-8") + parts.append(f"## {filename}\n\n{content}") + + return "\n\n".join(parts) if parts else "" + + def aggregate(self, + messages, + name: str, + parser = None, + system_instruction: str = None, + tools: List[Dict]= None + ) -> List[Dict[str, str]]: + """Aggregate messages into a format suitable for the agent.""" + messages_list = messages.get_memory() + + # Find the latest env_info + latest_env_info = None + for message in messages_list: + if getattr(message, 'env_info', None) is not None: + latest_env_info = message.env_info + + _message = [dict(role='system', content=self.build_system_prompt(env_info=latest_env_info))] + + # ── Handle conversation summary (compact) ──────────────── + # If env_info contains conversation_summary + compact_boundary, + # inject summary as a user message and skip messages before boundary. + compact_summary = None + compact_boundary = None + if latest_env_info: + compact_summary = latest_env_info.get("conversation_summary") + compact_boundary = latest_env_info.get("compact_boundary") + + if compact_summary and compact_boundary is not None: + _message.append(dict( + role='user', + content=( + f"[Conversation Summary — the following is a summary " + f"of the conversation up to this point]\n\n" + f"{compact_summary}" + ), + )) + # Only process messages AFTER the boundary index + messages_to_process = messages_list[compact_boundary:] + else: + messages_to_process = messages_list + + for message in messages_to_process: + if message.sender == name: + msg = {'role': 'assistant', 'content': message.content or ''} + if message.tool_calls: + msg['tool_calls'] = message.tool_calls + # When tool_calls are present, content should be None or empty for some APIs + if not message.content: + msg['content'] = None + if message.reasoning_content: + msg['reasoning_content'] = message.reasoning_content + _message.append(msg) + else: + user_message = message.content + if isinstance(user_message, list): + for m in user_message: + if isinstance(m, dict): + m = dict(m) # shallow copy to avoid mutating memory + tool_call_id = m.pop('tool_call_id', '') + m = ActionReturn(**m) + else: + tool_call_id = '' + assert isinstance(m, ActionReturn), f"Expected m to be ActionReturn, but got {type(m)}" + tool_msg = dict(role='tool', content=m.format_result(), name=m.type) + if tool_call_id: + tool_msg['tool_call_id'] = tool_call_id + _message.append(tool_msg) + else: + if len(_message) > 0 and _message[-1]['role'] == 'user': + _message[-1]['content'] += user_message + else: + _message.append(dict(role='user', content=user_message)) + + tools_to_use = tools or self.tools + if latest_env_info and latest_env_info.get("tools"): + tools_to_use = latest_env_info.get("tools") + + return _message, tools_to_use + + +if __name__ == "__main__": + # Example usage + from lagent.memory import Memory + from lagent.schema import AgentMessage + builder = InternClawContextBuilder(Path("/mnt/shared-storage-user/llmit/user/liukuikun/workspace/lagent/workspace")) + env_info = { + "skills": "weather", + "active_skills": "weather skill content", + "memory": {"available": True, "long_term": "It's always sunny in Philadelphia."} + } + system_prompt = builder.build_system_prompt(env_info=env_info) + print(system_prompt) + session = Memory() + session.add( + [ + AgentMessage(sender="user", content="What is the weather today?", role="user", env_info=env_info), + AgentMessage(sender="agent", content="The weather is sunny.", role="assistant"), + AgentMessage(sender="user", content="What about tomorrow?", role="user"), + AgentMessage(sender="agent", content="Tomorrow will be cloudy.", role="assistant"), + AgentMessage(sender="user", content="Any plans for the weekend?", role="user"), + AgentMessage(sender="agent", content="I am planning to go hiking.", role="assistant"), + ] + ) + msgs, tools = builder.aggregate(session, name="agent") + for msg in msgs: + print(f"{msg['role']}: {msg['content']}\n") diff --git a/lagent/agents/aggregator/default_aggregator.py b/lagent/agents/aggregator/default_aggregator.py index 0888aef5..669c5613 100644 --- a/lagent/agents/aggregator/default_aggregator.py +++ b/lagent/agents/aggregator/default_aggregator.py @@ -1,4 +1,4 @@ -from typing import Dict, List +from typing import Dict, List, Optional, Tuple from lagent.memory import Memory from lagent.prompts import StrParser @@ -10,7 +10,9 @@ def aggregate(self, messages: Memory, name: str, parser: StrParser = None, - system_instruction: str = None) -> List[Dict[str, str]]: + system_instruction: str = None, + tools: List[Dict] = None, + ) -> Tuple[List[Dict[str, str]], Optional[List[Dict]]]: _message = [] messages = messages.get_memory() if system_instruction: @@ -26,7 +28,16 @@ def aggregate(self, _message[-1]['content'] += user_message else: _message.append(dict(role='user', content=user_message)) - return _message + latest_env_info = None + for message in messages: + if getattr(message, 'env_info', None) is not None: + latest_env_info = message.env_info + + tools_to_use = tools + if latest_env_info and latest_env_info.get("tools"): + tools_to_use = latest_env_info.get("tools") + + return _message, tools_to_use @staticmethod def aggregate_system_intruction(system_intruction) -> List[dict]: diff --git a/lagent/agents/compact_agent.py b/lagent/agents/compact_agent.py new file mode 100644 index 00000000..df2831c3 --- /dev/null +++ b/lagent/agents/compact_agent.py @@ -0,0 +1,161 @@ +"""CompactAgent — context compression as a first-class Agent. + +A standard AsyncAgent with a compact-specific template (COMPACT_PROMPT). +The orchestrator passes the policy_message to ``should_compact()`` to +check, then calls the agent normally. CompactAgent's own ContextBuilder +assembles the summary request. + +Usage (in orchestrator):: + + if compact_agent.should_compact(policy_message): + summary_msg = await compact_agent(policy_message) + # inject summary into env_info for next turn +""" + +from __future__ import annotations + +import json +import logging +from typing import Optional + +from lagent.agents.agent import AsyncAgent +from lagent.schema import AgentMessage + +logger = logging.getLogger("lagent.agents.compact_agent") + + +# ── Token estimation ────────────────────────────────────────────── + +def estimate_token_count(messages: list, tools: Optional[list] = None) -> int: + """Rough token count estimation (chars / 4). + + Good enough for threshold decisions — we don't need exact counts. + """ + total_chars = 0 + for msg in messages: + content = msg.get("content") or "" + if isinstance(content, str): + total_chars += len(content) + elif isinstance(content, list): + for part in content: + total_chars += len(str(part)) + total_chars += len(msg.get("reasoning_content") or "") + + if tools: + total_chars += len(json.dumps(tools, ensure_ascii=False)) + + return total_chars // 4 + + +# ── Compact prompt ──────────────────────────────────────────────── + +COMPACT_PROMPT = """\ +CRITICAL: Respond with TEXT ONLY. Do NOT call any tools. \ +Tool calls will be REJECTED and will waste your only turn. + +Your task is to create a detailed summary of the conversation so far. \ +This summary will REPLACE the conversation history, so it must contain \ +ALL information needed to continue working effectively. + +Please provide a thorough summary covering these sections: + +## 1. Primary Request and Intent +What is the user's overall goal? What specific outcomes do they want? + +## 2. Key Technical Concepts +What technologies, frameworks, APIs, or patterns are involved? + +## 3. Files and Code Sections +List ALL files that were read, created, or modified, with a brief note \ +on what was done to each. Include key code snippets that would be \ +needed to continue the work. + +## 4. Errors and Fixes +Document any errors encountered and how they were (or weren't) resolved. + +## 5. Problem-Solving Approach +What strategies were tried? What worked, what didn't? + +## 6. All User Messages (Chronological) +Reproduce the ESSENCE of every user message in order. Do not skip any. + +## 7. Pending Tasks +What remains to be done? What was the user's last request? + +## 8. Current Work State +What is the current state of the work? What file is being edited? \ +What was the last action taken? + +## 9. Optional Next Step +If there's an obvious next action, state it briefly. + +Be thorough and specific. The model reading this summary will have \ +NO access to the original conversation. +""" + + +# ── CompactAgent ────────────────────────────────────────────────── + +class AsyncCompactAgent(AsyncAgent): + """Context compression Agent. Peer of PolicyAgent / EnvAgent. + + Has its own aggregator/contextbuilder that assembles + COMPACT_PROMPT + the input message into a summarisation request. + + ``should_compact(message)`` inspects the message to decide + whether compaction is needed — the orchestrator doesn't need to + know what fields to check. + + Parameters + ---------- + llm : BaseLLM or dict + LLM provider for summarisation (can share with policy). + max_context_tokens : int + Maximum context window size in tokens. + threshold_ratio : float + Trigger compact when usage exceeds this ratio of max tokens. + """ + + def __init__( + self, + max_context_tokens: int = 128_000, + threshold_ratio: float = 0.85, + **kwargs, + ): + # Default template to COMPACT_PROMPT if not provided + kwargs.setdefault('template', COMPACT_PROMPT) + # Default aggregator to CompactAggregator + if 'aggregator' not in kwargs: + from lagent.agents.aggregator.compact_aggregator import CompactAggregator + kwargs['aggregator'] = CompactAggregator() + super().__init__(**kwargs) + self._max_context_tokens = max_context_tokens + self._threshold_ratio = threshold_ratio + self._consecutive_failures = 0 + self._max_consecutive_failures = 3 + + @property + def threshold_tokens(self) -> int: + return int(self._max_context_tokens * self._threshold_ratio) + + def should_compact(self, message: AgentMessage) -> bool: + """Check whether compaction should be triggered. + + Inspects the message's extra_info for context_tokens, + but subclasses can override with different strategies + (e.g. message count, content length, etc.). + + Parameters + ---------- + message : AgentMessage + Typically the policy_message with extra_info. + """ + if self._consecutive_failures >= self._max_consecutive_failures: + logger.warning( + "Compact circuit breaker open: %d consecutive failures", + self._consecutive_failures, + ) + return False + context_tokens = (message.extra_info or {}).get('context_tokens', 0) + return context_tokens > self.threshold_tokens + diff --git a/lagent/agents/internclaw_agent.py b/lagent/agents/internclaw_agent.py new file mode 100644 index 00000000..1419e24d --- /dev/null +++ b/lagent/agents/internclaw_agent.py @@ -0,0 +1,488 @@ +import asyncio +import base64 +import json +import mimetypes +import platform +import time +from copy import deepcopy +from dataclasses import asdict +from datetime import datetime +from pathlib import Path +from typing import Any, Dict, Optional + +from tenacity import retry, retry_if_result, stop_after_attempt, wait_fixed + +from lagent.actions import AsyncActionExecutor, BaseAction, ActionExecutor +from lagent.agents.agent import Agent, AsyncAgent +from lagent.skills.skills import BaseSkillsBackend, SkillsLoader +from lagent.schema import ( + ActionReturn, + ActionStatusCode, + ActionValidCode, + AgentMessage, + AgentStatusCode, +) +from lagent.utils import create_object + + +def get_tool_prompt(actions: list, exclude_arguments: list = None) -> str: + exclude_arguments = exclude_arguments or [] + + def _convert_tool_schema(action_description: dict, name_pattern: str = '{}') -> dict: + action_description = deepcopy(action_description) + properties = {} + required = list(action_description.get('required', [])) + for param in action_description['parameters']: + param = deepcopy(param) + param_name, param_type = param.pop('name'), param.pop('type') + if param_name in exclude_arguments: + if param_name in required: + required.remove(param_name) + continue + param_type = [t.lower() for t in param_type] if isinstance(param_type, list) else param_type.lower() + properties[param_name] = {'type': param_type, **param} + return { + 'type': 'function', + 'function': { + 'name': name_pattern.format(action_description['name']), + 'description': action_description['description'], + 'parameters': {'type': 'object', 'properties': properties, 'required': required}, + }, + } + + tools = [] + for action in actions if isinstance(actions, list) else [actions]: + action = create_object(action) + action_desc = action.description + if action.is_toolkit: + for api in action_desc['api_list']: + tools.append(_convert_tool_schema(api, f"{action.name}.{{}}")) + else: + tools.append(_convert_tool_schema(action_desc)) + return tools + +class AsyncPolicyAgent(AsyncAgent): + + async def forward(self, *message, **kwargs): + formatted_messages, tools = self.aggregator.aggregate( + self.memory, self.name, self.output_format, self.template + ) + llm_response = await self.llm.chat(formatted_messages, tools=tools, **kwargs) + message = AgentMessage( + sender=self.name, + content=llm_response.get('content') or '', + tool_calls=llm_response.get('tool_calls') or [], + reasoning_content=llm_response.get('reasoning_content'), + ) + return message + + +class AsyncEnvAgent(AsyncAgent): + def __init__(self, + actions, + skills: SkillsLoader=None, + long_term_memory=None, + **kwargs): + super().__init__(**kwargs) + if isinstance(actions, ActionExecutor) or hasattr(actions, 'forward'): + self.actions = actions + else: + self.actions = AsyncActionExecutor(actions) + self.skills = create_object(skills) + self.long_term_memory = create_object(long_term_memory) + + async def get_env_info(self) -> Dict[str, Any]: + env_info: Dict[str, Any] = { + 'skills': '', + 'active_skills': '', + 'memory': '', + 'tools': [], + 'runtime': {} + } + + if self.skills is not None: + env_info['skills'] = await self.skills.build_skills_summary() + always_skills = await self.skills.get_always_skills() + if always_skills: + env_info['active_skills'] = await self.skills.load_skills_for_context(always_skills) + + if self.long_term_memory is not None: + env_info['memory'] = await self.long_term_memory.get_info() + if self.actions: + env_info['tools'] = get_tool_prompt(list(self.actions.actions.values())) + + env_info['runtime'] = { + 'system': platform.system(), + 'machine': platform.machine(), + 'python_version': platform.python_version(), + } + + return env_info + + async def forward(self, message, **kwargs): + if isinstance(message, str): + return AgentMessage(sender=self.name, content=message, env_info=await self.get_env_info()) + + if not message.tool_calls: + return AgentMessage( + sender=self.name, + content=message.content, + env_info=await self.get_env_info(), + tool_calls=message.tool_calls, + ) + + @retry( + stop=stop_after_attempt(3), + wait=wait_fixed(2), + retry=retry_if_result( + lambda r: r.valid == ActionValidCode.OPEN + and r.state not in [ActionStatusCode.SUCCESS, ActionStatusCode.ARGS_ERROR] + ), + retry_error_callback=lambda retry_state: retry_state.outcome.result(), + ) + async def _inner_func(tool_call): + tool_call = deepcopy(tool_call) + try: + if tool_call['function']['name'].split('.', 1)[0] not in self.actions: + return ActionReturn(valid=ActionValidCode.INVALID, errmsg=f"Tool {tool_call['function']['name']} Not Found") + if isinstance(tool_call['function']['arguments'], str): + tool_call['function']['arguments'] = json.loads(tool_call['function']['arguments']) + except Exception as e: + return ActionReturn(valid=ActionValidCode.INVALID, errmsg=str(e)) + tool_response: ActionReturn = ( + await self.actions( + AgentMessage( + sender='assistant', content=dict(name=tool_call['function']['name'], parameters=tool_call['function']['arguments']) + ), + ) + ).content + return tool_response + + tasks = [_inner_func(tool_call) for tool_call in message.tool_calls] + responses = await asyncio.gather(*tasks) + for i, resp in enumerate(responses): + if resp.valid != ActionValidCode.OPEN: + return AgentMessage( + sender=self.name, + content=f'Tool Call Error: {resp.errmsg} in tool call ' + f'{json.dumps(message.tool_calls[i], ensure_ascii=False)}', + ) + if resp.state != ActionStatusCode.SUCCESS: + return AgentMessage( + sender=self.name, + content=f'Tool Call Error: {resp.errmsg} in tool call ' + f'{json.dumps(message.tool_calls[i], ensure_ascii=False)}', + reward=-1 if resp.state == ActionStatusCode.ARGS_ERROR else 0, + ) + # Pair each ActionReturn with its tool_call_id for proper LLM API formatting + tool_results = [] + for tc, r in zip(message.tool_calls, responses): + result_dict = asdict(r) + result_dict['tool_call_id'] = tc.get('id', '') + tool_results.append(result_dict) + return_message = AgentMessage( + sender=self.name, + content=tool_results, + reward=0.0, + env_info=await self.get_env_info(), + ) + + return return_message + + +class InternClawAgent(AsyncAgent): + def __init__(self, + policy_agent: Dict, + env_agent: Dict, + compact_agent: Dict = None, + consolidate_agent: Dict = None, + max_turn: int = 500, + finish_condition: Optional[callable] = lambda m, _: m and not m.tool_calls, + **kwargs): + super().__init__(**kwargs) + self.policy_agent = create_object(policy_agent) + self.env_agent = create_object(env_agent) + self.compact_agent = create_object(compact_agent) + self.consolidate_agent = create_object(consolidate_agent) + self.max_turn = max_turn + self.finish_condition = finish_condition + + async def forward(self, env_message, **kwargs): + selection_message: AgentMessage = None + current_turn = 0 + env_message = await self.env_agent(env_message, **kwargs) + + while not ( + self.finish_condition is not None + and self.finish_condition(selection_message, env_message) + ) and (self.max_turn is None or current_turn < self.max_turn): + selection_message = await self.policy_agent(env_message, **kwargs) + + # ── Orchestrator manages memory ── + await self._maybe_manage_memory(selection_message, env_message) + + env_message = await self.env_agent(selection_message) + current_turn += 1 + if selection_message is not None: + return AgentMessage(sender=self.name, content=selection_message.content, finish_reason='stop') + return AgentMessage(sender=self.name, content="Finished", finish_reason='stop') + + async def _maybe_manage_memory( + self, policy_message: AgentMessage, env_message: AgentMessage, + ) -> None: + """Orchestrate compact and consolidate. + + Orchestrator calls policy's aggregator to get formatted_messages, + checks should_compact, and if triggered: + 1. Runs consolidate_agent (optional) + 2. Runs compact_agent to produce summary + 3. Injects summary + boundary into env_message + ContextBuilder reads these on the next turn. + """ + if not self.compact_agent: + return + + # Get formatted context from policy's aggregator + formatted_messages, tools = self.policy_agent.aggregator.aggregate( + self.policy_agent.memory, self.policy_agent.name, + self.policy_agent.output_format, self.policy_agent.template, + ) + + from lagent.agents.compact_agent import estimate_token_count + compact_input = AgentMessage( + sender=self.name, + content=formatted_messages, + extra_info={'context_tokens': estimate_token_count(formatted_messages, tools)}, + ) + + if not self.compact_agent.should_compact(compact_input): + return + + import logging + _logger = logging.getLogger("lagent.agents.internclaw") + + # 1. Consolidate first (preserve info before compacting) + if self.consolidate_agent: + try: + await self.consolidate_agent(compact_input) + self.consolidate_agent.reset(recursive=True) + _logger.info("Consolidation completed") + except Exception: + _logger.exception("Consolidation failed, continuing with compact") + + # 2. Compact — inject summary + boundary into env_message + try: + summary_msg = await self.compact_agent(compact_input) + self.compact_agent.reset(recursive=True) + if summary_msg and summary_msg.content: + if env_message.env_info is None: + env_message.env_info = {} + env_message.env_info['conversation_summary'] = summary_msg.content + env_message.env_info['compact_boundary'] = len( + self.policy_agent.memory.memory + ) + _logger.info("Compact summary injected (%d chars)", len(summary_msg.content)) + except Exception: + _logger.exception("Compact failed") + +if __name__ == "__main__": + import asyncio + import os + from pathlib import Path + + from lagent.agents.aggregator.context import InternClawContextBuilder + from lagent.agents.compact_agent import AsyncCompactAgent + from lagent.actions.filesystem import ReadFileAction, WriteFileAction, EditFileAction + from lagent.actions.shell import ShellAction + from lagent.actions.save_memory import AsyncSaveMemoryAction + from lagent.memory.openclaw_provider import OpenClawMemoryProvider + from lagent.hooks.logger import MessageLogger + from lagent.llms.model import AsyncAPIClient, ModelConfig, SampleParameters + from lagent.agents.aggregator.compact_aggregator import CompactAggregator + # ── Model config ── + model_name = "Pro/moonshotai/Kimi-K2.5" + api_base = "http://35.220.164.252:3888/v1" + api_key = "" + proxy = "http://100.100.72.89:8899" + + model_name = "/mnt/shared-storage-user/llmit1/user/liujiangning/exp/s2_preview/agent_rl/s2-preview-thinker_sft_0228b_rl0312rc1_fix_klmismatch/20260331212858/hf-15" + api_base = "http://10.102.252.171:23333/v1" + api_key = "sk-admin" + proxy = None + model = AsyncAPIClient( + model=ModelConfig(model=model_name, base_url=api_base, api_key=api_key, proxy=proxy), + sample_params=SampleParameters(temperature=0.7, top_p=1.0, top_k=50), + timeout=600, + max_retry=500, + sleep_interval=5, + extra_body=dict(spaces_between_special_tokens=False) + ) + + workspace = Path("/mnt/shared-storage-user/llmit/user/liukuikun/workspace/lagent/workspace") + + CONSOLIDATION_PROMPT = ( + "You are a memory consolidation agent. Review the conversation " + "and call the save_memory tool to persist important information.\n\n" + "Extract key facts, decisions, user preferences, and project context. " + "Merge with existing long-term memory. For history_entry, write a " + "grep-searchable summary starting with [YYYY-MM-DD HH:MM]." + ) + + async def main(): + # ── 1. Actions ── + base_actions = [ + ReadFileAction(workspace=str(workspace)), + WriteFileAction(workspace=str(workspace)), + EditFileAction(workspace=str(workspace)), + ShellAction(working_dir=str(workspace)), + ] + + # ── 2. Memory provider (read) + action (write) ── + memory_provider = OpenClawMemoryProvider(workspace) + save_action = AsyncSaveMemoryAction(workspace) + + # ── 3. Hooks ── + logger_hook = MessageLogger() + + # ── 4. Policy agent ── + aggregator = InternClawContextBuilder(workspace, tools=None) + policy = AsyncPolicyAgent( + llm=model, + aggregator=aggregator, + hooks=[logger_hook], + ) + + # ── 5. Env agent ── + env = AsyncEnvAgent( + actions=base_actions + [save_action], + skills=SkillsLoader(workspace), + long_term_memory=memory_provider, + ) + + # ── 6. Compact agent ── + compact = AsyncCompactAgent( + name='compact', + llm=model, + max_context_tokens=65536, + threshold_ratio=0.5, + ) + + # ── 7. Consolidate agent (standard InternClawAgent) ── + consolidate_policy = AsyncPolicyAgent( + name='consolidate_policy', + llm=model, + template=CONSOLIDATION_PROMPT, + hooks=[logger_hook], + aggregator=CompactAggregator() + ) + consolidate_env = AsyncEnvAgent( + actions=[AsyncSaveMemoryAction(workspace)], + ) + consolidate = InternClawAgent( + policy_agent=consolidate_policy, + env_agent=consolidate_env, + max_turn=1, + finish_condition=None, + ) + + # ── 8. Orchestrator ── + agent = InternClawAgent( + policy_agent=policy, + env_agent=env, + compact_agent=compact, + consolidate_agent=consolidate, + ) + + # ── Interactive loop ── + print("=" * 60) + print(" InternClaw Agent (with Memory System)") + print(" Commands:") + print(" quit/exit — stop") + print(" memory — check MEMORY.md") + print(" history — check HISTORY.md") + print(" compact — force compact (consolidate + compress)") + print(" consolidate — force consolidate only") + print("=" * 60) + + while True: + try: + user_input = input("\n[You] > ").strip() + except (EOFError, KeyboardInterrupt): + print("\nBye!") + break + + if not user_input: + continue + if user_input.lower() in ('quit', 'exit'): + print("Bye!") + break + if user_input.lower() == 'memory': + info = await memory_provider.get_info() + print("\n--- MEMORY.md ---") + print(info.get('long_term', '(empty)')) + print("--- end ---") + continue + if user_input.lower() == 'history': + history_file = workspace / "memory" / "HISTORY.md" + if history_file.exists(): + print("\n--- HISTORY.md ---") + print(history_file.read_text()) + print("--- end ---") + else: + print("(no history yet)") + continue + if user_input.lower() in ('compact', 'consolidate'): + from lagent.agents.compact_agent import estimate_token_count + # Get formatted context from policy + formatted_messages, tools = policy.aggregator.aggregate( + policy.memory, policy.name, policy.output_format, policy.template, + ) + token_count = estimate_token_count(formatted_messages, tools) + print(f"\n Session: {len(policy.memory.memory)} messages, ~{token_count} tokens") + + if user_input.lower() in ('compact', 'consolidate'): + # Force consolidate + print(" Running consolidation...") + try: + compact_input = AgentMessage( + sender='user', + content=formatted_messages, + extra_info={'context_tokens': token_count}, + ) + await consolidate(compact_input) + print(" Consolidation done. Check 'memory' and 'history'.") + except Exception as e: + print(f" Consolidation failed: {e}") + + if user_input.lower() == 'compact': + # Also run compact + print(" Running compact...") + try: + compact_input = AgentMessage( + sender='user', + content=formatted_messages, + extra_info={'context_tokens': token_count}, + ) + summary_msg = await compact(compact_input) + summary_content = summary_msg.content + if isinstance(summary_content, dict): + summary_content = summary_content.get('content', '') + if summary_content: + # Inject into last env message in policy memory + for msg in reversed(policy.memory.memory): + if msg.env_info is not None: + msg.env_info['conversation_summary'] = summary_content + msg.env_info['compact_boundary'] = len(policy.memory.memory) + break + print(f" Compact done. Summary: {len(summary_content)} chars") + print(f" First 200 chars: {summary_content[:200]}...") + else: + print(" Compact returned empty summary.") + except Exception as e: + print(f" Compact failed: {e}") + continue + + response = await agent(user_input) + print(f"\n[Agent] {response.content}") + + asyncio.run(main()) \ No newline at end of file diff --git a/lagent/agents/react.py b/lagent/agents/react.py deleted file mode 100644 index 4a942a06..00000000 --- a/lagent/agents/react.py +++ /dev/null @@ -1,164 +0,0 @@ -import json -from typing import Callable, Dict, List, Union - -from pydantic import BaseModel, Field - -from lagent.actions import ActionExecutor, AsyncActionExecutor, BaseAction -from lagent.agents.agent import Agent, AsyncAgent -from lagent.agents.aggregator import DefaultAggregator -from lagent.hooks import ActionPreprocessor -from lagent.llms import BaseLLM -from lagent.memory import Memory -from lagent.prompts.parsers.json_parser import JSONParser -from lagent.prompts.prompt_template import PromptTemplate -from lagent.schema import AgentMessage - -select_action_template = """你是一个可以调用外部工具的助手,可以使用的工具包括: -{action_info} -{output_format} -开始!""" - -output_format_template = """如果使用工具请遵循以下格式回复: -{function_format} - -如果你已经知道了答案,或者你不需要工具,请遵循以下格式回复 -{finish_format}""" - - -class ReAct(Agent): - - def __init__( - self, - llm: Union[BaseLLM, Dict], - actions: Union[BaseAction, List[BaseAction]], - template: Union[PromptTemplate, str] = None, - memory: Dict = dict(type=Memory), - output_format: Dict = dict(type=JSONParser), - aggregator: Dict = dict(type=DefaultAggregator), - hooks: List = [dict(type=ActionPreprocessor)], - finish_condition: Callable[[AgentMessage], bool] = lambda m: 'conclusion' in m.content - or 'conclusion' in m.formatted, - max_turn: int = 5, - **kwargs - ): - self.max_turn = max_turn - self.finish_condition = finish_condition - self.actions = ActionExecutor(actions=actions, hooks=hooks) - self.select_agent = Agent( - llm=llm, - template=template.format( - action_info=json.dumps(self.actions.description()), output_format=output_format.format_instruction() - ), - output_format=output_format, - memory=memory, - aggregator=aggregator, - hooks=hooks, - ) - super().__init__(**kwargs) - - def forward(self, message: AgentMessage, session_id=0, **kwargs) -> AgentMessage: - for _ in range(self.max_turn): - message = self.select_agent(message, session_id=session_id, **kwargs) - if self.finish_condition(message): - return message - message = self.actions(message, session_id=session_id) - return message - - -class AsyncReAct(AsyncAgent): - - def __init__( - self, - llm: Union[BaseLLM, Dict], - actions: Union[BaseAction, List[BaseAction]], - template: Union[PromptTemplate, str] = None, - memory: Dict = dict(type=Memory), - output_format: Dict = dict(type=JSONParser), - aggregator: Dict = dict(type=DefaultAggregator), - hooks: List = [dict(type=ActionPreprocessor)], - finish_condition: Callable[[AgentMessage], bool] = lambda m: 'conclusion' in m.content - or 'conclusion' in m.formatted, - max_turn: int = 5, - **kwargs - ): - self.max_turn = max_turn - self.finish_condition = finish_condition - self.actions = AsyncActionExecutor(actions=actions, hooks=hooks) - self.select_agent = AsyncAgent( - llm=llm, - template=template.format( - action_info=json.dumps(self.actions.description()), output_format=output_format.format_instruction() - ), - output_format=output_format, - memory=memory, - aggregator=aggregator, - hooks=hooks, - ) - super().__init__(**kwargs) - - async def forward(self, message: AgentMessage, session_id=0, **kwargs) -> AgentMessage: - for _ in range(self.max_turn): - message = await self.select_agent(message, session_id=session_id, **kwargs) - if self.finish_condition(message): - return message - message = await self.actions(message, session_id=session_id) - return message - - -if __name__ == '__main__': - import asyncio - - from lagent.llms import GPTAPI, AsyncGPTAPI - - class ActionCall(BaseModel): - name: str = Field(description='调用的函数名称') - parameters: Dict = Field(description='调用函数的参数') - - class ActionFormat(BaseModel): - thought_process: str = Field( - description='描述当前所处的状态和已知信息。这有助于明确目前所掌握的信息和接下来的搜索方向。' - ) - action: ActionCall = Field(description='当前步骤需要执行的操作,包括函数名称和参数。') - - class FinishFormat(BaseModel): - thought_process: str = Field( - description='描述当前所处的状态和已知信息。这有助于明确目前所掌握的信息和接下来的搜索方向。' - ) - conclusion: str = Field(description='总结当前的搜索结果,回答问题。') - - prompt_template = PromptTemplate(select_action_template) - output_format = JSONParser(output_format_template, function_format=ActionFormat, finish_format=FinishFormat) - - agent = ReAct( - llm=dict( - type=GPTAPI, - model_type='gpt-4o-2024-05-13', - max_new_tokens=4096, - proxies=dict(), - retry=1000, - ), - template=prompt_template, - output_format=output_format, - aggregator=dict(type='lagent.agents.aggregator.DefaultAggregator'), - actions=[dict(type='lagent.actions.PythonInterpreter')], - ) - response = agent(AgentMessage(sender='user', content='用 Python 计算一下 3 ** 5')) - print(response) - response = agent(AgentMessage(sender='user', content=' 2 ** 5 呢')) - print(response) - - async_agent = AsyncReAct( - llm=dict( - type=AsyncGPTAPI, - model_type='gpt-4o-2024-05-13', - max_new_tokens=4096, - proxies=dict(), - retry=1000, - ), - template=prompt_template, - output_format=output_format, - aggregator=dict(type='lagent.agents.aggregator.DefaultAggregator'), - actions=[dict(type='lagent.actions.AsyncPythonInterpreter')], - ) - response = asyncio.run(async_agent(AgentMessage(sender='user', content='用 Python 计算一下 3 ** 5'))) - print(async_agent.state_dict()) diff --git a/lagent/agents/stream.py b/lagent/agents/stream.py deleted file mode 100644 index 9687aac1..00000000 --- a/lagent/agents/stream.py +++ /dev/null @@ -1,287 +0,0 @@ -import json -import warnings -from copy import deepcopy -from typing import Callable, Dict, List, Union - -from lagent.actions import ActionExecutor, AsyncActionExecutor, AsyncIPythonInterpreter, IPythonInteractive -from lagent.agents.agent import Agent, AsyncAgent -from lagent.agents.aggregator import InternLMToolAggregator -from lagent.hooks import InternLMActionProcessor -from lagent.llms import BaseLLM -from lagent.memory import Memory -from lagent.prompts.parsers import InterpreterParser, MixedToolParser, PluginParser, ToolStatusCode -from lagent.schema import AgentMessage -from lagent.utils import create_object - -API_PREFIX = ( - "This is the subfunction for tool '{tool_name}', you can use this tool. " - 'The description of this function is: \n{description}' -) - -META_CN = '当开启工具以及代码时,根据需求选择合适的工具进行调用' - -INTERPRETER_CN = ( - '你现在已经能够在一个有状态的 Jupyter 笔记本环境中运行 Python 代码。' - '当你向 python 发送含有 Python 代码的消息时,它将在该环境中执行。' - '这个工具适用于多种场景,如数据分析或处理(包括数据操作、统计分析、图表绘制),' - '复杂的计算问题(解决数学和物理难题),编程示例(理解编程概念或特性),' - '文本处理和分析(比如文本解析和自然语言处理),' - '机器学习和数据科学(用于展示模型训练和数据可视化),' - '以及文件操作和数据导入(处理CSV、JSON等格式的文件)。' -) - -PLUGIN_CN = ( - '你可以使用如下工具:' - '\n{prompt}\n' - '如果你已经获得足够信息,请直接给出答案. 避免不必要的工具调用! ' - '同时注意你可以使用的工具,不要随意捏造!' -) - - -def get_plugin_prompt(actions, api_desc_template='{description}'): - plugin_descriptions = [] - for action in actions if isinstance(actions, list) else [actions]: - action = create_object(action) - action_desc = deepcopy(action.description) - if action.is_toolkit: - for api in action_desc['api_list']: - api['name'] = f"{action.name}.{api['name']}" - api['description'] = api_desc_template.format(tool_name=action.name, description=api['description']) - plugin_descriptions.append(api) - else: - plugin_descriptions.append(action_desc) - return json.dumps(plugin_descriptions, ensure_ascii=False, indent=4) - - -class AgentForInternLM(Agent): - - _INTERNAL_AGENT_CLS = Agent - - def __init__( - self, - llm: Union[BaseLLM, Dict], - plugins: Union[dict, List[dict]] = None, - interpreter: dict = None, - template: Union[str, dict, List[dict]] = None, - memory: Dict = dict(type=Memory), - output_format: Dict = dict( - type=MixedToolParser, - template=META_CN, - parsers=[ - dict(type=PluginParser, template=PLUGIN_CN), - dict(type=InterpreterParser, template=INTERPRETER_CN), - ], - ), - aggregator: Dict = dict(type=InternLMToolAggregator), - action_hooks: List = [dict(type=InternLMActionProcessor)], - finish_condition: Callable[[AgentMessage], bool] = lambda m: m.formatted['status'] == ToolStatusCode.NO_TOOL, - max_turn: int = 4, - **kwargs, - ): - self.agent = self._INTERNAL_AGENT_CLS( - llm=llm, - template=template, - output_format=output_format, - memory=memory, - aggregator=aggregator, - hooks=kwargs.pop('hooks', None), - ) - self.plugin_executor = plugins and ActionExecutor(plugins, hooks=action_hooks) - self.interpreter_executor = interpreter and ActionExecutor(interpreter, hooks=action_hooks) - if not (self.plugin_executor or self.interpreter_executor): - warnings.warn( - 'Neither plugin nor interpreter executor is initialized. ' - 'An exception will be thrown when the agent call a tool.' - ) - self.finish_condition = finish_condition - self.max_turn = max_turn - super().__init__(**kwargs) - - def forward(self, message: AgentMessage, session_id=0, **kwargs): - for _ in range(self.max_turn): - message = self.agent(message, session_id=session_id, **kwargs) - assert isinstance(message.formatted, dict) - if self.finish_condition(message): - return message - if message.formatted['tool_type']: - tool_type = message.formatted["tool_type"] - executor = getattr(self, f'{tool_type}_executor', None) - if not executor: - raise RuntimeError(f'No available {tool_type} executor') - message = executor(message, session_id=session_id) - return message - - def get_steps(self, session_id=0): - steps, tool_type = [], None - for msg in self.agent.memory.get_memory(session_id): - if msg.sender == self.agent.name: - steps.append(dict(role='thought', content=msg.formatted['thought'])) - if msg.formatted['tool_type']: - tool_type = msg.formatted['tool_type'] - steps.append(dict(role='tool', content=msg.formatted['action'], name=tool_type)) - elif msg.sender != 'user': - feedback = dict(role='environment', content=msg.content) - if tool_type: - feedback['name'] = tool_type - steps.append(feedback) - return steps - - -class MathCoder(AgentForInternLM): - - def __init__( - self, - llm: Union[BaseLLM, Dict], - interpreter: dict = dict(type=IPythonInteractive, timeout=20, max_out_len=8192), - template: Union[str, dict, List[dict]] = None, - memory: Dict = dict(type=Memory), - output_format: Dict = dict( - type=InterpreterParser, - template=( - 'Integrate step-by-step reasoning and Python code to solve math problems ' - 'using the following guidelines:\n' - '- Analyze the question and write jupyter code to solve the problem;\n' - r"- Present the final result in LaTeX using a '\boxed{{}}' without any " - 'units. \n' - ), - ), - aggregator: Dict = dict(type=InternLMToolAggregator), - action_hooks: List = [dict(type=InternLMActionProcessor)], - finish_condition: Callable[[AgentMessage], bool] = lambda m: m.formatted['status'] == ToolStatusCode.NO_TOOL, - max_turn: int = 6, - **kwargs, - ): - kwargs.pop('plugins', None) - super().__init__( - llm=llm, - interpreter=interpreter, - template=template, - memory=memory, - output_format=output_format, - aggregator=aggregator, - action_hooks=action_hooks, - finish_condition=finish_condition, - max_turn=max_turn, - **kwargs, - ) - - -class AsyncAgentForInternLM(AsyncAgent): - - _INTERNAL_AGENT_CLS = AsyncAgent - - def __init__( - self, - llm: Union[BaseLLM, Dict], - plugins: Union[dict, List[dict]] = None, - interpreter: dict = None, - template: Union[str, dict, List[dict]] = None, - memory: Dict = dict(type=Memory), - output_format: Dict = dict( - type=MixedToolParser, - template=META_CN, - parsers=[ - dict(type=PluginParser, template=PLUGIN_CN), - dict(type=InterpreterParser, template=INTERPRETER_CN), - ], - ), - aggregator: Dict = dict(type=InternLMToolAggregator), - action_hooks: List = [dict(type=InternLMActionProcessor)], - finish_condition: Callable[[AgentMessage], bool] = lambda m: m.formatted['status'] == ToolStatusCode.NO_TOOL, - max_turn: int = 4, - **kwargs, - ): - self.agent = self._INTERNAL_AGENT_CLS( - llm=llm, - template=template, - output_format=output_format, - memory=memory, - aggregator=aggregator, - hooks=kwargs.pop('hooks', None), - ) - self.plugin_executor = plugins and AsyncActionExecutor(plugins, hooks=action_hooks) - self.interpreter_executor = interpreter and AsyncActionExecutor(interpreter, hooks=action_hooks) - if not (self.plugin_executor or self.interpreter_executor): - warnings.warn( - 'Neither plugin nor interpreter executor is initialized. ' - 'An exception will be thrown when the agent call a tool.' - ) - self.finish_condition = finish_condition - self.max_turn = max_turn - super().__init__(**kwargs) - - async def forward(self, message: AgentMessage, session_id=0, **kwargs): - for _ in range(self.max_turn): - message = await self.agent(message, session_id=session_id, **kwargs) - assert isinstance(message.formatted, dict) - if self.finish_condition(message): - return message - if message.formatted['tool_type']: - tool_type = message.formatted["tool_type"] - executor = getattr(self, f'{tool_type}_executor', None) - if not executor: - raise RuntimeError(f'No available {tool_type} executor') - message = await executor(message, session_id=session_id) - return message - - def get_steps(self, session_id=0): - steps, tool_type = [], None - for msg in self.agent.memory.get_memory(session_id): - if msg.sender == self.agent.name: - steps.append(dict(role='thought', content=msg.formatted['thought'])) - if msg.formatted['tool_type']: - tool_type = msg.formatted['tool_type'] - steps.append(dict(role='tool', content=msg.formatted['action'], name=tool_type)) - elif msg.sender != 'user': - feedback = dict(role='environment', content=msg.content) - if tool_type: - feedback['name'] = tool_type - steps.append(feedback) - return steps - - -class AsyncMathCoder(AsyncAgentForInternLM): - - def __init__( - self, - llm: Union[BaseLLM, Dict], - interpreter: dict = dict(type=AsyncIPythonInterpreter), - template: Union[str, dict, List[dict]] = None, - memory: Dict = dict(type=Memory), - output_format: Dict = dict( - type=InterpreterParser, - template=( - 'Integrate step-by-step reasoning and Python code to solve math problems ' - 'using the following guidelines:\n' - '- Analyze the question and write jupyter code to solve the problem;\n' - r"- Present the final result in LaTeX using a '\boxed{{}}' without any " - 'units. \n' - ), - ), - aggregator: Dict = dict(type=InternLMToolAggregator), - action_hooks: List = [dict(type=InternLMActionProcessor)], - finish_condition: Callable[[AgentMessage], bool] = lambda m: m.formatted['status'] == ToolStatusCode.NO_TOOL, - max_turn: int = 6, - **kwargs, - ): - kwargs.pop('plugins', None) - super().__init__( - llm=llm, - interpreter=interpreter, - template=template, - memory=memory, - output_format=output_format, - aggregator=aggregator, - action_hooks=action_hooks, - finish_condition=finish_condition, - max_turn=max_turn, - **kwargs, - ) - - async def forward(self, message: AgentMessage, session_id=0, **kwargs): - try: - return await super().forward(message, session_id, **kwargs) - finally: - interpreter = next(iter(self.interpreter_executor.actions.values())) - if interpreter.name == 'AsyncIPythonInterpreter': - await interpreter.close_session(session_id) diff --git a/lagent/apps/app.py b/lagent/apps/app.py new file mode 100644 index 00000000..15556272 --- /dev/null +++ b/lagent/apps/app.py @@ -0,0 +1,73 @@ +import asyncio +import logging +from typing import List + +from lagent.apps.bus import MessageBus +from lagent.apps.dispatcher import InterclawDispatcher +from lagent.apps.session import SessionManager + +logger = logging.getLogger(__name__) +logging.basicConfig( + level=logging.INFO, # 允许打印 INFO 及以上级别的日志 (如果你想看 debug 就改成 logging.DEBUG) + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" +) + +class InterclawApp: + """ + Interclaw OS 的核心应用容器 (Inversion of Control Container) + 负责管理总线、分发器,以及动态注册的各种通道和后台服务。 + """ + def __init__(self, agent_engine): + """ + :param agent_engine: 核心的 Agent 引擎(纯无状态,如 lagent 实例或 FakeLagent) + """ + self.bus = MessageBus() + self.session_mgr = SessionManager() + self.dispatcher = InterclawDispatcher(self.bus, self.session_mgr, agent_engine=agent_engine) + self._channels = [] + self._services = [] + + def register_channel(self, channel_class, **kwargs): + """ + 动态注册通信通道 (如 CLI, 飞书, 微信) + 自动将 bus 注入给 channel。 + """ + channel_instance = channel_class(self.bus, **kwargs) + self._channels.append(channel_instance) + logger.info(f"Registered channel: {channel_instance.__class__.__name__}") + return self + + def register_service(self, service_class, **kwargs): + """ + 动态注册后台服务 (如 Cron, Heartbeat) + 自动将 bus 注入给 service。 + """ + service_instance = service_class(self.bus, **kwargs) + self._services.append(service_instance) + logger.info(f"Registered service: {service_instance.__class__.__name__}") + return self + + async def _run_all(self): + tasks = [] + # 1. 启动核心 Dispatcher (负责消费 inbound 数据并发送给 LLM,将结果推入 outbound) + tasks.append(asyncio.create_task(self.dispatcher.start())) + + # 2. 启动所有的 Channels (负责独立监听外部输入,并阻塞等待 outbound 返回结果) + for channel in self._channels: + tasks.append(asyncio.create_task(channel.start())) + + # 3. 启动后台服务 (定时任务等) + for service in self._services: + tasks.append(asyncio.create_task(service.start())) + + # 启动并等待所有的事件循环 + await asyncio.gather(*tasks) + + def run(self): + """启动整个 Agent OS 框架的入口""" + logger.info("Starting Interclaw OS...") + try: + asyncio.run(self._run_all()) + except KeyboardInterrupt: + logger.info("Interclaw OS interrupted by user. Shutting down...") + # 这里可以补充优雅退出的逻辑 diff --git a/lagent/apps/bus.py b/lagent/apps/bus.py new file mode 100644 index 00000000..74288dbd --- /dev/null +++ b/lagent/apps/bus.py @@ -0,0 +1,52 @@ +import asyncio +import logging +from dataclasses import dataclass, field +from typing import Dict, Any, Callable + +logger = logging.getLogger("lagent.interclaw.bus") + +@dataclass +class InboundEvent: + channel: str + session_id: str + content: str + metadata: Dict[str, Any] = field(default_factory=dict) + +@dataclass +class OutboundEvent: + channel: str + session_id: str + content: str + metadata: Dict[str, Any] = field(default_factory=dict) + +class MessageBus: + """ + 基于 Pub/Sub 模式的高可用内存总线。 + 支持一个统一的 Inbound 队列给 Dispatcher,以及多个按 Channel 路由的 Outbound 队列。 + """ + def __init__(self): + self.inbound_queue: asyncio.Queue[InboundEvent] = asyncio.Queue() + self._outbound_queues: Dict[str, asyncio.Queue[OutboundEvent]] = {} + + def subscribe_outbound(self, channel: str) -> asyncio.Queue[OutboundEvent]: + """供各种 Channel (如 cli, whatsapp) 注册并监听自己的发件箱""" + if channel not in self._outbound_queues: + self._outbound_queues[channel] = asyncio.Queue() + logger.debug(f"Channel [{channel}] subscribed to outbound bus.") + return self._outbound_queues[channel] + + async def publish_inbound(self, event: InboundEvent) -> None: + """供所有外部刺激 (通道、定时器) 推送消息给大脑""" + await self.inbound_queue.put(event) + + async def consume_inbound(self) -> InboundEvent: + """供 Dispatcher 持续消耗任务""" + return await self.inbound_queue.get() + + async def publish_outbound(self, event: OutboundEvent) -> None: + """供 Dispatcher 推送大脑的回复到指定通道""" + queue = self._outbound_queues.get(event.channel) + if queue: + await queue.put(event) + else: + logger.warning(f"No subscriber for channel: {event.channel}, message dropped.") diff --git a/lagent/apps/channels/__init__.py b/lagent/apps/channels/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/lagent/apps/channels/base.py b/lagent/apps/channels/base.py new file mode 100644 index 00000000..b91dd504 --- /dev/null +++ b/lagent/apps/channels/base.py @@ -0,0 +1,17 @@ +from abc import ABC, abstractmethod +from lagent.apps.bus import MessageBus + +class BaseChannel(ABC): + """ + 通讯频造的基础抽象类 (例如 CLI, WhatsApp, WeChat 等) + 每一个 Channel 都有责任生产输入,并消费属于自己的输出。 + """ + def __init__(self, bus: MessageBus, channel_name: str): + self.bus = bus + self.channel_name = channel_name + self.outbound_queue = bus.subscribe_outbound(channel_name) + + @abstractmethod + async def start(self): + """需要通过 asyncio.gather 并发运行读写逻辑""" + pass diff --git a/lagent/apps/channels/cli.py b/lagent/apps/channels/cli.py new file mode 100644 index 00000000..b393f302 --- /dev/null +++ b/lagent/apps/channels/cli.py @@ -0,0 +1,63 @@ +import asyncio +import sys +import logging +from .base import BaseChannel +from lagent.apps.bus import MessageBus, InboundEvent + +logger = logging.getLogger("lagent.interclaw.channels.cli") + +class CLIChannel(BaseChannel): + """ + 终端互动频道:真实可用的互动命令行端点。 + 不仅负责把你的键盘敲击传给 Agent,也负责把 Agent 回收的话打印回控制台。 + """ + def __init__(self, bus: MessageBus): + super().__init__(bus, channel_name="cli") + + async def _read_loop(self): + """Producer: 将用户的回车敲击送如总线""" + loop = asyncio.get_running_loop() + while True: + try: + # 使用 run_in_executor 防止 input() 阻塞 asyncio 事件循环 + line = await loop.run_in_executor(None, input, "\n[You] ➤ ") + text = line.strip() + if text.lower() in ('exit', 'quit', 'q'): + logger.info("User requested exit.") + # 优雅地通知主程序终止 + import os, signal + os.kill(os.getpid(), signal.SIGINT) + break + + if text: + event = InboundEvent( + channel=self.channel_name, + session_id="cli:local_master", + content=text + ) + await self.bus.publish_inbound(event) + except EOFError: + break + except asyncio.CancelledError: + break + + async def _write_loop(self): + """Consumer: 监听分配给 CLI 的 OutboundEvent 并打印""" + while True: + try: + event = await self.outbound_queue.get() + print(f"\n[Agent] ⚡ {event.content}\n", end="", flush=True) + except asyncio.CancelledError: + break + + async def start(self): + logger.info("CLI Channel activated. You can start typing.") + tasks = [ + asyncio.create_task(self._read_loop(), name="CLI_Reader"), + asyncio.create_task(self._write_loop(), name="CLI_Writer"), + ] + try: + await asyncio.gather(*tasks) + except asyncio.CancelledError: + for t in tasks: + t.cancel() diff --git a/lagent/apps/channels/feishu.py b/lagent/apps/channels/feishu.py new file mode 100644 index 00000000..ad9b9eb7 --- /dev/null +++ b/lagent/apps/channels/feishu.py @@ -0,0 +1,190 @@ +import asyncio +import json +import logging +import threading +from typing import Any + +from lagent.apps.bus import MessageBus, InboundEvent +from lagent.apps.channels.base import BaseChannel + +logger = logging.getLogger("lagent.interclaw.channels.feishu") + +class FeishuChannel(BaseChannel): + """ + 基于 lark-oapi 实现的飞书全双工通道。 + 内部维护一个独立的 WebSocket 线程监听飞书事件,用 async loop 发送回调回复。 + """ + def __init__(self, bus: MessageBus, app_id: str, app_secret: str, encrypt_key: str = "", verification_token: str = ""): + super().__init__(bus, channel_name="feishu") + self.app_id = app_id + self.app_secret = app_secret + self.encrypt_key = encrypt_key + self.verification_token = verification_token + + self._client: Any = None + self._ws_client: Any = None + self._ws_thread: threading.Thread | None = None + self._loop: asyncio.AbstractEventLoop | None = None + + async def _producer_loop(self): + """Producer: 将飞书的回调变成 InboundEvent (使用 Lark WS 客户端挂载后台线程)""" + try: + import lark_oapi as lark + except ImportError: + logger.error("lark-oapi sdk is missing! Please pip install lark-oapi") + return + + self._loop = asyncio.get_running_loop() + + # 1. 构造主动发信 Client + self._client = lark.Client.builder() \ + .app_id(self.app_id) \ + .app_secret(self.app_secret) \ + .log_level(lark.LogLevel.INFO) \ + .build() + + # 2. 构造事件处理器 (收信器) + def _on_message_callback(data: Any) -> None: + # 这个回调是在飞书内置的多线程里触发的,所以必须丢回 asyncio 中枢 + if self._loop and self._loop.is_running(): + asyncio.run_coroutine_threadsafe(self._handle_feishu_inbound(data), self._loop) + + event_handler = lark.EventDispatcherHandler.builder( + self.encrypt_key, + self.verification_token, + ).register_p2_im_message_receive_v1( + _on_message_callback + ).build() + + # 3. 构造长连接 Web Socket 客户端 + self._ws_client = lark.ws.Client( + self.app_id, + self.app_secret, + event_handler=event_handler, + log_level=lark.LogLevel.INFO + ) + + def run_ws(): + logger.info("Feishu WebSocket is starting via Lark-OAPI...") + + # 1. 为当前的后台线程创建一个独立的、全新的事件循环 + ws_loop = asyncio.new_event_loop() + asyncio.set_event_loop(ws_loop) + + # 2. 核心修复:强制覆盖 lark_oapi.ws.client 模块内部缓存的 loop 变量 + # 防止其去调度主线程的 loop 而引发 "This event loop is already running" 错误 + import lark_oapi.ws.client + lark_oapi.ws.client.loop = ws_loop + + try: + self._ws_client.start() + except Exception as e: + logger.error(f"Feishu WebSocket failed: {e}") + + # 挂载到一个守护线程执行 + self._ws_thread = threading.Thread(target=run_ws, daemon=True) + self._ws_thread.start() + + # 主协程挂起,维持存活 + while True: + await asyncio.sleep(3600) + + async def _handle_feishu_inbound(self, data: Any): + """将飞书原生的数据格式转化成 Agent OS 的内部格式并丢入总线""" + try: + event = data.event + message = event.message + sender = event.sender + + if sender.sender_type == "bot": + return # 不自己跟自己聊天 + + sender_id = sender.sender_id.open_id if sender.sender_id else "unknown" + chat_id = message.chat_id + msg_type = message.message_type + + # 构建会话 ID: 例如 feishu:group_xabc123 或者 feishu:user_888 + session_id = f"feishu:{chat_id}" + + # 解析内容 (简化版文本提取,实际你需要保留 nanobot 里的全面解析如图片、Card等) + content = "" + if msg_type == "text": + content_json = json.loads(message.content) if message.content else {} + content = content_json.get("text", "") + + if content.strip(): + logger.info(f"[Feishu Inbound] Receive MSG -> session:{session_id}") + await self.bus.publish_inbound(InboundEvent( + channel=self.channel_name, + session_id=session_id, + content=content, + # 可以通过 Metadata 保存原始的飞书数据对象以备高级工具使用 + metadata={"feishu_chat_id": chat_id, "feishu_receive_id_type": "chat_id"} + )) + except Exception as e: + logger.error(f"Failed to process Feishu inbound message: {e}") + + def _sync_send_text(self, receive_id: str, text: str): + from lark_oapi.api.im.v1 import CreateMessageRequest, CreateMessageRequestBody + request = CreateMessageRequest.builder() \ + .receive_id_type("chat_id") \ + .request_body( + CreateMessageRequestBody.builder() + .receive_id(receive_id) + .msg_type("text") + .content(json.dumps({"text": text}, ensure_ascii=False)) + .build() + ).build() + self._client.im.v1.message.create(request) + + async def _consumer_loop(self): + """Consumer: 监听由 Dispatcher 分发给专属 Feishu Channel 的回复""" + while True: + try: + event = await self.outbound_queue.get() + logger.info(f"[Feishu Outbound] Send MSG -> {event.session_id}") + + # 从 session_id 中解析出飞书认的 chat_id (feishu:chat_xxxxx) + if ":" in event.session_id: + _, chat_id = event.session_id.split(":", 1) + else: + chat_id = event.session_id + + # 非阻塞投递 + loop = asyncio.get_running_loop() + await loop.run_in_executor(None, self._sync_send_text, chat_id, event.content) + except asyncio.CancelledError: + break + except Exception as e: + logger.error(f"Failed to send outbound to Feishu: {e}") + + async def start(self): + logger.info("Feishu Channel initialized. Connecting pipelines...") + tasks = [ + asyncio.create_task(self._producer_loop(), name="Feishu_Producer"), + asyncio.create_task(self._consumer_loop(), name="Feishu_Consumer"), + ] + try: + await asyncio.gather(*tasks) + except asyncio.CancelledError: + for t in tasks: + t.cancel() + +if __name__ == "__main__": + # 这个 main 只是为了本地测试飞书通道的收发功能,实际使用时不需要运行这个文件 + import os + from lagent.apps.bus import MessageBus + logging.basicConfig( + level=logging.INFO, # 允许打印 INFO 及以上级别的日志 (如果你想看 debug 就改成 logging.DEBUG) + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" + ) + bus = MessageBus() + feishu_channel = FeishuChannel( + bus=bus, + app_id=os.getenv("FEISHU_APP_ID", "cli_a92538846ff99cd2"), + app_secret=os.getenv("FEISHU_APP_SECRET", "EpDAW3TCnqUpyUltr1Q5WfR27j0vX13F"), + encrypt_key=os.getenv("FEISHU_ENCRYPT_KEY", ""), + verification_token=os.getenv("FEISHU_VERIFICATION_TOKEN", "") + ) + + asyncio.run(feishu_channel.start()) \ No newline at end of file diff --git a/lagent/apps/dispatcher.py b/lagent/apps/dispatcher.py new file mode 100644 index 00000000..3d842b38 --- /dev/null +++ b/lagent/apps/dispatcher.py @@ -0,0 +1,56 @@ +import asyncio +import logging +from .bus import MessageBus, OutboundEvent +from .session import SessionManager + +logger = logging.getLogger("lagent.interclaw.dispatcher") + +class InterclawDispatcher: + """ + 核心调度网关。 + 将异步事件总线 (MessageBus) 与同步/异步的大模型计算 (Lagent) 完美缝合。 + """ + def __init__(self, bus: MessageBus, session_manager: SessionManager, agent_engine=None): + self.bus = bus + self.session_manager = session_manager + self._running = False + self.agent_engine = agent_engine + + async def start(self) -> None: + self._running = True + logger.info("Interclaw Dispatcher initialized. Ready to orchestrate the Agents.") + + while self._running: + try: + event = await self.bus.consume_inbound() + logger.info(f"==> Inbound Hit | Channel:[{event.channel}] Session:[{event.session_id}]") + + # 1. 挂载持久化记忆 + state = self.session_manager.load_state(event.session_id) + + # 2. 从“休眠仓”唤醒 Agent (纯粹无状态) + agent = self.agent_engine + # self.agent_engine.load_state_dict(state) + + # 3. 运行思考 + reply = await agent(event.content) + + # 4. 把变更后的神经节点冷冻回冰柜 + # self.session_manager.save_state(event.session_id, agent.state_dict()) + + # 5. 返还结果给各自负责分发的人 + if reply: + await self.bus.publish_outbound(OutboundEvent( + channel=event.channel, + session_id=event.session_id, + content=reply.content + )) + + except asyncio.CancelledError: + logger.info("Dispatcher task cancelled properly.") + break + except Exception as e: + logger.error(f"Dispatcher internal error: {e}", exc_info=True) + + def stop(self): + self._running = False diff --git a/lagent/apps/main.py b/lagent/apps/main.py new file mode 100644 index 00000000..7e0031e6 --- /dev/null +++ b/lagent/apps/main.py @@ -0,0 +1,84 @@ +import asyncio +import logging +import signal +import sys +import os + +# 将 Lagent 项目根目录插入 SYS PATH,确保支持正确引用 +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) + +from lagent.apps.app import InterclawApp +from lagent.apps.channels.cli import CLIChannel +from lagent.apps.channels.feishu import FeishuChannel +from lagent.services.cron import CronService +from lagent.services.heartbeat import HeartbeatService + + +# 配置工程化日志格式(生产级别格式) +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s | %(levelname)-7s | %(name)-20s | %(message)s", + datefmt="%H:%M:%S" +) +logger = logging.getLogger("interclaw.main") + + +def main(): + logger.info("Initializing Interclaw Agent OS...") + + from pathlib import Path + from lagent.agents.aggregator.context import InternClawContextBuilder + from lagent.llms.model import AsyncAPIClient, ModelConfig, SampleParameters + from lagent.actions.filesystem import ReadFileAction, WriteFileAction, EditFileAction + from lagent.actions.shell import ShellAction + from lagent.agents.internclaw_agent import InternClawAgent, AsyncPolicyAgent, AsyncEnvAgent, get_tool_prompt + from lagent.hooks.logger import MessageLogger + model_name = "/mnt/shared-storage-user/puyudelivery/user/puyudilivery/ckpts/xtuner_saved_model/interns1_1_mini_official/interns1_1_mini_sft_based_cpt_bs512_epoch1_maxlr3e-5_minlr1e-6_max16k-hf/20260207101512/hf-4374" + # model_name = "gpt-4o-2024-08-06" + api_base = "http://10.102.218.28:23333/v1/" + # api_base = f"http://35.220.164.252:3888/v1beta/models/{model_name}:generateContent" + api_key = "" + extra_body = {'enable_thinking': True, 'spaces_between_special_tokens': False} + proxies = None + + model = AsyncAPIClient( + model=ModelConfig(model=model_name, base_url=api_base, api_key=api_key, proxy=proxies), + sample_params=SampleParameters(temperature=0.7, top_p=1.0, top_k=50), + timeout=600, + max_retry=5, + sleep_interval=5, + extra_body=extra_body, + ) + workspace = "/mnt/shared-storage-user/llmit/user/liukuikun/workspace/lagent/workspace" + actions = [ReadFileAction(workspace=workspace), WriteFileAction(workspace=workspace), EditFileAction(workspace=workspace), ShellAction(working_dir=workspace)] + aggregator = InternClawContextBuilder(Path(workspace), tools=get_tool_prompt(actions)) + policy = AsyncPolicyAgent( + llm=model, + aggregator=aggregator, + hooks=[MessageLogger()]) + env = AsyncEnvAgent(actions=actions) + agent = InternClawAgent(policy_agent=policy, env_agent=env) + + # 2. 创立 App 容器 + app = InterclawApp(agent_engine=agent) + + # 3. 动态注册外部通道 (External Channels) + # app自动负责把总线 bus 依赖注入给这些组件 + app.register_channel(FeishuChannel, + app_id=os.getenv("FEISHU_APP_ID", "cli_a92538846ff99cd2"), + app_secret=os.getenv("FEISHU_APP_SECRET", "EpDAW3TCnqUpyUltr1Q5WfR27j0vX13F"), + encrypt_key=os.getenv("FEISHU_ENCRYPT_KEY", ""), + verification_token=os.getenv("FEISHU_VERIFICATION_TOKEN", "")) + + # 如果有飞书 credentials,可动态注册飞书 channel + # app.register_channel(FeishuChannel, app_id="TODO_APP_ID", app_secret="TODO_APP_SECRET") + + # 4. 动态注册后台服务 (Services & Daemons) + app.register_service(CronService) + app.register_service(HeartbeatService, pulse_seconds=10) + + # 5. 启动总控制器 + app.run() + +if __name__ == "__main__": + main() diff --git a/lagent/apps/session.py b/lagent/apps/session.py new file mode 100644 index 00000000..2114d57e --- /dev/null +++ b/lagent/apps/session.py @@ -0,0 +1,49 @@ +import json +import logging +from pathlib import Path +from typing import Dict, Any + +logger = logging.getLogger("lagent.interclaw.session") + +class SessionManager: + """ + 真实落地的会话管理器,提供物理隔离与硬盘持久化。 + 参考 nanobot 的 session 实现,使用 JSON 保存。 + """ + def __init__(self, data_dir: str = ".interclaw_data/sessions"): + self.data_dir = Path(data_dir) + self.data_dir.mkdir(parents=True, exist_ok=True) + self._cache: Dict[str, Dict[str, Any]] = {} + + def _get_file_path(self, session_id: str) -> Path: + # 清理可能引起路径风险的符号 + safe_name = session_id.replace(":", "_").replace("/", "_") + return self.data_dir / f"{safe_name}.json" + + def load_state(self, session_id: str) -> Dict[str, Any]: + """加载时优先查缓存,再读硬盘""" + if session_id in self._cache: + return self._cache[session_id] + + file_path = self._get_file_path(session_id) + if file_path.exists(): + try: + with open(file_path, "r", encoding="utf-8") as f: + state = json.load(f) + self._cache[session_id] = state + return state + except json.JSONDecodeError as e: + logger.error(f"Failed to decode session file {file_path}: {e}") + return {} + return {} + + def save_state(self, session_id: str, state: Dict[str, Any]) -> None: + """更新缓存并持久化到硬盘""" + self._cache[session_id] = state + file_path = self._get_file_path(session_id) + try: + with open(file_path, "w", encoding="utf-8") as f: + json.dump(state, f, ensure_ascii=False, indent=2) + logger.debug(f"Session [{session_id}] persisted to {file_path}") + except Exception as e: + logger.error(f"Failed to save session state: {e}") diff --git a/lagent/distributed/__init__.py b/lagent/distributed/__init__.py deleted file mode 100644 index 98ad5210..00000000 --- a/lagent/distributed/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -from .http_serve import AgentAPIServer, AsyncHTTPAgentClient, AsyncHTTPAgentServer, HTTPAgentClient, HTTPAgentServer -from .ray_serve import AgentRayActor, AsyncAgentRayActor - -__all__ = [ - 'AsyncAgentRayActor', 'AgentRayActor', 'HTTPAgentServer', - 'HTTPAgentClient', 'AsyncHTTPAgentServer', 'AsyncHTTPAgentClient', - 'AgentAPIServer' -] diff --git a/lagent/distributed/http_serve/api_server.py b/lagent/distributed/http_serve/api_server.py deleted file mode 100644 index e78dcacc..00000000 --- a/lagent/distributed/http_serve/api_server.py +++ /dev/null @@ -1,131 +0,0 @@ -import json -import os -import subprocess -import sys -import time -import threading - -import aiohttp -import requests - -from lagent.schema import AgentMessage - - -class HTTPAgentClient: - - def __init__(self, host='127.0.0.1', port=8090, timeout=None): - self.host = host - self.port = port - self.timeout = timeout - - @property - def is_alive(self): - try: - resp = requests.get( - f'http://{self.host}:{self.port}/health_check', - timeout=self.timeout) - return resp.status_code == 200 - except: - return False - - def __call__(self, *message, session_id: int = 0, **kwargs): - response = requests.post( - f'http://{self.host}:{self.port}/chat_completion', - json={ - 'message': [ - m if isinstance(m, str) else m.model_dump() - for m in message - ], - 'session_id': session_id, - **kwargs, - }, - headers={'Content-Type': 'application/json'}, - timeout=self.timeout) - resp = response.json() - if response.status_code != 200: - return resp - return AgentMessage.model_validate(resp) - - def state_dict(self, session_id: int = 0): - resp = requests.get( - f'http://{self.host}:{self.port}/memory/{session_id}', - timeout=self.timeout) - return resp.json() - - -class HTTPAgentServer(HTTPAgentClient): - - def __init__(self, gpu_id, config, host='127.0.0.1', port=8090): - super().__init__(host, port) - self.gpu_id = gpu_id - self.config = config - self.start_server() - - def start_server(self): - # set CUDA_VISIBLE_DEVICES in subprocess - env = os.environ.copy() - env['CUDA_VISIBLE_DEVICES'] = self.gpu_id - cmds = [ - sys.executable, 'lagent/distributed/http_serve/app.py', '--host', - self.host, '--port', - str(self.port), '--config', - json.dumps(self.config) - ] - self.process = subprocess.Popen( - cmds, - env=env, - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - text=True) - - self.service_started = False - - def log_output(stream): - if stream is not None: - for line in iter(stream.readline, ''): - print(line, end='') - if 'Uvicorn running on' in line: - self.service_started = True - - # Start log output thread - threading.Thread(target=log_output, args=(self.process.stdout,), daemon=True).start() - threading.Thread(target=log_output, args=(self.process.stderr,), daemon=True).start() - - # Waiting for the service to start - while not self.service_started: - time.sleep(0.1) - - def shutdown(self): - self.process.terminate() - self.process.wait() - - -class AsyncHTTPAgentMixin: - - async def __call__(self, *message, session_id: int = 0, **kwargs): - async with aiohttp.ClientSession( - timeout=aiohttp.ClientTimeout(self.timeout)) as session: - async with session.post( - f'http://{self.host}:{self.port}/chat_completion', - json={ - 'message': [ - m if isinstance(m, str) else m.model_dump() - for m in message - ], - 'session_id': session_id, - **kwargs, - }, - headers={'Content-Type': 'application/json'}, - ) as response: - resp = await response.json() - if response.status != 200: - return resp - return AgentMessage.model_validate(resp) - - -class AsyncHTTPAgentClient(AsyncHTTPAgentMixin, HTTPAgentClient): - pass - - -class AsyncHTTPAgentServer(AsyncHTTPAgentMixin, HTTPAgentServer): - pass diff --git a/lagent/distributed/http_serve/app.py b/lagent/distributed/http_serve/app.py deleted file mode 100644 index 7d05660a..00000000 --- a/lagent/distributed/http_serve/app.py +++ /dev/null @@ -1,96 +0,0 @@ -import argparse -import json -import logging -import time - -import uvicorn -from fastapi import FastAPI, HTTPException -from fastapi.middleware.cors import CORSMiddleware -from fastapi.requests import Request - -from lagent.schema import AgentMessage -from lagent.utils import load_class_from_string - - -class AgentAPIServer: - - def __init__(self, - config: dict, - host: str = '127.0.0.1', - port: int = 8090): - self.app = FastAPI(docs_url='/') - self.app.add_middleware( - CORSMiddleware, - allow_origins=['*'], - allow_credentials=True, - allow_methods=['*'], - allow_headers=['*'], - ) - cls_name = config.pop('type') - python_path = config.pop('python_path', None) - cls_name = load_class_from_string(cls_name, python_path) if isinstance( - cls_name, str) else cls_name - self.agent = cls_name(**config) - self.setup_routes() - self.run(host, port) - - def setup_routes(self): - - def heartbeat(): - return {'status': 'success', 'timestamp': time.time()} - - async def process_message(request: Request): - try: - body = await request.json() - message = [ - m if isinstance(m, str) else AgentMessage.model_validate(m) - for m in body.pop('message') - ] - result = await self.agent(*message, **body) - return result - except Exception as e: - logging.error(f'Error processing message: {str(e)}') - raise HTTPException( - status_code=500, detail='Internal Server Error') - - def get_memory(session_id: int = 0): - try: - result = self.agent.state_dict(session_id) - return result - except KeyError: - raise HTTPException( - status_code=404, detail="Session ID not found") - except Exception as e: - logging.error(f'Error processing message: {str(e)}') - raise HTTPException( - status_code=500, detail='Internal Server Error') - - self.app.add_api_route('/health_check', heartbeat, methods=['GET']) - self.app.add_api_route( - '/chat_completion', process_message, methods=['POST']) - self.app.add_api_route( - '/memory/{session_id}', get_memory, methods=['GET']) - - def run(self, host='127.0.0.1', port=8090): - logging.info(f'Starting server at {host}:{port}') - uvicorn.run(self.app, host=host, port=port) - - -def parse_args(): - parser = argparse.ArgumentParser(description='Async Agent API Server') - parser.add_argument('--host', type=str, default='127.0.0.1') - parser.add_argument('--port', type=int, default=8090) - parser.add_argument( - '--config', - type=json.loads, - required=True, - help='JSON configuration for the agent') - args = parser.parse_args() - - return args - - -if __name__ == '__main__': - logging.basicConfig(level=logging.INFO) - args = parse_args() - AgentAPIServer(args.config, host=args.host, port=args.port) diff --git a/lagent/distributed/ray_serve/__init__.py b/lagent/distributed/ray_serve/__init__.py deleted file mode 100644 index 8739b2a1..00000000 --- a/lagent/distributed/ray_serve/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .ray_warpper import AgentRayActor, AsyncAgentRayActor - -__all__ = ['AsyncAgentRayActor', 'AgentRayActor'] diff --git a/lagent/distributed/ray_serve/ray_warpper.py b/lagent/distributed/ray_serve/ray_warpper.py deleted file mode 100644 index d401c944..00000000 --- a/lagent/distributed/ray_serve/ray_warpper.py +++ /dev/null @@ -1,48 +0,0 @@ -import importlib -import sys -from typing import Dict - -import ray - -from lagent.schema import AgentMessage -from lagent.utils import load_class_from_string - - -class AsyncAgentRayActor: - - def __init__( - self, - config: Dict, - num_gpus: int, - ): - cls_name = config.pop('type') - python_path = config.pop('python_path', None) - cls_name = load_class_from_string(cls_name, python_path) if isinstance( - cls_name, str) else cls_name - AsyncAgentActor = ray.remote(num_gpus=num_gpus)(cls_name) - self.agent_actor = AsyncAgentActor.remote(**config) - - async def __call__(self, *message: AgentMessage, session_id=0, **kwargs): - response = await self.agent_actor.__call__.remote( - *message, session_id=session_id, **kwargs) - return response - - -class AgentRayActor: - - def __init__( - self, - config: Dict, - num_gpus: int, - ): - cls_name = config.pop('type') - python_path = config.pop('python_path', None) - cls_name = load_class_from_string(cls_name, python_path) if isinstance( - cls_name, str) else cls_name - AgentActor = ray.remote(num_gpus=num_gpus)(cls_name) - self.agent_actor = AgentActor.remote(**config) - - def __call__(self, *message: AgentMessage, session_id=0, **kwargs): - response = self.agent_actor.__call__.remote( - *message, session_id=session_id, **kwargs) - return ray.get(response) diff --git a/lagent/hooks/__init__.py b/lagent/hooks/__init__.py index 0e935281..0b27623d 100644 --- a/lagent/hooks/__init__.py +++ b/lagent/hooks/__init__.py @@ -1,8 +1,7 @@ -from .action_preprocessor import ActionPreprocessor, InternLMActionProcessor from .hook import Hook, RemovableHandle from .logger import MessageLogger __all__ = [ - 'Hook', 'RemovableHandle', 'ActionPreprocessor', 'InternLMActionProcessor', + 'Hook', 'RemovableHandle', 'MessageLogger' ] diff --git a/lagent/hooks/action_preprocessor.py b/lagent/hooks/action_preprocessor.py deleted file mode 100644 index 51083aa8..00000000 --- a/lagent/hooks/action_preprocessor.py +++ /dev/null @@ -1,62 +0,0 @@ -from copy import deepcopy - -from lagent.schema import ActionReturn, ActionStatusCode, FunctionCall -from .hook import Hook - - -class ActionPreprocessor(Hook): - """The ActionPreprocessor is a hook that preprocesses the action message - and postprocesses the action return message. - - """ - - def before_action(self, executor, message, session_id): - assert isinstance(message.formatted, FunctionCall) or ( - isinstance(message.formatted, dict) and 'name' in message.content - and 'parameters' in message.formatted) or ( - 'action' in message.formatted - and 'parameters' in message.formatted['action'] - and 'name' in message.formatted['action']) - if isinstance(message.formatted, dict): - name = message.formatted.get('name', - message.formatted['action']['name']) - parameters = message.formatted.get( - 'parameters', message.formatted['action']['parameters']) - else: - name = message.formatted.name - parameters = message.formatted.parameters - message.content = dict(name=name, parameters=parameters) - return message - - def after_action(self, executor, message, session_id): - action_return = message.content - if isinstance(action_return, ActionReturn): - if action_return.state == ActionStatusCode.SUCCESS: - response = action_return.format_result() - else: - response = action_return.errmsg - else: - response = action_return - message.content = response - return message - - -class InternLMActionProcessor(ActionPreprocessor): - - def __init__(self, code_parameter: str = 'command'): - self.code_parameter = code_parameter - - def before_action(self, executor, message, session_id): - message = deepcopy(message) - assert isinstance(message.formatted, dict) and set( - message.formatted).issuperset( - {'tool_type', 'thought', 'action', 'status'}) - if isinstance(message.formatted['action'], str): - # encapsulate code interpreter arguments - action_name = next(iter(executor.actions)) - parameters = {self.code_parameter: message.formatted['action']} - if action_name in ['AsyncIPythonInterpreter']: - parameters['session_id'] = session_id - message.formatted['action'] = dict( - name=action_name, parameters=parameters) - return super().before_action(executor, message, session_id) diff --git a/lagent/hooks/hook.py b/lagent/hooks/hook.py index c3b3e303..dc906862 100644 --- a/lagent/hooks/hook.py +++ b/lagent/hooks/hook.py @@ -10,7 +10,6 @@ def before_agent( self, agent, message: Tuple[AgentMessage], - session_id: int, ): pass @@ -18,7 +17,6 @@ def after_agent( self, agent, message: AgentMessage, - session_id: int, ): pass @@ -26,7 +24,6 @@ def before_action( self, executor, message: AgentMessage, - session_id: int, ): pass @@ -34,7 +31,6 @@ def after_action( self, executor, message: AgentMessage, - session_id: int, ): pass diff --git a/lagent/hooks/logger.py b/lagent/hooks/logger.py index ccdb8012..33376e17 100644 --- a/lagent/hooks/logger.py +++ b/lagent/hooks/logger.py @@ -13,20 +13,28 @@ def __init__(self, name: str = 'lagent', add_file_handler: bool = False): ) self.sender2color = {} - def before_agent(self, agent, messages, session_id): + def before_agent(self, agent, messages): for message in messages: - self._process_message(message, session_id) + self._process_message(message) - def after_agent(self, agent, message, session_id): - self._process_message(message, session_id) + def after_agent(self, agent, message): + self._process_message(message) - def before_action(self, executor, message, session_id): - self._process_message(message, session_id) + def before_action(self, executor, message): + self._process_message(message) - def after_action(self, executor, message, session_id): - self._process_message(message, session_id) + def after_action(self, executor, message): + self._process_message(message) - def _process_message(self, message, session_id): + def _process_message(self, message): sender = message.sender color = self.sender2color.setdefault(sender, random.choice(list(COLORS))) - self.logger.info(colored(f'session id: {session_id}, message sender: {sender}\n' f'{message.content}', color)) + msg_str = f'message sender: {sender}' + if getattr(message, 'reasoning_content', None): + msg_str += f'\nReasoning:{message.reasoning_content}' + if getattr(message, 'content', None): + msg_str += f'\nContent:{message.content}' + if getattr(message, 'tool_calls', None): + msg_str += f'\nTool Calls:{message.tool_calls}' + + self.logger.info(colored(msg_str, color)) diff --git a/lagent/llms/__init__.py b/lagent/llms/__init__.py index 95679b15..3b64e4ab 100644 --- a/lagent/llms/__init__.py +++ b/lagent/llms/__init__.py @@ -12,7 +12,6 @@ ) from .meta_template import INTERNLM2_META from .openai import GPTAPI, AsyncGPTAPI -from .sensenova import SensenovaAPI from .vllm_wrapper import AsyncVllmModel, VllmModel __all__ = [ @@ -34,7 +33,6 @@ 'HFTransformerChat', 'VllmModel', 'AsyncVllmModel', - 'SensenovaAPI', 'AsyncClaudeAPI', 'ClaudeAPI', ] diff --git a/lagent/llms/anthropic_llm.py b/lagent/llms/anthropic_llm.py index 8aac28cf..426dfe41 100644 --- a/lagent/llms/anthropic_llm.py +++ b/lagent/llms/anthropic_llm.py @@ -1,12 +1,17 @@ import asyncio import json import os -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional, Union, TYPE_CHECKING + -import anthropic import httpcore import httpx -from anthropic import NOT_GIVEN +if TYPE_CHECKING: + from anthropic import NOT_GIVEN +try: + from anthropic import NOT_GIVEN +except ImportError: + NOT_GIVEN = object() from requests.exceptions import ProxyError from .base_api import AsyncBaseAPILLM, BaseAPILLM @@ -47,7 +52,7 @@ def __init__( repetition_penalty=repetition_penalty, stop_words=stop_words, ) - + import anthropic key = os.getenv('Claude_API_KEY') if key == 'ENV' else key if isinstance(key, str): @@ -67,7 +72,6 @@ def __init__( def chat( self, inputs: Union[List[dict], List[List[dict]]], - session_ids: Union[int, List[int]] = None, **gen_params, ) -> Union[str, List[str]]: """Generate responses given the contexts. @@ -267,7 +271,6 @@ def __init__( async def chat( self, inputs: Union[List[dict], List[List[dict]]], - session_ids: Union[int, List[int]] = None, **gen_params, ) -> Union[str, List[str]]: """Generate responses given the contexts. diff --git a/lagent/llms/model.py b/lagent/llms/model.py new file mode 100644 index 00000000..06829b10 --- /dev/null +++ b/lagent/llms/model.py @@ -0,0 +1,209 @@ +import asyncio +import json +import random +import traceback +from logging import getLogger +from typing import List, Union, Optional, Dict, TypedDict + +import aiohttp +from lagent.llms.openai import AsyncGPTAPI + +logger = getLogger(__name__) + +import httpx +from openai import APITimeoutError, AsyncOpenAI, NOT_GIVEN +# from pdp_ext.fc_inferencer import ModelConfig, SampleParameters +class SampleParameters(TypedDict): + temperature: float + top_p: float + top_k: int + +class ModelConfig(TypedDict): + model: str + base_url: str | List[str] + api_key: Optional[str] + + +class AsyncAPIClient(AsyncGPTAPI): + def __init__( + self, + model: ModelConfig, + sample_params: SampleParameters, + timeout: int = 600, + max_retry: int = 50, + sleep_interval: int = 5, + extra_body: Optional[dict] = None, + max_tool_response_length: Optional[int] = 4096, + max_tool_calls_per_turn: int = 5, + **kwargs, + ): + super().__init__(**kwargs) + http_client = httpx.AsyncClient(proxy=model.get('proxy'), timeout=timeout, trust_env=False) if model.get('proxy') else httpx.AsyncClient(timeout=timeout) + self.clients = [ + AsyncOpenAI(api_key=model["api_key"], base_url=url, http_client=http_client) + for url in (model['base_url'] if isinstance(model['base_url'], list) else [model['base_url']]) + ] + self.model_name = model["model"] + self.sample_params = sample_params + self.max_retry = max_retry + self.timeout = timeout + self.sleep_interval = sleep_interval + self.extra_body = extra_body + self.max_tool_response_length = max_tool_response_length + self.max_tool_calls_per_turn = max_tool_calls_per_turn + + async def chat(self, messages: List[dict], tools=None, **gen_params) -> str: + """Generate completion from a list of templates. + + Args: + messages (List[dict]): a list of prompt dictionaries + gen_params: additional generation configuration + + Returns: + str: The generated string. + """ + assert isinstance(messages, list) + for attempt in range(self.max_retry): + try: + client = random.choice(self.clients) + response = await client.chat.completions.create( + model=self.model_name, + messages=messages, + tools=tools, + stream=False, + temperature=self.sample_params.get("temperature", 0.7), + top_p=self.sample_params.get("top_p", 1.0), + timeout=self.timeout, + extra_body=self.extra_body, + max_tokens= self.sample_params.get("max_tokens", 64 * 1024), + reasoning_effort=self.sample_params.get("reasoning_effort", NOT_GIVEN) + ) + break + except (APITimeoutError, TimeoutError) as e: + logger.error(f"LLM Call Timeout: {e}") + if attempt == self.max_retry - 1: + assistant_msg_dict = {"role": "assistant", "content": f"LLM Call Timeout: {e}"} + return assistant_msg_dict + await asyncio.sleep(self.sleep_interval) + except Exception as e: + for val in [ + "用户额度不足", + "剩余额度", + "TimeoutError", + "litellm.BadRequestError", + "litellm.APIError: APIError", + "Failed to parse fc related info to json format!", + "Error code" + ]: + if val in str(e): + import traceback + traceback.print_exc() + logger.error(f"[Retry] {attempt} LLM Call Error: {e}") + if attempt == self.max_retry - 1: + assistant_msg_dict = {"role": "assistant", "content": f"LLM Call Error: {e}"} + return assistant_msg_dict + await asyncio.sleep(self.sleep_interval) + break + else: + import traceback + traceback.print_exc() + assistant_msg_dict = {"role": "assistant", "content": f"LLM Call Error: {e}"} + return assistant_msg_dict + + choice = response.choices[0] + message_data = choice.message + return message_data.model_dump() + +if __name__ == '__main__': + tools = [{ + 'type': 'function', + 'function': { + 'name': 'get_current_temperature', + 'description': 'Get current temperature at a location.', + 'parameters': { + 'type': 'object', + 'properties': { + 'location': { + 'type': 'string', + 'description': 'The location to get the temperature for, in the format \'City, State, Country\'.' + }, + 'unit': { + 'type': 'string', + 'enum': [ + 'celsius', + 'fahrenheit' + ], + 'description': 'The unit to return the temperature in. Defaults to \'celsius\'.' + } + }, + 'required': [ + 'location' + ] + } + } + }, { + 'type': 'function', + 'function': { + 'name': 'get_temperature_date', + 'description': 'Get temperature at a location and date.', + 'parameters': { + 'type': 'object', + 'properties': { + 'location': { + 'type': 'string', + 'description': 'The location to get the temperature for, in the format \'City, State, Country\'.' + }, + 'date': { + 'type': 'string', + 'description': 'The date to get the temperature for, in the format \'Year-Month-Day\'.' + }, + 'unit': { + 'type': 'string', + 'enum': [ + 'celsius', + 'fahrenheit' + ], + 'description': 'The unit to return the temperature in. Defaults to \'celsius\'.' + } + }, + 'required': [ + 'location', + 'date' + ] + } + } + }] + + + + messages = [ + {'role': 'user', 'content': 'Today is 2024-11-14, What\'s the temperature in San Francisco now? How about tomorrow?'} + ] + messages = [ + {'role': 'user', 'content': '上海温度'} + ] + # model_name = "claude-opus-4-6" + # api_base = "http://35.220.164.252:3888/v1" + # api_key = "" + # proxy = "http://100.100.72.89:8899" + # extra_body = {} + + extra_body = {'enable_thinking': True, 'spaces_between_special_tokens': False} + model_name = "/mnt/shared-storage-user/llmit1/user/liujiangning/exp/s2_preview/agent_rl/s2-preview-thinker_sft_0228b_rl0312rc1_fix_klmismatch/20260331212858/hf-15" + api_base = "http://10.102.252.171:23333/v1" + api_key = "sk-admin" + proxy = None + + async def main(): + model = AsyncAPIClient( + model=ModelConfig(model=model_name, base_url=api_base, api_key=api_key, proxy=proxy), + sample_params=SampleParameters(temperature=0.7, ), + timeout=600, + max_retry=5, + sleep_interval=5, + extra_body=extra_body, + ) + response = await model.chat(messages, tools=tools) + print("Response:", response) + + asyncio.run(main()) \ No newline at end of file diff --git a/lagent/llms/openai.py b/lagent/llms/openai.py index 7418a65c..a99a0904 100644 --- a/lagent/llms/openai.py +++ b/lagent/llms/openai.py @@ -12,9 +12,9 @@ import aiohttp import requests -from ..schema import ModelStatusCode -from ..utils import filter_suffix -from .base_api import AsyncBaseAPILLM, BaseAPILLM +from lagent.schema import ModelStatusCode +from lagent.utils import filter_suffix +from lagent.llms.base_api import AsyncBaseAPILLM, BaseAPILLM warnings.simplefilter('default') diff --git a/lagent/llms/sensenova.py b/lagent/llms/sensenova.py deleted file mode 100644 index 54a9f9ee..00000000 --- a/lagent/llms/sensenova.py +++ /dev/null @@ -1,406 +0,0 @@ -import json -import os -import time -import warnings -from concurrent.futures import ThreadPoolExecutor -from logging import getLogger -from threading import Lock -from typing import Dict, Generator, List, Optional, Tuple, Union - -import requests - -from lagent.schema import ModelStatusCode -from lagent.utils.util import filter_suffix -from .base_api import BaseAPILLM - -warnings.simplefilter('default') - -SENSENOVA_API_BASE = 'https://api.sensenova.cn/v1/llm/chat-completions' - -sensechat_models = {'SenseChat-5': 131072, 'SenseChat-5-Cantonese': 32768} - - -class SensenovaAPI(BaseAPILLM): - """Model wrapper around SenseTime's models. - - Args: - model_type (str): The name of SenseTime's model. - retry (int): Number of retires if the API call fails. Defaults to 2. - key (str or List[str]): SenseTime key(s). In particular, when it - is set to "ENV", the key will be fetched from the environment - variable $SENSENOVA_API_KEY. If it's a list, the keys will be - used in round-robin manner. Defaults to 'ENV'. - meta_template (Dict, optional): The model's meta prompt - template if needed, in case the requirement of injecting or - wrapping of any meta instructions. - sensenova_api_base (str): The base url of SenseTime's API. Defaults to - 'https://api.sensenova.cn/v1/llm/chat-completions'. - gen_params: Default generation configuration which could be overridden - on the fly of generation. - """ - - is_api: bool = True - - def __init__( - self, - model_type: str = 'SenseChat-5-Cantonese', - retry: int = 2, - json_mode: bool = False, - key: Union[str, List[str]] = 'ENV', - meta_template: Optional[Dict] = [ - dict(role='system', api_role='system'), - dict(role='user', api_role='user'), - dict(role='assistant', api_role='assistant'), - dict(role='environment', api_role='system'), - ], - sensenova_api_base: str = SENSENOVA_API_BASE, - proxies: Optional[Dict] = None, - **gen_params, - ): - - super().__init__( - model_type=model_type, - meta_template=meta_template, - retry=retry, - **gen_params, - ) - self.logger = getLogger(__name__) - - if isinstance(key, str): - # First, apply for SenseNova's ak and sk from SenseTime staff - # Then, generated SENSENOVA_API_KEY using lagent.utils.gen_key.auto_gen_jwt_token(ak, sk) - self.keys = [ - os.getenv('SENSENOVA_API_KEY') if key == 'ENV' else key - ] - else: - self.keys = key - - # record invalid keys and skip them when requesting API - # - keys have insufficient_quota - self.invalid_keys = set() - - self.key_ctr = 0 - self.url = sensenova_api_base - self.model_type = model_type - self.proxies = proxies - self.json_mode = json_mode - - def chat( - self, - inputs: Union[List[dict], List[List[dict]]], - **gen_params, - ) -> Union[str, List[str]]: - """Generate responses given the contexts. - - Args: - inputs (Union[List[dict], List[List[dict]]]): a list of messages - or list of lists of messages - gen_params: additional generation configuration - - Returns: - Union[str, List[str]]: generated string(s) - """ - assert isinstance(inputs, list) - if 'max_tokens' in gen_params: - raise NotImplementedError('unsupported parameter: max_tokens') - gen_params = {**self.gen_params, **gen_params} - with ThreadPoolExecutor(max_workers=20) as executor: - tasks = [ - executor.submit(self._chat, - self.template_parser._prompt2api(messages), - **gen_params) - for messages in ( - [inputs] if isinstance(inputs[0], dict) else inputs) - ] - ret = [task.result() for task in tasks] - return ret[0] if isinstance(inputs[0], dict) else ret - - def stream_chat( - self, - inputs: List[dict], - **gen_params, - ) -> Generator[Tuple[ModelStatusCode, str, Optional[str]], None, None]: - """Generate responses given the contexts. - - Args: - inputs (List[dict]): a list of messages - gen_params: additional generation configuration - - Yields: - Tuple[ModelStatusCode, str, Optional[str]]: Status code, generated string, and optional metadata - """ - assert isinstance(inputs, list) - if 'max_tokens' in gen_params: - raise NotImplementedError('unsupported parameter: max_tokens') - gen_params = self.update_gen_params(**gen_params) - gen_params['stream'] = True - - resp = '' - finished = False - stop_words = gen_params.get('stop_words') or [] - messages = self.template_parser._prompt2api(inputs) - for text in self._stream_chat(messages, **gen_params): - # TODO 测试 resp = text 还是 resp += text - resp += text - if not resp: - continue - # remove stop_words - for sw in stop_words: - if sw in resp: - resp = filter_suffix(resp, stop_words) - finished = True - break - yield ModelStatusCode.STREAM_ING, resp, None - if finished: - break - yield ModelStatusCode.END, resp, None - - def _chat(self, messages: List[dict], **gen_params) -> str: - """Generate completion from a list of templates. - - Args: - messages (List[dict]): a list of prompt dictionaries - gen_params: additional generation configuration - - Returns: - str: The generated string. - """ - assert isinstance(messages, list) - - header, data = self.generate_request_data( - model_type=self.model_type, - messages=messages, - gen_params=gen_params, - json_mode=self.json_mode, - ) - - max_num_retries = 0 - while max_num_retries < self.retry: - self._wait() - - with Lock(): - if len(self.invalid_keys) == len(self.keys): - raise RuntimeError('All keys have insufficient quota.') - - # find the next valid key - while True: - self.key_ctr += 1 - if self.key_ctr == len(self.keys): - self.key_ctr = 0 - - if self.keys[self.key_ctr] not in self.invalid_keys: - break - - key = self.keys[self.key_ctr] - header['Authorization'] = f'Bearer {key}' - - response = dict() - try: - raw_response = requests.post( - self.url, - headers=header, - data=json.dumps(data), - proxies=self.proxies, - ) - response = raw_response.json() - return response['choices'][0]['message']['content'].strip() - except requests.ConnectionError: - print('Got connection error, retrying...') - continue - except requests.JSONDecodeError: - print('JsonDecode error, got', str(raw_response.content)) - continue - except KeyError: - if 'error' in response: - if response['error']['code'] == 'rate_limit_exceeded': - time.sleep(1) - continue - elif response['error']['code'] == 'insufficient_quota': - self.invalid_keys.add(key) - self.logger.warn(f'insufficient_quota key: {key}') - continue - - print('Find error message in response: ', - str(response['error'])) - except Exception as error: - print(str(error)) - max_num_retries += 1 - - raise RuntimeError('Calling SenseTime failed after retrying for ' - f'{max_num_retries} times. Check the logs for ' - 'details.') - - def _stream_chat(self, messages: List[dict], **gen_params) -> str: - """Generate completion from a list of templates. - - Args: - messages (List[dict]): a list of prompt dictionaries - gen_params: additional generation configuration - - Returns: - str: The generated string. - """ - - def streaming(raw_response): - for chunk in raw_response.iter_lines(): - if chunk: - try: - decoded_chunk = chunk.decode('utf-8') - # print(f"Decoded chunk: {decoded_chunk}") - - if decoded_chunk == 'data:[DONE]': - # print("Stream ended") - break - - if decoded_chunk.startswith('data:'): - json_str = decoded_chunk[5:] - chunk_data = json.loads(json_str) - - if 'data' in chunk_data and 'choices' in chunk_data[ - 'data']: - choice = chunk_data['data']['choices'][0] - if 'delta' in choice: - content = choice['delta'] - yield content - else: - print(f'Unexpected format: {decoded_chunk}') - - except json.JSONDecodeError as e: - print(f'JSON parsing error: {e}') - except Exception as e: - print( - f'An error occurred while processing the chunk: {e}' - ) - - assert isinstance(messages, list) - - header, data = self.generate_request_data( - model_type=self.model_type, - messages=messages, - gen_params=gen_params, - json_mode=self.json_mode, - ) - - max_num_retries = 0 - while max_num_retries < self.retry: - if len(self.invalid_keys) == len(self.keys): - raise RuntimeError('All keys have insufficient quota.') - - # find the next valid key - while True: - self.key_ctr += 1 - if self.key_ctr == len(self.keys): - self.key_ctr = 0 - - if self.keys[self.key_ctr] not in self.invalid_keys: - break - - key = self.keys[self.key_ctr] - header['Authorization'] = f'Bearer {key}' - - response = dict() - try: - raw_response = requests.post( - self.url, - headers=header, - data=json.dumps(data), - proxies=self.proxies, - ) - return streaming(raw_response) - except requests.ConnectionError: - print('Got connection error, retrying...') - continue - except requests.JSONDecodeError: - print('JsonDecode error, got', str(raw_response.content)) - continue - except KeyError: - if 'error' in response: - if response['error']['code'] == 'rate_limit_exceeded': - time.sleep(1) - continue - elif response['error']['code'] == 'insufficient_quota': - self.invalid_keys.add(key) - self.logger.warn(f'insufficient_quota key: {key}') - continue - - print('Find error message in response: ', - str(response['error'])) - except Exception as error: - print(str(error)) - max_num_retries += 1 - - raise RuntimeError('Calling SenseTime failed after retrying for ' - f'{max_num_retries} times. Check the logs for ' - 'details.') - - def generate_request_data(self, - model_type, - messages, - gen_params, - json_mode=False): - """ - Generates the request data for different model types. - - Args: - model_type (str): The type of the model (e.g., 'sense'). - messages (list): The list of messages to be sent to the model. - gen_params (dict): The generation parameters. - json_mode (bool): Flag to determine if the response format should be JSON. - - Returns: - tuple: A tuple containing the header and the request data. - """ - # Copy generation parameters to avoid modifying the original dictionary - gen_params = gen_params.copy() - - # Hold out 100 tokens due to potential errors in token calculation - max_tokens = min(gen_params.pop('max_new_tokens'), 4096) - if max_tokens <= 0: - return '', '' - - # Initialize the header - header = { - 'content-type': 'application/json', - } - - # Common parameters processing - gen_params['max_tokens'] = max_tokens - if 'stop_words' in gen_params: - gen_params['stop'] = gen_params.pop('stop_words') - if 'repetition_penalty' in gen_params: - gen_params['frequency_penalty'] = gen_params.pop( - 'repetition_penalty') - - # Model-specific processing - data = {} - if model_type.lower().startswith('sense'): - gen_params.pop('skip_special_tokens', None) - gen_params.pop('session_id', None) - data = { - 'model': model_type, - 'messages': messages, - 'n': 1, - **gen_params - } - if json_mode: - data['response_format'] = {'type': 'json_object'} - else: - raise NotImplementedError( - f'Model type {model_type} is not supported') - - return header, data - - def tokenize(self, prompt: str) -> list: - """Tokenize the input prompt. - - Args: - prompt (str): Input string. - - Returns: - list: token ids - """ - import tiktoken - - self.tiktoken = tiktoken - enc = self.tiktoken.encoding_for_model('gpt-4o') - return enc.encode(prompt) diff --git a/lagent/memory/__init__.py b/lagent/memory/__init__.py index 13534ada..c9c83372 100644 --- a/lagent/memory/__init__.py +++ b/lagent/memory/__init__.py @@ -1,4 +1,10 @@ from .base_memory import Memory -from .manager import MemoryManager +from .openclaw_provider import OpenClawMemoryProvider, SandboxOpenClawMemoryProvider +from .claude_code_provider import ClaudeCodeMemoryProvider -__all__ = ['Memory', 'MemoryManager'] +__all__ = [ + 'Memory', + 'OpenClawMemoryProvider', + 'SandboxOpenClawMemoryProvider', + 'ClaudeCodeMemoryProvider', +] diff --git a/lagent/memory/base_memory.py b/lagent/memory/base_memory.py index 3c8fcf04..40e92b30 100644 --- a/lagent/memory/base_memory.py +++ b/lagent/memory/base_memory.py @@ -4,27 +4,34 @@ class Memory: + """Session message buffer. The only memory primitive in lagent. + + A simple append-only list of AgentMessage. No windowing, no + boundary tracking — those concerns belong to ContextBuilder + (which reads compact state from env_info). + """ _item_cls = AgentMessage - def __init__(self, recent_n=None) -> None: + def __init__(self) -> None: self.memory: List[AgentMessage] = [] - self.recent_n = recent_n + + def reset(self) -> None: + """Clear all messages.""" + self.memory = [] def get_memory( self, - recent_n: Optional[int] = None, - filter_func: Optional[Callable[[int, dict], bool]] = None, + filter_func: Optional[Callable[[int, AgentMessage], bool]] = None, ) -> list: - recent_n = recent_n or self.recent_n - if recent_n is not None: - memory = self.memory[-recent_n:] - else: - memory = self.memory + memory = self.memory if filter_func is not None: memory = [m for i, m in enumerate(memory) if filter_func(i, m)] return memory + # Alias for backward compatibility + get = get_memory + def add(self, memories: Union[List[Dict], Dict, None]) -> None: for memory in memories if isinstance(memories, (list, tuple)) else [memories]: if isinstance(memory, str): @@ -38,7 +45,7 @@ def delete(self, index: Union[List, int]) -> None: if isinstance(index, int): del self.memory[index] else: - for i in index: + for i in sorted(index, reverse=True): del self.memory[i] def load( @@ -57,7 +64,4 @@ def load( raise TypeError(f'{type(memories)} is not supported') def save(self) -> List[dict]: - memory = [] - for m in self.memory: - memory.append(m.model_dump()) - return memory + return [m.model_dump() for m in self.memory] diff --git a/lagent/memory/claude_code_provider.py b/lagent/memory/claude_code_provider.py new file mode 100644 index 00000000..c4f7f947 --- /dev/null +++ b/lagent/memory/claude_code_provider.py @@ -0,0 +1,127 @@ +"""Claude Code memory provider. + +Claude Code uses a **directory of markdown files** managed by the +model itself through standard file read/write tools: + + memory/ + ├── MEMORY.md ← index file (pointers to memory files) + ├── user_role.md ← individual memory with frontmatter + ├── feedback_testing.md + └── project_context.md + +Each memory file has YAML frontmatter:: + + --- + name: user role + description: user is a senior engineer working on lagent + type: user + --- + (memory content) + +``MEMORY.md`` is an index — each entry is one line:: + + - [User Role](user_role.md) — senior engineer on lagent project + +``get_info()`` reads the index + all referenced files, returning +the assembled content for env_info injection. + +``actions`` is empty — Claude Code relies on the model using +standard file read/write tools (guided by prompt instructions) +to manage memory. No special memory actions needed. +""" + +from __future__ import annotations + +import asyncio +import re +from pathlib import Path +from typing import List + + +class ClaudeCodeMemoryProvider: + """Claude Code style memory: index + individual markdown files. + + The model manages memory files autonomously via read/write tools. + This provider only handles the **read side** — loading memory + content into env_info each turn. + + Usage:: + + provider = ClaudeCodeMemoryProvider(memory_dir) + env = AsyncEnvAgent( + actions=actions, # no special memory actions needed + long_term_memory=provider, + ) + + Parameters + ---------- + memory_dir : Path + Directory containing MEMORY.md index and memory files. + """ + + def __init__(self, memory_dir: Path): + self._dir = Path(memory_dir) + self._index_file = self._dir / "MEMORY.md" + + async def get_info(self) -> dict: + """Load index + all referenced memory files.""" + index_content = await self._read_index() + if not index_content: + return {} + + # Parse referenced files from index + memory_files = self._parse_index_links(index_content) + + # Load each memory file + memories = [] + for filename in memory_files: + filepath = self._dir / filename + content = await self._read_file(filepath) + if content: + memories.append(content) + + return { + "available": True, + "index": index_content, + "memories": memories, + } + + @property + def actions(self) -> list: + """Claude Code doesn't need special memory actions. + + The model uses standard file read/write tools, guided by + prompt instructions that describe the memory directory format. + """ + return [] + + # ── Internal helpers ── + + async def _read_index(self) -> str: + if self._index_file.exists(): + return await asyncio.to_thread( + self._index_file.read_text, encoding="utf-8" + ) + return "" + + async def _read_file(self, path: Path) -> str: + if path.exists(): + return await asyncio.to_thread( + path.read_text, encoding="utf-8" + ) + return "" + + def _parse_index_links(self, index_content: str) -> List[str]: + """Extract markdown link targets from index lines. + + Parses lines like: + - [Title](filename.md) — description + Returns list of filenames. + """ + pattern = re.compile(r'\[.*?\]\((.+?\.md)\)') + filenames = [] + for line in index_content.splitlines(): + match = pattern.search(line) + if match: + filenames.append(match.group(1)) + return filenames diff --git a/lagent/memory/manager.py b/lagent/memory/manager.py deleted file mode 100644 index 1d42ac96..00000000 --- a/lagent/memory/manager.py +++ /dev/null @@ -1,29 +0,0 @@ -from typing import Dict - -from ..utils import create_object -from .base_memory import Memory - - -class MemoryManager: - - def __init__(self, cfg) -> None: - self.cfg = cfg - self.memory_map: Dict[str, Memory] = {} - - def create_instance(self, session_id): - self.memory_map[session_id] = create_object(self.cfg) - - def get_memory(self, session_id=0, **kwargs) -> list: - return self.memory_map[session_id].get_memory(**kwargs) - - def add(self, memory, session_id=0, **kwargs) -> None: - if session_id not in self.memory_map: - self.create_instance(session_id) - self.memory_map[session_id].add(memory, **kwargs) - - def get(self, session_id=0) -> Memory: - return self.memory_map.get(session_id, None) - - def reset(self, session_id=0) -> None: - if session_id in self.memory_map: - del self.memory_map[session_id] diff --git a/lagent/memory/openclaw_provider.py b/lagent/memory/openclaw_provider.py new file mode 100644 index 00000000..c0c718cf --- /dev/null +++ b/lagent/memory/openclaw_provider.py @@ -0,0 +1,78 @@ +"""OpenClaw memory provider. + +Reads memory context from OpenClaw's MEMORY.md for env_info injection. +Write operations are handled by ``SaveMemoryAction`` (separate action). +Both operate on the same workspace directory independently. +""" + +from __future__ import annotations + +import asyncio +import json +from pathlib import Path +from typing import Any + + +class OpenClawMemoryProvider: + """OpenClaw memory reader: MEMORY.md on filesystem. + + Usage:: + + provider = OpenClawMemoryProvider(workspace_path) + save_action = SaveMemoryAction(workspace_path) + + env = AsyncEnvAgent( + actions=[save_action, ...], + long_term_memory=provider, + ) + """ + + def __init__(self, workspace: Path): + memory_dir = Path(workspace) / "memory" + memory_dir.mkdir(parents=True, exist_ok=True) + self._memory_file = memory_dir / "MEMORY.md" + + async def get_info(self) -> dict: + content = await self._read() + if not content: + return {} + return {"available": True, "long_term": content} + + async def _read(self) -> str: + if self._memory_file.exists(): + return await asyncio.to_thread( + self._memory_file.read_text, encoding="utf-8" + ) + return "" + + +class SandboxOpenClawMemoryProvider: + """OpenClaw memory reader on a remote sandbox.""" + + def __init__(self, action: Any, *, workspace_root: str = "."): + self._shell = action + root = workspace_root.rstrip("/") or "." + self._memory_file = f"{root}/memory/MEMORY.md" + + async def get_info(self) -> dict: + content = await self._read() + if not content: + return {} + return {"available": True, "long_term": content} + + async def _read(self) -> str: + from lagent.schema import ActionStatusCode + result = await self._shell.run( + command=f"cat {self._memory_file!r} 2>/dev/null || true" + ) + if result.state != ActionStatusCode.SUCCESS: + return "" + try: + if isinstance(result.result, list) and result.result: + content_str = result.result[0].get("content", "") + content_dict = json.loads(content_str) + if content_dict.get("exit_code") == 0: + return content_dict.get("stdout", "").strip() + except Exception: + pass + return "" diff --git a/lagent/schema.py b/lagent/schema.py index 668846fb..feaf8b79 100644 --- a/lagent/schema.py +++ b/lagent/schema.py @@ -86,12 +86,21 @@ class AgentStatusCode(IntEnum): CODE_END = 7 # end python CODE_RETURN = 8 # python return - +from datetime import datetime class AgentMessage(BaseModel): content: Any sender: str = 'user' + role: Optional[str] = None formatted: Optional[Any] = None extra_info: Optional[Any] = None + env_info: Optional[Dict[str, Any]] = None type: Optional[str] = None receiver: Optional[str] = None stream_state: Union[ModelStatusCode, AgentStatusCode] = AgentStatusCode.END + tool_calls: Optional[List[Dict]] = None + timestamp: str = datetime.now().isoformat() + reasoning_content: Optional[str] = None + + def model_post_init(self, context): + if self.role is None: + self.role = self.sender \ No newline at end of file diff --git a/lagent/services/__init__.py b/lagent/services/__init__.py new file mode 100644 index 00000000..dd034119 --- /dev/null +++ b/lagent/services/__init__.py @@ -0,0 +1,6 @@ +from .agent import AgentService +from .cron import CronService +from .mailbox import Mailbox +from .task import TaskBoard + +__all__ = ["AgentService", "CronService", "Mailbox", "TaskBoard"] diff --git a/lagent/services/agent.py b/lagent/services/agent.py new file mode 100644 index 00000000..54b54d5b --- /dev/null +++ b/lagent/services/agent.py @@ -0,0 +1,540 @@ +"""AgentService — lifecycle manager for agent instances. + +Core interface: ``config_dict + build → Agent``. + +Responsibilities +---------------- +* **Template registry**: loads :class:`AgentSpec` via :class:`AgentLoader`. +* **Instance creation**: ``AgentSpec.acreate()`` (or custom build). +* **Registry**: tracks running/stopped instances with IDs. +* **Execution**: sync (await) or async (``asyncio.Task``). +* **Persistence**: ``save_all()`` / ``load_all()`` via ``state_dict()``. + +Design decisions +---------------- +* No fork — all agents are built from config. State transfer uses + ``state_dict()`` / ``load_state_dict()`` when needed. +* ID is auto-generated (``uuid.uuid4().hex[:8]``). +* Service is decoupled from channels/bus. +""" + +from __future__ import annotations + +import asyncio +import json +import logging +import time +import uuid +from dataclasses import dataclass, field +from enum import Enum +from pathlib import Path +from typing import Any, Awaitable, Callable + +from lagent.agents.agent import Agent + +from .agent_loader import AgentLoader, AgentSpec + +logger = logging.getLogger("lagent.interclaw.services.agent") + + +# ── helpers ─────────────────────────────────────────────────────────── + +def _now_ms() -> int: + return int(time.time() * 1000) + + +# ── data model ──────────────────────────────────────────────────────── + +class AgentStatus(str, Enum): + PENDING = "pending" + RUNNING = "running" + STOPPED = "stopped" + FAILED = "failed" + + +@dataclass +class AgentEntry: + """Registry record for a managed agent instance.""" + + id: str = field(default_factory=lambda: uuid.uuid4().hex[:8]) + agent_type: str = "" + label: str = "" + task: str = "" + status: str = AgentStatus.PENDING + result: str | None = None + error: str | None = None + created_at_ms: int = field(default_factory=_now_ms) + finished_at_ms: int | None = None + + def to_dict(self) -> dict: + return { + "id": self.id, + "agent_type": self.agent_type, + "label": self.label, + "task": self.task, + "status": self.status, + "result": self.result, + "error": self.error, + "created_at_ms": self.created_at_ms, + "finished_at_ms": self.finished_at_ms, + } + + @classmethod + def from_dict(cls, data: dict) -> AgentEntry: + return cls( + id=data.get("id", uuid.uuid4().hex[:8]), + agent_type=data.get("agent_type", ""), + label=data.get("label", ""), + task=data.get("task", ""), + status=data.get("status", AgentStatus.PENDING), + result=data.get("result"), + error=data.get("error"), + created_at_ms=data.get("created_at_ms", _now_ms()), + finished_at_ms=data.get("finished_at_ms"), + ) + + +# ── AgentService ────────────────────────────────────────────────────── + +class AgentService: + """Manages the full lifecycle of dynamically spawned agents. + + Responsibilities: spec registry, instance creation from specs, + lifecycle tracking (run/stop/resume), and persistence. + + Agent *construction* logic (LLM injection, tool selection) lives + in the caller (e.g. :class:`SubAgentAction`), not here. This + service only needs a built :class:`Agent` or a registered + :class:`AgentSpec`. + + Parameters + ---------- + agent_loader : AgentLoader, optional + Discovers and parses agent templates from the filesystem. + max_concurrent : int + Maximum number of concurrently running async agents. + on_complete : callable, optional + ``async (entry: AgentEntry) -> None`` called when an async agent + finishes. + """ + + def __init__( + self, + agent_loader: AgentLoader | None = None, + max_concurrent: int = 5, + on_complete: Callable[[AgentEntry], Awaitable[None]] | None = None, + ): + self._loader = agent_loader + self._max_concurrent = max_concurrent + self._on_complete = on_complete + self._semaphore = asyncio.Semaphore(max_concurrent) + + # Registry + self._entries: dict[str, AgentEntry] = {} + self._agents: dict[str, Agent] = {} # live instances + self._tasks: dict[str, asyncio.Task] = {} # async tasks + + # Spec cache + self._specs: dict[str, AgentSpec] = {} + + # Saved states (for resume) + self._saved_states: dict[str, dict] = {} + + # ── initialisation ──────────────────────────────────────────── + + async def load_specs(self) -> None: + """Load all agent templates from the AgentLoader.""" + if self._loader is None: + return + self._specs = await self._loader.load_all() + logger.info( + "Loaded %d agent specs: %s", + len(self._specs), list(self._specs.keys()), + ) + + def register_spec(self, spec: AgentSpec) -> None: + """Manually register an AgentSpec.""" + self._specs[spec.name] = spec + + def get_spec(self, agent_type: str) -> AgentSpec | None: + """Get a registered spec by name.""" + return self._specs.get(agent_type) + + @property + def available_types(self) -> list[str]: + """List all registered agent type names.""" + return list(self._specs.keys()) + + # ── spawn ───────────────────────────────────────────────────── + + async def spawn( + self, + task: str, + agent_type: str, + *, + label: str | None = None, + mode: str = "sync", + state: dict | None = None, + **spec_kwargs, + ) -> AgentEntry: + """Create and run a new agent from a registered spec. + + Parameters + ---------- + task : str + Task description / instruction. + agent_type : str + Registered agent type name. + label : str, optional + Human-readable label. + mode : str + ``"sync"`` — block until done. + ``"async"`` / ``"background"`` — run as asyncio.Task. + state : dict, optional + If provided, ``load_state_dict(state)`` is called on the + new agent before running. + **spec_kwargs + Extra keyword arguments forwarded to + ``AgentSpec.acreate()`` (e.g. ``llm=``, ``actions=``). + + Returns + ------- + AgentEntry + """ + agent = await self._create_from_spec(agent_type, **spec_kwargs) + + label = label or (task[:40] + ("…" if len(task) > 40 else "")) + entry = AgentEntry( + agent_type=agent_type, + label=label, + task=task, + status=AgentStatus.PENDING, + ) + self._entries[entry.id] = entry + + logger.info( + "Spawning agent [%s] type=%s mode=%s: %s", + entry.id, agent_type, mode, label, + ) + + # Optional state transfer + if state is not None: + try: + agent.load_state_dict(state) + except Exception as exc: + logger.warning( + "Failed to load state for [%s]: %s", entry.id, exc, + ) + + if mode == "sync": + await self._run_sync(entry, agent, task) + else: + self._run_async(entry, agent, task) + + return entry + + async def spawn_agent( + self, + agent: Agent, + task: str, + *, + label: str | None = None, + mode: str = "sync", + agent_type: str = "_custom", + ) -> AgentEntry: + """Run a pre-built Agent instance directly. + + Use this when you have an already-constructed Agent (e.g. from + a custom factory) and want it managed by the service. + + Parameters + ---------- + agent : Agent + A fully constructed agent instance. + task : str + Task description. + label : str, optional + Human-readable label. + mode : str + ``"sync"`` or ``"async"``. + agent_type : str + Label for the entry's ``agent_type`` field. + """ + label = label or (task[:40] + ("…" if len(task) > 40 else "")) + entry = AgentEntry( + agent_type=agent_type, + label=label, + task=task, + status=AgentStatus.PENDING, + ) + self._entries[entry.id] = entry + + logger.info( + "Spawning pre-built agent [%s] type=%s mode=%s: %s", + entry.id, agent_type, mode, label, + ) + + if mode == "sync": + await self._run_sync(entry, agent, task) + else: + self._run_async(entry, agent, task) + + return entry + + async def _create_from_spec(self, agent_type: str, **kwargs) -> Agent: + """Create agent from a registered spec.""" + spec = self._specs.get(agent_type) + if spec is None and self._loader is not None: + spec = await self._loader.load(agent_type) + if spec is not None: + self._specs[agent_type] = spec + if spec is None: + raise ValueError( + f"Unknown agent type: {agent_type!r}. " + f"Available: {self.available_types}" + ) + return await spec.acreate(**kwargs) + + async def _run_sync( + self, entry: AgentEntry, agent: Agent, task: str, + ) -> None: + """Run agent synchronously.""" + entry.status = AgentStatus.RUNNING + try: + async with self._semaphore: + self._agents[entry.id] = agent + response = await agent(task) + entry.result = ( + response.content if hasattr(response, "content") + else str(response) + ) + entry.status = AgentStatus.STOPPED + entry.finished_at_ms = _now_ms() + except Exception as exc: + entry.status = AgentStatus.FAILED + entry.error = str(exc) + entry.finished_at_ms = _now_ms() + logger.error("Agent [%s] sync failed: %s", entry.id, exc) + finally: + finished_agent = self._agents.pop(entry.id, None) + if finished_agent is not None: + try: + self._saved_states[entry.id] = finished_agent.state_dict() + except Exception as exc: + logger.warning( + "Failed to save state for [%s]: %s", entry.id, exc, + ) + + def _run_async( + self, entry: AgentEntry, agent: Agent, task: str, + ) -> None: + """Run agent asynchronously — create asyncio.Task.""" + async def _run() -> None: + entry.status = AgentStatus.RUNNING + try: + async with self._semaphore: + self._agents[entry.id] = agent + response = await agent(task) + entry.result = ( + response.content if hasattr(response, "content") + else str(response) + ) + entry.status = AgentStatus.STOPPED + entry.finished_at_ms = _now_ms() + except Exception as exc: + entry.status = AgentStatus.FAILED + entry.error = str(exc) + entry.finished_at_ms = _now_ms() + logger.error("Agent [%s] async failed: %s", entry.id, exc) + finally: + finished_agent = self._agents.pop(entry.id, None) + if finished_agent is not None: + try: + self._saved_states[entry.id] = finished_agent.state_dict() + except Exception as exc: + logger.warning( + "Failed to save state for [%s]: %s", + entry.id, exc, + ) + self._tasks.pop(entry.id, None) + if self._on_complete: + try: + await self._on_complete(entry) + except Exception as cb_exc: + logger.error( + "on_complete failed for [%s]: %s", + entry.id, cb_exc, + ) + + task_obj = asyncio.create_task(_run(), name=f"agent-{entry.id}") + self._tasks[entry.id] = task_obj + + # ── query ───────────────────────────────────────────────────── + + def list( + self, + *, + status: str | None = None, + agent_type: str | None = None, + ) -> list[AgentEntry]: + """List managed agent entries, optionally filtered.""" + entries = list(self._entries.values()) + if status is not None: + entries = [e for e in entries if e.status == status] + if agent_type is not None: + entries = [e for e in entries if e.agent_type == agent_type] + return entries + + def get(self, agent_id: str) -> AgentEntry | None: + """Get a single entry by ID.""" + return self._entries.get(agent_id) + + # ── resume ──────────────────────────────────────────────────── + + async def resume(self, agent_id: str, message: str) -> AgentEntry: + """Resume a stopped agent with a new message. + + Recreates the agent from spec, restores saved state, and sends + the new message. Reuses the existing AgentEntry rather than + creating a new one. + """ + entry = self._entries.get(agent_id) + if entry is None: + raise ValueError(f"Agent {agent_id!r} not found") + if entry.status == AgentStatus.RUNNING: + raise ValueError(f"Agent {agent_id!r} is still running") + + spec = self._specs.get(entry.agent_type) + if spec is None: + raise ValueError( + f"Agent spec {entry.agent_type!r} not found, " + f"cannot resume agent {agent_id!r}" + ) + + agent = await spec.acreate() + + # Restore saved conversation state if available + saved_state = self._saved_states.get(agent_id) + if saved_state is not None: + try: + agent.load_state_dict(saved_state) + except Exception as exc: + logger.warning( + "Failed to restore state for [%s]: %s", agent_id, exc, + ) + + # Reset entry for the new run + entry.result = None + entry.error = None + entry.finished_at_ms = None + + await self._run_sync(entry, agent, message) + return entry + + # ── stop ────────────────────────────────────────────────────── + + async def stop(self, agent_id: str) -> bool: + """Stop a running async agent. Returns True if cancelled.""" + task = self._tasks.get(agent_id) + if task is None: + return False + + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + entry = self._entries.get(agent_id) + if entry and entry.status == AgentStatus.RUNNING: + entry.status = AgentStatus.STOPPED + entry.finished_at_ms = _now_ms() + + self._tasks.pop(agent_id, None) + self._agents.pop(agent_id, None) + logger.info("Agent [%s] stopped", agent_id) + return True + + # ── persistence ─────────────────────────────────────────────── + + async def save_all(self, path: Path) -> None: + """Save all entries and live agent states to disk.""" + path.mkdir(parents=True, exist_ok=True) + + entries_data = [e.to_dict() for e in self._entries.values()] + entries_file = path / "agents.json" + await asyncio.to_thread( + entries_file.write_text, + json.dumps(entries_data, ensure_ascii=False, indent=2), + "utf-8", + ) + + states_dir = path / "states" + states_dir.mkdir(exist_ok=True) + for agent_id, agent in self._agents.items(): + try: + state = agent.state_dict() + state_file = states_dir / f"{agent_id}.json" + await asyncio.to_thread( + state_file.write_text, + json.dumps(state, ensure_ascii=False, indent=2), + "utf-8", + ) + except Exception as exc: + logger.warning( + "Failed to save state for [%s]: %s", agent_id, exc, + ) + + logger.info("Saved %d entries to %s", len(entries_data), path) + + async def load_all(self, path: Path) -> None: + """Load entries and states from disk.""" + entries_file = path / "agents.json" + if not entries_file.exists(): + return + + raw = await asyncio.to_thread(entries_file.read_text, "utf-8") + for data in json.loads(raw): + entry = AgentEntry.from_dict(data) + self._entries[entry.id] = entry + + states_dir = path / "states" + if states_dir.exists(): + for state_file in states_dir.iterdir(): + if state_file.suffix == ".json": + agent_id = state_file.stem + try: + state_raw = await asyncio.to_thread( + state_file.read_text, "utf-8", + ) + self._saved_states[agent_id] = json.loads(state_raw) + except Exception as exc: + logger.warning( + "Failed to load state for [%s]: %s", + agent_id, exc, + ) + + logger.info( + "Loaded %d entries from %s", + len(self._entries), path, + ) + + # ── cleanup ─────────────────────────────────────────────────── + + async def shutdown(self) -> None: + """Gracefully stop all running agents.""" + running_ids = list(self._tasks.keys()) + for agent_id in running_ids: + await self.stop(agent_id) + logger.info( + "AgentService shut down, stopped %d agents", len(running_ids), + ) + + def remove(self, agent_id: str) -> bool: + """Remove a finished entry. Cannot remove running agents.""" + entry = self._entries.get(agent_id) + if entry is None or entry.status == AgentStatus.RUNNING: + return False + self._entries.pop(agent_id, None) + self._saved_states.pop(agent_id, None) + return True diff --git a/lagent/services/agent_loader.py b/lagent/services/agent_loader.py new file mode 100644 index 00000000..4e7b4e07 --- /dev/null +++ b/lagent/services/agent_loader.py @@ -0,0 +1,303 @@ +"""AgentLoader -- discover and load agent projects. + +An agent project is a directory containing a ``config.py`` that +exports ``agent_config`` (a dict for ``create_object()``). + +Optionally, a project may also export: + +* ``build`` -- custom factory ``(config_dict) -> Agent``. + When absent the default ``create_object(agent_config)`` is used. +* ``name`` -- agent type name (defaults to directory name). +* ``description`` -- one-line description. +* ``background`` -- whether to run async (default False). + +Usage:: + + loader = AgentLoader(Path("workspace")) + spec = await loader.load("my-agent") + agent = spec.create() # uses build or create_object + +.. note:: + + TODO: Support AGENT.md (markdown + YAML frontmatter) as a simplified + format for declaring sub-agents. This requires an ActionRegistry or + similar mechanism to resolve tool names to Action instances, which is + not yet implemented. For now, only pyconfig (config.py) is supported. +""" + +from __future__ import annotations + +import asyncio +import importlib +import importlib.util +import logging +import sys +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Callable, Optional + +from lagent.agents.agent import Agent +from lagent.utils import create_object + +logger = logging.getLogger("lagent.interclaw.agent_loader") + +BUILTIN_AGENTS_DIR = Path(__file__).parent.parent / "agents_builtin" + + +# ── AgentSpec ───────────────────────────────────────────────────────── + +@dataclass +class AgentSpec: + """Everything needed to instantiate an agent. + + The only truly essential field is ``agent_config``. Everything else + is metadata or optional override. + """ + + name: str = "" + description: str = "" + background: bool = False + + # Core: the PyConfig dict for create_object(). + agent_config: dict[str, Any] | None = None + + # Optional custom build function: (config_dict) -> Agent. + # When set, create() calls build(agent_config) instead of create_object(). + build: Callable | None = None + + # Where this spec was loaded from (informational). + project_dir: str | None = None + + # Pass-through metadata (AGENT.md fields like tools, model, etc.). + extra: dict[str, Any] = field(default_factory=dict) + + def create(self) -> Agent: + """Instantiate the agent from this spec. + + Uses ``build(agent_config)`` if a custom build function is set, + otherwise falls back to ``create_object(agent_config)``. + """ + if self.agent_config is None: + raise ValueError( + f"AgentSpec {self.name!r} has no agent_config" + ) + if self.build is not None and callable(self.build): + return self.build(self.agent_config) + return create_object(self.agent_config) + + async def acreate(self, **kwargs) -> Agent: + """Async version of :meth:`create`. + + Handles both sync and async build functions. + Extra *kwargs* (e.g. ``llm=``, ``actions=``) are forwarded + to the build function if it accepts them. + """ + import inspect + + if self.agent_config is None: + raise ValueError( + f"AgentSpec {self.name!r} has no agent_config" + ) + if self.build is not None and callable(self.build): + result = self.build(self.agent_config, **kwargs) + if inspect.isawaitable(result): + result = await result + return result + return create_object(self.agent_config) + + def to_dict(self) -> dict: + """Serialize to plain dict (build is excluded).""" + d = { + "name": self.name, + "description": self.description, + "background": self.background, + } + if self.project_dir is not None: + d["project_dir"] = self.project_dir + if self.extra: + d["extra"] = self.extra + return d + + @classmethod + def from_dict(cls, data: dict) -> AgentSpec: + """Deserialize from plain dict.""" + return cls( + name=data.get("name", ""), + description=data.get("description", ""), + background=data.get("background", False), + agent_config=data.get("agent_config"), + project_dir=data.get("project_dir"), + extra=data.get("extra", {}), + ) + + +# ── Module loader ───────────────────────────────────────────────────── + +def _import_module_from_path( + module_name: str, + file_path: Path, + package_dir: Path, +) -> dict[str, Any]: + """Import a Python module and return its public attributes. + + The module is temporarily registered in ``sys.modules`` during + execution (required for relative imports inside config.py) and + removed afterwards to avoid polluting the module namespace. + """ + parent = str(package_dir.parent) + added = parent not in sys.path + if added: + sys.path.insert(0, parent) + prev_module = sys.modules.get(module_name) + try: + spec = importlib.util.spec_from_file_location( + module_name, str(file_path), + submodule_search_locations=[str(package_dir)], + ) + if spec is None or spec.loader is None: + raise ImportError(f"Cannot load spec from {file_path}") + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + spec.loader.exec_module(module) + return {k: getattr(module, k) for k in dir(module) if not k.startswith("_")} + finally: + # Restore or clean up sys.modules + if prev_module is not None: + sys.modules[module_name] = prev_module + else: + sys.modules.pop(module_name, None) + if added and parent in sys.path: + sys.path.remove(parent) + + +def _spec_from_module( + attrs: dict[str, Any], name: str, path: str, +) -> AgentSpec: + """Build AgentSpec from a config.py module's exported attributes.""" + agent_config = attrs.get("agent_config") + if agent_config is None: + raise ValueError(f"Agent {name!r}: config.py must export 'agent_config'") + + return AgentSpec( + name=attrs.get("name", name), + description=attrs.get("description", ""), + background=bool(attrs.get("background", False)), + agent_config=agent_config, + build=attrs.get("build"), # None if not defined + project_dir=path, + extra=attrs.get("extra", {}), + ) + + + +# ── AgentLoader ─────────────────────────────────────────────────────── + + +class AgentLoader: + """Discover and load agent projects from the filesystem. + + Scans ``workspace/agents/`` (and optional user/builtin dirs) for + agent project directories containing a ``config.py``. + + Parameters + ---------- + workspace : Path + Workspace root. Agents are at ``workspace/agents//``. + user_agents_dir : Path, optional + User-level agents directory. + builtin_agents_dir : Path, optional + Built-in agents (defaults to ``lagent/agents_builtin/``). + """ + + def __init__( + self, + workspace: Path, + user_agents_dir: Path | None = None, + builtin_agents_dir: Path | None = None, + ): + self.workspace = workspace + self._dirs = [ + d for d in [ + workspace / "agents", + user_agents_dir, + builtin_agents_dir or BUILTIN_AGENTS_DIR, + ] if d is not None + ] + + async def list(self) -> list[dict[str, str]]: + """List available agent projects: [{name, path}, ...].""" + agents = [] + seen: set[str] = set() + for directory in self._dirs: + if not directory.exists(): + continue + for agent_dir in sorted(directory.iterdir()): + if not agent_dir.is_dir() or agent_dir.name in seen: + continue + if (agent_dir / "config.py").exists(): + agents.append({ + "name": agent_dir.name, + "path": str(agent_dir), + }) + seen.add(agent_dir.name) + return agents + + async def load(self, name: str) -> AgentSpec | None: + """Load a single agent by name. + + Returns AgentSpec or None if not found. + """ + for directory in self._dirs: + if directory is None or not directory.exists(): + continue + agent_dir = directory / name + if not agent_dir.is_dir(): + continue + + config_file = agent_dir / "config.py" + if not config_file.exists(): + continue + + try: + attrs = await asyncio.to_thread( + _import_module_from_path, + f"_agentproject_{name}_config", + config_file, + agent_dir, + ) + return _spec_from_module(attrs, name, str(agent_dir)) + except Exception: + logger.exception("Failed to load agent %r", name) + return None + + return None + + async def load_all(self) -> dict[str, AgentSpec]: + """Load all available agents. Returns name -> AgentSpec mapping.""" + entries = await self.list() + specs = {} + for entry in entries: + spec = await self.load(entry["name"]) + if spec is not None: + specs[spec.name] = spec + return specs + + async def build_agents_summary(self) -> str: + """Build XML summary of available agents for prompt injection.""" + specs = await self.load_all() + if not specs: + return "" + + def esc(s: str) -> str: + return s.replace("&", "&").replace("<", "<").replace(">", ">") + + lines = [""] + for spec in specs.values(): + lines.append(f' ') + if spec.description: + lines.append(f" {esc(spec.description)}") + if spec.background: + lines.append(" true") + lines.append(" ") + lines.append("") + return "\n".join(lines) diff --git a/lagent/services/cron.py b/lagent/services/cron.py new file mode 100644 index 00000000..b3bd4612 --- /dev/null +++ b/lagent/services/cron.py @@ -0,0 +1,389 @@ +"""Cron scheduling service. + +A self-contained timer-based scheduler that stores jobs as JSON and fires +callbacks when jobs come due. Completely decoupled from agent / bus — the +caller decides what happens when a job fires via the ``on_job`` callback. + +Design heavily inspired by the python-cron-demo reference implementation. +""" + +from __future__ import annotations + +import asyncio +import json +import logging +import time +import uuid +from dataclasses import asdict, dataclass, field +from datetime import datetime, timezone +from enum import Enum +from pathlib import Path +from typing import Any, Awaitable, Callable, Literal + +logger = logging.getLogger("lagent.interclaw.services.cron") + + +# ── helpers ─────────────────────────────────────────────────────────── + +def _now_ms() -> int: + return int(time.time() * 1000) + + +# ── data model ──────────────────────────────────────────────────────── + +class ScheduleKind(str, Enum): + AT = "at" + EVERY = "every" + CRON = "cron" + + +@dataclass +class Schedule: + """Unified schedule descriptor. + + Exactly one of ``at``, ``every_seconds``, or ``expr`` should be set. + """ + kind: Literal["at", "every", "cron"] = "at" + at: str | None = None # ISO-8601 datetime for one-shot + every_seconds: float | None = None # interval in seconds + expr: str | None = None # 5-field cron expression + tz: str | None = None # IANA timezone (cron only) + + +@dataclass +class JobState: + next_run_at_ms: int | None = None + last_run_at_ms: int | None = None + last_status: str | None = None + last_error: str | None = None + consecutive_errors: int = 0 + + +@dataclass +class CronJob: + id: str = field(default_factory=lambda: uuid.uuid4().hex[:8]) + name: str = "" + enabled: bool = True + schedule: Schedule = field(default_factory=Schedule) + payload: dict[str, Any] = field(default_factory=dict) + state: JobState = field(default_factory=JobState) + created_at_ms: int = field(default_factory=_now_ms) + updated_at_ms: int = field(default_factory=_now_ms) + delete_after_run: bool = False + + # ── serialisation ───────────────────────────────────────────── + + def to_dict(self) -> dict: + return asdict(self) + + @classmethod + def from_dict(cls, data: dict) -> CronJob: + sched_data = data.get("schedule", {}) + schedule = Schedule( + **{k: v for k, v in sched_data.items() + if k in Schedule.__dataclass_fields__} + ) + state_data = data.get("state", {}) + state = JobState( + **{k: v for k, v in state_data.items() + if k in JobState.__dataclass_fields__} + ) + return cls( + id=data.get("id", uuid.uuid4().hex[:8]), + name=data.get("name", ""), + enabled=data.get("enabled", True), + schedule=schedule, + payload=data.get("payload", {}), + state=state, + created_at_ms=data.get("created_at_ms", _now_ms()), + updated_at_ms=data.get("updated_at_ms", _now_ms()), + delete_after_run=data.get("delete_after_run", False), + ) + + +# ── schedule computation ────────────────────────────────────────────── + +def compute_next_run(schedule: Schedule, now_ms: int) -> int | None: + """Return the next fire time in epoch-ms, or *None* if no more fires.""" + + if schedule.kind == "at": + if not schedule.at: + return None + try: + dt = datetime.fromisoformat(schedule.at) + if dt.tzinfo is None: + dt = dt.replace(tzinfo=timezone.utc) + at_ms = int(dt.timestamp() * 1000) + except (ValueError, OverflowError): + return None + return at_ms if at_ms > now_ms else None + + if schedule.kind == "every": + if not schedule.every_seconds or schedule.every_seconds <= 0: + return None + return now_ms + int(schedule.every_seconds * 1000) + + if schedule.kind == "cron" and schedule.expr: + try: + from zoneinfo import ZoneInfo + from croniter import croniter + tz = ZoneInfo(schedule.tz) if schedule.tz else timezone.utc + base_dt = datetime.fromtimestamp(now_ms / 1000, tz=tz) + nxt = croniter(schedule.expr, base_dt).get_next(datetime) + return int(nxt.timestamp() * 1000) + except Exception: + return None + + return None + + +# ── persistence ─────────────────────────────────────────────────────── + +def _load_jobs(path: Path) -> list[CronJob]: + if not path.exists(): + return [] + try: + data = json.loads(path.read_text("utf-8")) + if not isinstance(data, dict) or data.get("version") != 1: + return [] + return [ + CronJob.from_dict(j) + for j in data.get("jobs", []) + if isinstance(j, dict) + ] + except (json.JSONDecodeError, OSError) as exc: + logger.warning("Failed to load cron store %s: %s", path, exc) + return [] + + +def _save_jobs(jobs: list[CronJob], path: Path) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + data = {"version": 1, "jobs": [j.to_dict() for j in jobs]} + tmp = path.with_suffix(".tmp") + try: + tmp.write_text( + json.dumps(data, indent=2, ensure_ascii=False, default=str), + "utf-8", + ) + tmp.replace(path) + except OSError as exc: + logger.error("Failed to save cron store: %s", exc) + tmp.unlink(missing_ok=True) + + +# ── service ─────────────────────────────────────────────────────────── + +# Backoff schedule (seconds) for consecutive errors +_BACKOFF = [30, 60, 300, 900, 3600] +_MAX_TIMER_DELAY_S = 60.0 +_MIN_REFIRE_GAP_S = 2.0 + + +class CronService: + """Async cron scheduler with JSON persistence. + + Parameters + ---------- + store_path : Path + Where to persist ``jobs.json``. + on_job : callable, optional + ``async (CronJob) -> None`` called when a job fires. The caller + decides what to do (e.g. publish an inbound event on the bus). + """ + + def __init__( + self, + store_path: Path, + on_job: Callable[[CronJob], Awaitable[None]] | None = None, + ): + self.store_path = store_path + self.on_job = on_job + self._jobs: list[CronJob] = [] + self._timer: asyncio.TimerHandle | None = None + self._loop: asyncio.AbstractEventLoop | None = None + self._running = False + + # ── lifecycle ───────────────────────────────────────────────── + + async def start(self) -> None: + self._loop = asyncio.get_running_loop() + self._jobs = _load_jobs(self.store_path) + self._recompute_all() + self._persist() + self._running = True + self._arm_timer() + logger.info( + "Cron service started (%d jobs, %d enabled)", + len(self._jobs), + sum(1 for j in self._jobs if j.enabled), + ) + + def stop(self) -> None: + self._running = False + self._cancel_timer() + logger.info("Cron service stopped") + + # ── public CRUD ─────────────────────────────────────────────── + + def add_job( + self, + name: str, + schedule: Schedule, + message: str, + channel: str | None = None, + chat_id: str | None = None, + delete_after_run: bool = False, + ) -> CronJob: + # Validate schedule consistency + if schedule.kind == "at" and not schedule.at: + raise ValueError("Schedule kind 'at' requires 'at' (ISO-8601 datetime)") + if schedule.kind == "every" and ( + not schedule.every_seconds or schedule.every_seconds <= 0 + ): + raise ValueError("Schedule kind 'every' requires positive 'every_seconds'") + if schedule.kind == "cron" and not schedule.expr: + raise ValueError("Schedule kind 'cron' requires 'expr' (5-field cron expression)") + + now = _now_ms() + job = CronJob( + name=name, + schedule=schedule, + payload=dict(message=message, channel=channel, chat_id=chat_id), + state=JobState(next_run_at_ms=compute_next_run(schedule, now)), + created_at_ms=now, + updated_at_ms=now, + delete_after_run=delete_after_run, + ) + self._jobs.append(job) + self._persist() + self._arm_timer() + logger.info("Cron: added job '%s' (%s)", name, job.id) + return job + + def remove_job(self, job_id: str) -> bool: + before = len(self._jobs) + self._jobs = [j for j in self._jobs if j.id != job_id] + removed = len(self._jobs) < before + if removed: + self._persist() + self._arm_timer() + logger.info("Cron: removed job %s", job_id) + return removed + + def list_jobs(self, include_disabled: bool = False) -> list[CronJob]: + jobs = ( + self._jobs + if include_disabled + else [j for j in self._jobs if j.enabled] + ) + return sorted( + jobs, key=lambda j: j.state.next_run_at_ms or float("inf") + ) + + def get_job(self, job_id: str) -> CronJob | None: + return next((j for j in self._jobs if j.id == job_id), None) + + # ── timer engine ────────────────────────────────────────────── + + def _arm_timer(self) -> None: + self._cancel_timer() + if not self._running or not self._loop: + return + wake = self._next_wake_ms() + if wake is None: + return + delay_s = max((wake - _now_ms()) / 1000, _MIN_REFIRE_GAP_S) + delay_s = min(delay_s, _MAX_TIMER_DELAY_S) + self._timer = self._loop.call_later( + delay_s, lambda: asyncio.ensure_future(self._on_timer()) + ) + + def _cancel_timer(self) -> None: + if self._timer: + self._timer.cancel() + self._timer = None + + async def _on_timer(self) -> None: + if not self._running: + return + self._jobs = _load_jobs(self.store_path) # pick up external edits + now = _now_ms() + due = [ + j for j in self._jobs + if j.enabled and j.state.next_run_at_ms and now >= j.state.next_run_at_ms + ] + for job in due: + await self._execute_job(job) + self._persist() + self._arm_timer() + + async def _execute_job(self, job: CronJob) -> None: + start = _now_ms() + logger.info("Cron: executing '%s' (%s)", job.name, job.id) + try: + if self.on_job: + await self.on_job(job) + job.state.last_status = "ok" + job.state.last_error = None + job.state.consecutive_errors = 0 + except Exception as exc: + job.state.last_status = "error" + job.state.last_error = str(exc) + job.state.consecutive_errors += 1 + logger.error("Cron: job '%s' failed: %s", job.name, exc) + + job.state.last_run_at_ms = start + job.updated_at_ms = _now_ms() + + # reschedule or retire + if job.schedule.kind == "at": + if job.delete_after_run: + self._jobs = [j for j in self._jobs if j.id != job.id] + else: + job.enabled = False + job.state.next_run_at_ms = None + else: + if job.state.consecutive_errors > 0: + idx = min( + job.state.consecutive_errors - 1, len(_BACKOFF) - 1 + ) + job.state.next_run_at_ms = _now_ms() + _BACKOFF[idx] * 1000 + else: + job.state.next_run_at_ms = compute_next_run( + job.schedule, _now_ms() + ) + + # ── helpers ─────────────────────────────────────────────────── + + def _next_wake_ms(self) -> int | None: + times = [ + j.state.next_run_at_ms + for j in self._jobs + if j.enabled and j.state.next_run_at_ms + ] + return min(times) if times else None + + def _recompute_all(self) -> None: + now = _now_ms() + for job in self._jobs: + if job.enabled: + job.state.next_run_at_ms = compute_next_run(job.schedule, now) + + def _persist(self) -> None: + _save_jobs(self._jobs, self.store_path) + + # ── serialisation ──────────────────────────────────────────── + + def state_dict(self) -> dict: + """Export service state for Module tree integration.""" + return {"version": 1, "jobs": [j.to_dict() for j in self._jobs]} + + def load_state_dict(self, state: dict) -> None: + """Restore service state from a previous :meth:`state_dict`.""" + self._jobs = [ + CronJob.from_dict(j) + for j in state.get("jobs", []) + if isinstance(j, dict) + ] + self._recompute_all() + self._persist() diff --git a/lagent/services/heartbeat.py b/lagent/services/heartbeat.py new file mode 100644 index 00000000..76e0df54 --- /dev/null +++ b/lagent/services/heartbeat.py @@ -0,0 +1,42 @@ +import asyncio +import logging +from pathlib import Path +from lagent.apps.bus import MessageBus, InboundEvent + +logger = logging.getLogger("lagent.interclaw.services.heartbeat") + +class HeartbeatService: + """ + 状态心跳服务:定期扫描某个物理介质 (如 Markdown 文件) + 发现未处理内容,即封装为一个内部指令并发送给大脑处理。 + """ + def __init__(self, bus: MessageBus, pulse_seconds: int = 15, file_path: str = ".interclaw_data/HEARTBEAT.md"): + self.bus = bus + self.pulse_seconds = pulse_seconds + self.file_path = Path(file_path) + self.file_path.parent.mkdir(parents=True, exist_ok=True) + if not self.file_path.exists(): + self.file_path.write_text("No background tasks for now.\n") + + async def start(self): + logger.info(f"Heartbeat Service started, pulse every {self.pulse_seconds}s for {self.file_path.name}") + while True: + try: + await asyncio.sleep(self.pulse_seconds) + + if self.file_path.exists(): + content = self.file_path.read_text(encoding="utf-8").strip() + # 避免空白文件造成无效触发,如果包含特定的关键字可以考虑唤醒 + if content and "No background tasks" not in content and "TODO:" in content: + logger.info("Heartbeat found actionable items in Markdown!") + + await self.bus.publish_inbound(InboundEvent( + channel="heartbeat", + session_id="system:heartbeat", + content=f"请梳理处理并更新我的后台记录:\n{content}" + )) + + except asyncio.CancelledError: + break + except Exception as e: + logger.error(f"Heartbeat Error: {e}") diff --git a/lagent/services/mailbox.py b/lagent/services/mailbox.py new file mode 100644 index 00000000..afe82e9b --- /dev/null +++ b/lagent/services/mailbox.py @@ -0,0 +1,252 @@ +"""Mailbox -- shared message service for multi-agent communication. + +A centralised message store with per-agent queues. Agents send +:class:`~lagent.schema.AgentMessage` objects to named recipients; +recipients drain their queue at natural checkpoints in their own +execution loop (typically inside ``EnvAgent.get_env_info()``). + +Design +------ +* **Shared resource** -- one Mailbox instance shared across all agents + in a team, like :class:`TaskBoard`. +* **Centralised storage, decentralised routing** -- sender writes + directly into receiver's queue, no middleman. +* **asyncio.Event notification** -- receivers can ``await + wait_for_message()`` instead of polling with ``asyncio.sleep()``. +* **state_dict / load_state_dict** -- serialisable for persistence. +""" + +from __future__ import annotations + +import asyncio +import json +import logging +import time +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + +from lagent.schema import AgentMessage + +logger = logging.getLogger("lagent.services.mailbox") + + +def _now_ms() -> int: + return int(time.time() * 1000) + + +# ── serialisation helpers ──────────────────────────────────────────── + +def _msg_to_dict(msg: AgentMessage) -> dict: + """Convert an AgentMessage to a plain dict for persistence.""" + return { + "sender": msg.sender, + "receiver": msg.receiver, + "content": msg.content if isinstance(msg.content, str) else str(msg.content), + "timestamp": msg.timestamp, + } + + +def _msg_from_dict(data: dict) -> AgentMessage: + """Reconstruct an AgentMessage from a persistence dict.""" + return AgentMessage( + sender=data.get("sender", "unknown"), + receiver=data.get("receiver"), + content=data.get("content", ""), + timestamp=data.get("timestamp", ""), + ) + + +# ── persistence ────────────────────────────────────────────────────── + +def _load_mailbox(path: Path) -> dict[str, list[AgentMessage]]: + if not path.exists(): + return {} + try: + data = json.loads(path.read_text("utf-8")) + if not isinstance(data, dict) or data.get("version") != 1: + return {} + queues = {} + for name, msgs in data.get("queues", {}).items(): + queues[name] = [ + _msg_from_dict(m) for m in msgs if isinstance(m, dict) + ] + return queues + except (json.JSONDecodeError, OSError) as exc: + logger.warning("Failed to load mailbox store %s: %s", path, exc) + return {} + + +def _save_mailbox(queues: dict[str, list[AgentMessage]], path: Path) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + payload = { + "version": 1, + "queues": { + name: [_msg_to_dict(m) for m in msgs] + for name, msgs in queues.items() + }, + } + tmp = path.with_suffix(".tmp") + try: + tmp.write_text( + json.dumps(payload, indent=2, ensure_ascii=False, default=str), + "utf-8", + ) + tmp.replace(path) + except OSError as exc: + logger.error("Failed to save mailbox store: %s", exc) + tmp.unlink(missing_ok=True) + + +# ── Mailbox ────────────────────────────────────────────────────────── + +class Mailbox: + """Shared message service for multi-agent communication. + + Parameters + ---------- + store_path : Path or None + Where to persist mailbox state. ``None`` means purely + in-memory (no file I/O). + """ + + def __init__(self, store_path: Path | None = None): + self._queues: dict[str, list[AgentMessage]] = {} + self._notify: dict[str, asyncio.Event] = {} + self._store_path = store_path + + if store_path is not None: + loaded = _load_mailbox(store_path) + if loaded: + self._queues = loaded + + # ── registration ───────────────────────────────────────────── + + def register(self, agent_name: str) -> None: + """Register an agent, initialising its queue and notification + event. Idempotent — safe to call multiple times.""" + self._queues.setdefault(agent_name, []) + if agent_name not in self._notify: + self._notify[agent_name] = asyncio.Event() + + @property + def agents(self) -> list[str]: + """Names of all registered agents.""" + return list(self._queues.keys()) + + # ── send ───────────────────────────────────────────────────── + + def send(self, message: AgentMessage) -> None: + """Send a message. Uses ``message.receiver`` for routing. + + If ``receiver`` is ``"*"``, the message is broadcast to every + registered agent except the sender. + """ + receiver = message.receiver + if not receiver: + raise ValueError("AgentMessage.receiver must be set") + + if receiver == "*": + for name in list(self._queues.keys()): + if name != message.sender: + self._queues[name].append(message) + self._signal(name) + else: + self._queues.setdefault(receiver, []).append(message) + self._signal(receiver) + + self._persist() + + # ── receive ────────────────────────────────────────────────── + + def drain(self, agent_name: str) -> list[AgentMessage]: + """Take all pending messages for *agent_name* and clear its + queue. Intended to be called from ``EnvAgent.get_env_info()``. + """ + msgs = self._queues.get(agent_name) + if not msgs: + return [] + taken = list(msgs) + msgs.clear() + self._persist() + return taken + + def has_messages(self, agent_name: str) -> bool: + """Quick non-destructive check.""" + return bool(self._queues.get(agent_name)) + + def peek(self, agent_name: str) -> list[AgentMessage]: + """View pending messages without removing them.""" + return list(self._queues.get(agent_name, [])) + + # ── wait (async notification) ──────────────────────────────── + + async def wait_for_message( + self, agent_name: str, timeout: float | None = None, + ) -> bool: + """Block until a message arrives for *agent_name*. + + Returns ``True`` if a message arrived, ``False`` on timeout. + Use this in worker loops to avoid ``asyncio.sleep()`` polling:: + + if not board.list_available(): + await mailbox.wait_for_message(name, timeout=5.0) + """ + if agent_name not in self._notify: + self.register(agent_name) + event = self._notify[agent_name] + event.clear() + try: + await asyncio.wait_for(event.wait(), timeout=timeout) + return True + except asyncio.TimeoutError: + return False + + def _signal(self, agent_name: str) -> None: + event = self._notify.get(agent_name) + if event is not None: + event.set() + + # ── management ─────────────────────────────────────────────── + + def clear(self, agent_name: str) -> None: + """Discard all pending messages for one agent.""" + self._queues.pop(agent_name, None) + self._persist() + + def clear_all(self) -> None: + """Discard all messages for all agents.""" + for q in self._queues.values(): + q.clear() + self._persist() + + def message_count(self, agent_name: str | None = None) -> int: + """Total pending messages, optionally filtered by agent.""" + if agent_name is not None: + return len(self._queues.get(agent_name, [])) + return sum(len(q) for q in self._queues.values()) + + # ── serialisation ──────────────────────────────────────────── + + def state_dict(self) -> dict: + return { + "version": 1, + "queues": { + name: [_msg_to_dict(m) for m in msgs] + for name, msgs in self._queues.items() + }, + } + + def load_state_dict(self, state: dict) -> None: + self._queues = {} + for name, msgs in state.get("queues", {}).items(): + self._queues[name] = [ + _msg_from_dict(d) for d in msgs if isinstance(d, dict) + ] + self._persist() + + # ── persistence ────────────────────────────────────────────── + + def _persist(self) -> None: + if self._store_path is not None: + _save_mailbox(self._queues, self._store_path) diff --git a/lagent/services/task.py b/lagent/services/task.py new file mode 100644 index 00000000..9ff42402 --- /dev/null +++ b/lagent/services/task.py @@ -0,0 +1,471 @@ +"""TaskBoard -- structured task management with dependency tracking. + +A shared resource object that agents use to organise work into trackable +tasks. Inspired by Claude Code's production task system, adapted for +lagent's PyTorch-style architecture. + +Design +------ +* **Independent resource** -- not part of Agent; injected into Actions. +* **In-memory by default** -- optional JSON file persistence via *store_path*. +* **state_dict / load_state_dict** -- integrates with lagent's Module tree. +* **Dependency graph** -- bidirectional ``blocks`` / ``blocked_by`` with + cascading cleanup on delete. +* **High water mark** -- task IDs are monotonically increasing and never + reused, even after deletion. +""" + +from __future__ import annotations + +import json +import logging +import time +from dataclasses import asdict, dataclass, field +from pathlib import Path +from typing import Any + +logger = logging.getLogger("lagent.services.task") + +_UNSET = object() +"""Sentinel to distinguish 'not provided' from 'set to None'.""" + + +def _now_ms() -> int: + return int(time.time() * 1000) + + +# ── Status constants ───────────────────────────────────────────────── + +TASK_STATUSES = ("pending", "in_progress", "completed") + + +# ── Data model ─────────────────────────────────────────────────────── + +@dataclass +class Task: + """A single task on the board.""" + + id: str = "" + subject: str = "" + description: str = "" + status: str = "pending" + active_form: str | None = None + owner: str | None = None + blocks: list[str] = field(default_factory=list) + blocked_by: list[str] = field(default_factory=list) + metadata: dict[str, Any] | None = None + created_at_ms: int = field(default_factory=_now_ms) + updated_at_ms: int = field(default_factory=_now_ms) + + def to_dict(self) -> dict: + return asdict(self) + + @classmethod + def from_dict(cls, data: dict) -> Task: + return cls( + id=data.get("id", ""), + subject=data.get("subject", ""), + description=data.get("description", ""), + status=data.get("status", "pending"), + active_form=data.get("active_form"), + owner=data.get("owner"), + blocks=list(data.get("blocks") or []), + blocked_by=list(data.get("blocked_by") or []), + metadata=data.get("metadata"), + created_at_ms=data.get("created_at_ms", _now_ms()), + updated_at_ms=data.get("updated_at_ms", _now_ms()), + ) + + +@dataclass +class ClaimResult: + """Result of a :meth:`TaskBoard.claim` attempt.""" + + success: bool + task: Task | None = None + reason: str = "" + blocked_by_tasks: list[str] | None = None + busy_with_tasks: list[str] | None = None + + +# ── Persistence helpers ────────────────────────────────────────────── + +def _load_store(path: Path) -> tuple[list[Task], int]: + if not path.exists(): + return [], 1 + try: + data = json.loads(path.read_text("utf-8")) + if not isinstance(data, dict) or data.get("version") != 1: + return [], 1 + tasks = [ + Task.from_dict(t) + for t in data.get("tasks", []) + if isinstance(t, dict) + ] + next_id = data.get("next_id", 1) + return tasks, next_id + except (json.JSONDecodeError, OSError) as exc: + logger.warning("Failed to load task store %s: %s", path, exc) + return [], 1 + + +def _save_store(tasks: list[Task], next_id: int, path: Path) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + payload = { + "version": 1, + "next_id": next_id, + "tasks": [t.to_dict() for t in tasks], + } + tmp = path.with_suffix(".tmp") + try: + tmp.write_text( + json.dumps(payload, indent=2, ensure_ascii=False, default=str), + "utf-8", + ) + tmp.replace(path) + except OSError as exc: + logger.error("Failed to save task store: %s", exc) + tmp.unlink(missing_ok=True) + + +# ── TaskBoard ──────────────────────────────────────────────────────── + +class TaskBoard: + """In-memory task board with optional JSON persistence. + + Parameters + ---------- + store_path : Path or None + Where to persist ``tasks.json``. ``None`` means purely + in-memory (no file I/O). + """ + + def __init__(self, store_path: Path | None = None): + self._store_path = store_path + self._tasks: list[Task] = [] + self._next_id: int = 1 + + if store_path is not None: + loaded, hwm = _load_store(store_path) + if loaded: + self._tasks = loaded + self._next_id = max(hwm, self._next_id) + + # ── helpers ─────────────────────────────────────────────────── + + def _find(self, task_id: str) -> Task | None: + return next((t for t in self._tasks if t.id == task_id), None) + + def _persist(self) -> None: + if self._store_path is not None: + _save_store(self._tasks, self._next_id, self._store_path) + + def _ensure_next_id(self) -> None: + """Safety: make sure _next_id is past any existing task ID.""" + if self._tasks: + max_existing = max(int(t.id) for t in self._tasks) + self._next_id = max(self._next_id, max_existing + 1) + + def _add_block(self, from_id: str, to_id: str) -> None: + """Bidirectional: from_task.blocks += to_id, to_task.blocked_by += from_id.""" + from_task = self._find(from_id) + to_task = self._find(to_id) + if from_task is not None and to_id not in from_task.blocks: + from_task.blocks.append(to_id) + if to_task is not None and from_id not in to_task.blocked_by: + to_task.blocked_by.append(from_id) + + # ── CRUD ────────────────────────────────────────────────────── + + def create( + self, + subject: str, + description: str, + active_form: str | None = None, + blocked_by: list[str] | None = None, + metadata: dict[str, Any] | None = None, + ) -> Task: + """Create a new task. Returns the created :class:`Task`.""" + task_id = str(self._next_id) + self._next_id += 1 + now = _now_ms() + + task = Task( + id=task_id, + subject=subject, + description=description, + active_form=active_form, + status="pending", + blocked_by=list(blocked_by or []), + metadata=metadata, + created_at_ms=now, + updated_at_ms=now, + ) + + self._tasks.append(task) + + # Bidirectional: update blockers' "blocks" lists + for dep_id in task.blocked_by: + dep = self._find(dep_id) + if dep is not None and task_id not in dep.blocks: + dep.blocks.append(task_id) + + self._persist() + return task + + def update( + self, + task_id: str, + status=_UNSET, + subject=_UNSET, + description=_UNSET, + active_form=_UNSET, + owner=_UNSET, + metadata=_UNSET, + add_blocks: list[str] | None = None, + add_blocked_by: list[str] | None = None, + ) -> Task | None: + """Update a task. Returns the updated :class:`Task`, or None if + not found. + + Pass ``status="deleted"`` to permanently remove the task + (equivalent to :meth:`delete`). + """ + # Handle "deleted" status as hard delete + if status is not _UNSET and status == "deleted": + self.delete(task_id) + return None + + task = self._find(task_id) + if task is None: + return None + + if status is not _UNSET: + task.status = status + if subject is not _UNSET: + task.subject = subject + if description is not _UNSET: + task.description = description + if active_form is not _UNSET: + task.active_form = active_form + if owner is not _UNSET: + task.owner = owner + if metadata is not _UNSET and metadata is not None: + if task.metadata is None: + task.metadata = {} + for k, v in metadata.items(): + if v is None: + task.metadata.pop(k, None) + else: + task.metadata[k] = v + + # Dependency additions + if add_blocks: + for to_id in add_blocks: + self._add_block(task_id, to_id) + if add_blocked_by: + for from_id in add_blocked_by: + self._add_block(from_id, task_id) + + task.updated_at_ms = _now_ms() + self._persist() + return task + + def delete(self, task_id: str) -> bool: + """Permanently remove a task and cascade-clean all references.""" + task = self._find(task_id) + if task is None: + return False + + # Cascading cleanup + for other in self._tasks: + if task_id in other.blocks: + other.blocks.remove(task_id) + if task_id in other.blocked_by: + other.blocked_by.remove(task_id) + + self._tasks = [t for t in self._tasks if t.id != task_id] + # Do NOT decrement _next_id -- high water mark + self._persist() + return True + + def get(self, task_id: str) -> Task | None: + """Get a task by ID.""" + return self._find(task_id) + + def list(self, status: str | None = None) -> list[Task]: + """List tasks, optionally filtered by status.""" + if status is None: + return list(self._tasks) + return [t for t in self._tasks if t.status == status] + + # ── Multi-agent coordination ────────────────────────────────── + + def claim( + self, + task_id: str, + agent_name: str, + check_busy: bool = True, + ) -> ClaimResult: + """Attempt to claim a task for an agent. + + Checks + ------ + 1. Task exists. + 2. Task not already completed. + 3. Task not already claimed by a *different* agent. + 4. All ``blocked_by`` tasks are completed. + 5. (optional) Agent has no other ``in_progress`` tasks. + """ + task = self._find(task_id) + if task is None: + return ClaimResult(False, reason="task_not_found") + + if task.status == "completed": + return ClaimResult(False, task=task, reason="already_completed") + + if task.owner is not None and task.owner != agent_name: + return ClaimResult( + False, task=task, + reason=f"already_claimed by {task.owner}", + ) + + # Check unresolved blockers + unresolved = [ + bid for bid in task.blocked_by + if (b := self._find(bid)) is not None + and b.status != "completed" + ] + if unresolved: + return ClaimResult( + False, task=task, reason="blocked", + blocked_by_tasks=unresolved, + ) + + # Check if agent is already busy + if check_busy: + busy = [ + t.id for t in self._tasks + if t.owner == agent_name + and t.status == "in_progress" + and t.id != task_id + ] + if busy: + return ClaimResult( + False, task=task, reason="agent_busy", + busy_with_tasks=busy, + ) + + task.owner = agent_name + task.status = "in_progress" + task.updated_at_ms = _now_ms() + self._persist() + return ClaimResult(True, task=task) + + def release_agent(self, agent_name: str) -> list[Task]: + """Release all non-completed tasks owned by *agent_name*. + + Resets ``owner`` to None and ``status`` to ``"pending"``. + Used when an agent shuts down or fails. + """ + released = [] + for task in self._tasks: + if task.owner == agent_name and task.status != "completed": + task.owner = None + task.status = "pending" + task.updated_at_ms = _now_ms() + released.append(task) + if released: + self._persist() + return released + + # ── Query helpers ───────────────────────────────────────────── + + def all_completed(self) -> bool: + """True if every task on the board is completed (or board is empty).""" + return len(self._tasks) == 0 or all( + t.status == "completed" for t in self._tasks + ) + + def list_available(self) -> list[Task]: + """List tasks that can be claimed: pending, no unresolved blockers, + no owner.""" + available = [] + for task in self._tasks: + if task.status != "pending" or task.owner is not None: + continue + unresolved = [ + bid for bid in task.blocked_by + if (b := self._find(bid)) is not None + and b.status != "completed" + ] + if not unresolved: + available.append(task) + return available + + def get_summary(self) -> str: + """One-line-per-task summary, suitable for prompt injection. + + Format:: + + Summary: 2 pending, 1 in_progress, 3 completed + + #1. [completed] Audit existing code + #2. [in_progress] Design JWT schema @coder + #3. [pending] Implement JWT ▶ blocked by #2 + """ + if not self._tasks: + return "No tasks." + + counts: dict[str, int] = {} + for t in self._tasks: + counts[t.status] = counts.get(t.status, 0) + 1 + + header = "Summary: " + ", ".join( + f"{v} {k}" for k, v in counts.items() + ) + + completed_ids = { + t.id for t in self._tasks if t.status == "completed" + } + lines = [] + for t in self._tasks: + # Status icon + icon = {"completed": "completed", "in_progress": "in_progress", + "pending": "pending"}.get(t.status, t.status) + line = f"#{t.id}. [{icon}] {t.subject}" + + # Owner annotation + if t.owner: + line += f" @{t.owner}" + + # Active blocker annotation (filter out completed) + active_blockers = [ + bid for bid in t.blocked_by if bid not in completed_ids + ] + if active_blockers: + refs = ", ".join(f"#{bid}" for bid in active_blockers) + line += f" ▶ blocked by {refs}" + + lines.append(line) + + return header + "\n\n" + "\n".join(lines) + + # ── Serialisation ───────────────────────────────────────────── + + def state_dict(self) -> dict: + return { + "version": 1, + "next_id": self._next_id, + "tasks": [t.to_dict() for t in self._tasks], + } + + def load_state_dict(self, state: dict) -> None: + self._tasks = [ + Task.from_dict(t) + for t in state.get("tasks", []) + if isinstance(t, dict) + ] + self._next_id = state.get("next_id", 1) + self._ensure_next_id() + self._persist() diff --git a/lagent/serving/__init__.py b/lagent/serving/__init__.py new file mode 100644 index 00000000..72405849 --- /dev/null +++ b/lagent/serving/__init__.py @@ -0,0 +1,8 @@ +"""lagent.serving — agent deployment backends. + +Each sub-package provides a server + client pair for a specific transport: + +- ``http`` — HTTP API (Starlette/uvicorn) +- ``ray`` — Ray Serve +- ``sandbox`` — Unix socket daemon via bash channel +""" diff --git a/lagent/distributed/http_serve/__init__.py b/lagent/serving/http/__init__.py similarity index 100% rename from lagent/distributed/http_serve/__init__.py rename to lagent/serving/http/__init__.py diff --git a/lagent/serving/http/api_server.py b/lagent/serving/http/api_server.py new file mode 100644 index 00000000..a866f232 --- /dev/null +++ b/lagent/serving/http/api_server.py @@ -0,0 +1,249 @@ +import json +import os +import subprocess +import sys +import time +import threading +from typing import Dict, Optional + +import aiohttp +import requests + +from lagent.schema import AgentMessage + + +class HTTPAgentClient: + """Stateless HTTP client for the Agent API. + + The client manages an optional local ``_state_dict`` that is automatically + sent with each ``__call__`` and updated from the response. This makes the + *server* stateless while keeping a familiar stateful feel on the client + side. + """ + + def __init__(self, host='127.0.0.1', port=8090, timeout=None): + self.host = host + self.port = port + self.timeout = timeout + self._state_dict: Optional[Dict] = None + + @property + def base_url(self): + return f'http://{self.host}:{self.port}' + + @property + def is_alive(self): + try: + resp = requests.get( + f'{self.base_url}/health_check', + timeout=self.timeout) + return resp.status_code == 200 + except Exception: + return False + + # ------------------------------------------------------------------ + # Core call – send state, receive new state + # ------------------------------------------------------------------ + def __call__(self, *message, session_state: Optional[Dict] = None, **kwargs): + """Run the remote agent. + + Parameters + ---------- + *message : str | AgentMessage + Messages to send. + session_state : dict, optional + Explicit state to use for this call. If *None* the client's + internal ``_state_dict`` is used (which may also be *None* for a + fresh session). + **kwargs + Extra keyword arguments forwarded to the agent. + + Returns + ------- + AgentMessage + The agent's response. + """ + state = session_state if session_state is not None else self._state_dict + payload = { + 'message': [ + m if isinstance(m, str) else m.model_dump() + for m in message + ], + **kwargs, + } + if state is not None: + payload['state_dict'] = state + + response = requests.post( + f'{self.base_url}/chat_completion', + json=payload, + headers={'Content-Type': 'application/json'}, + timeout=self.timeout) + resp = response.json() + if response.status_code != 200: + return resp + + # Update local state from server response + self._state_dict = resp.get('state_dict', self._state_dict) + return AgentMessage.model_validate(resp['response']) + + # ------------------------------------------------------------------ + # State management + # ------------------------------------------------------------------ + def state_dict(self) -> Dict: + """Return the current local state dict.""" + if self._state_dict is not None: + return self._state_dict + # Fetch a fresh empty state from the server + resp = requests.post( + f'{self.base_url}/state_dict', + json={}, + headers={'Content-Type': 'application/json'}, + timeout=self.timeout) + return resp.json().get('state_dict', {}) + + def load_state_dict(self, state_dict: Dict): + """Load a state dict into the client (validated by the server).""" + resp = requests.post( + f'{self.base_url}/load_state_dict', + json={'state_dict': state_dict}, + headers={'Content-Type': 'application/json'}, + timeout=self.timeout) + data = resp.json() + if resp.status_code != 200: + raise RuntimeError(f'Failed to load state_dict: {data}') + self._state_dict = data.get('state_dict', state_dict) + + def reset(self): + """Reset the client to a fresh state.""" + resp = requests.post( + f'{self.base_url}/reset', + json={}, + headers={'Content-Type': 'application/json'}, + timeout=self.timeout) + data = resp.json() + self._state_dict = data.get('state_dict', None) + + +class HTTPAgentServer(HTTPAgentClient): + + def __init__(self, gpu_id, config, host='127.0.0.1', port=8090): + super().__init__(host, port) + self.gpu_id = gpu_id + self.config = config + self.start_server() + + def start_server(self): + # set CUDA_VISIBLE_DEVICES in subprocess + env = os.environ.copy() + env['CUDA_VISIBLE_DEVICES'] = self.gpu_id + cmds = [ + sys.executable, 'lagent/distributed/http_serve/app.py', '--host', + self.host, '--port', + str(self.port), '--config', + json.dumps(self.config) + ] + self.process = subprocess.Popen( + cmds, + env=env, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True) + + self.service_started = False + + def log_output(stream): + if stream is not None: + for line in iter(stream.readline, ''): + print(line, end='') + if 'Uvicorn running on' in line: + self.service_started = True + + # Start log output thread + threading.Thread(target=log_output, args=(self.process.stdout,), daemon=True).start() + threading.Thread(target=log_output, args=(self.process.stderr,), daemon=True).start() + + # Waiting for the service to start + while not self.service_started: + time.sleep(0.1) + + def shutdown(self): + self.process.terminate() + self.process.wait() + + +class AsyncHTTPAgentMixin: + + async def __call__(self, *message, session_state: Optional[Dict] = None, **kwargs): + """Async version of the stateless agent call.""" + state = session_state if session_state is not None else self._state_dict + payload = { + 'message': [ + m if isinstance(m, str) else m.model_dump() + for m in message + ], + **kwargs, + } + if state is not None: + payload['state_dict'] = state + + async with aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout(self.timeout)) as session: + async with session.post( + f'{self.base_url}/chat_completion', + json=payload, + headers={'Content-Type': 'application/json'}, + ) as response: + resp = await response.json() + if response.status != 200: + return resp + self._state_dict = resp.get('state_dict', self._state_dict) + return AgentMessage.model_validate(resp['response']) + + async def async_state_dict(self) -> Dict: + """Async version of state_dict.""" + if self._state_dict is not None: + return self._state_dict + async with aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout(self.timeout)) as session: + async with session.post( + f'{self.base_url}/state_dict', + json={}, + headers={'Content-Type': 'application/json'}, + ) as response: + data = await response.json() + return data.get('state_dict', {}) + + async def async_load_state_dict(self, state_dict: Dict): + """Async version of load_state_dict.""" + async with aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout(self.timeout)) as session: + async with session.post( + f'{self.base_url}/load_state_dict', + json={'state_dict': state_dict}, + headers={'Content-Type': 'application/json'}, + ) as response: + data = await response.json() + if response.status != 200: + raise RuntimeError(f'Failed to load state_dict: {data}') + self._state_dict = data.get('state_dict', state_dict) + + async def async_reset(self): + """Async version of reset.""" + async with aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout(self.timeout)) as session: + async with session.post( + f'{self.base_url}/reset', + json={}, + headers={'Content-Type': 'application/json'}, + ) as response: + data = await response.json() + self._state_dict = data.get('state_dict', None) + + +class AsyncHTTPAgentClient(AsyncHTTPAgentMixin, HTTPAgentClient): + pass + + +class AsyncHTTPAgentServer(AsyncHTTPAgentMixin, HTTPAgentServer): + pass diff --git a/lagent/serving/http/app.py b/lagent/serving/http/app.py new file mode 100644 index 00000000..b0523b29 --- /dev/null +++ b/lagent/serving/http/app.py @@ -0,0 +1,191 @@ +import argparse +import json +import logging +import time + +import uvicorn +from fastapi import FastAPI, HTTPException +from fastapi.middleware.cors import CORSMiddleware +from fastapi.requests import Request + +from lagent.schema import AgentMessage +from lagent.utils import load_class_from_string + + +class AgentAPIServer: + """Stateless Agent API Server. + + Each request can optionally carry a ``state_dict`` so the server restores + agent state before execution and returns the updated state afterwards. + The server itself holds only a *template* agent used for creating fresh + instances via ``load_state_dict``. + + API routes + ---------- + GET /health_check – liveness probe + POST /chat_completion – run agent with optional state round-trip + POST /state_dict – export current agent state (from given state) + POST /load_state_dict – validate & echo back a state_dict + POST /reset – return a fresh (empty) state_dict + """ + + def __init__(self, + config: dict, + host: str = '127.0.0.1', + port: int = 8090): + self.app = FastAPI(docs_url='/') + self.app.add_middleware( + CORSMiddleware, + allow_origins=['*'], + allow_credentials=True, + allow_methods=['*'], + allow_headers=['*'], + ) + cls_name = config.pop('type') + python_path = config.pop('python_path', None) + cls_name = load_class_from_string(cls_name, python_path) if isinstance( + cls_name, str) else cls_name + # Keep both the class and default config so we can mint new instances + self._agent_cls = cls_name + self._agent_config = config + # Template agent – used as the prototype for every request + self.agent = cls_name(**config) + self.setup_routes() + self.run(host, port) + + # ------------------------------------------------------------------ + # Helpers + # ------------------------------------------------------------------ + def _prepare_agent(self, state_dict=None): + """Create a fresh agent instance with shared heavy resources + (llm, actions, skills) but a new empty memory, then optionally + restore state from *state_dict*.""" + agent = self.agent.new_instance() + if state_dict: + agent.load_state_dict(state_dict) + return agent + + # ------------------------------------------------------------------ + # Routes + # ------------------------------------------------------------------ + def setup_routes(self): + + def heartbeat(): + return {'status': 'success', 'timestamp': time.time()} + + async def process_message(request: Request): + """Run the agent. + + Request body + ------------ + - ``message``: list of str / AgentMessage dicts (required) + - ``state_dict``: dict – agent state to restore before running + - any other keys are forwarded as ``**kwargs`` to the agent + + Response + -------- + - ``response``: the AgentMessage returned by the agent + - ``state_dict``: updated agent state after execution + """ + try: + body = await request.json() + message = [ + m if isinstance(m, str) else AgentMessage.model_validate(m) + for m in body.pop('message') + ] + incoming_state = body.pop('state_dict', None) + agent = self._prepare_agent(incoming_state) + result = await agent(*message, **body) + return { + 'response': result, + 'state_dict': agent.state_dict(), + } + except Exception as e: + logging.error(f'Error processing message: {str(e)}', + exc_info=True) + raise HTTPException( + status_code=500, detail=str(e)) + + async def get_state_dict(request: Request): + """Export agent state. + + If a ``state_dict`` is provided in the body, load it first then + re-export (useful for normalisation). Otherwise return a fresh + empty state. + """ + try: + body = await request.json() if request.headers.get( + 'content-length', '0') != '0' else {} + incoming_state = body.get('state_dict', None) + agent = self._prepare_agent(incoming_state) + return {'state_dict': agent.state_dict()} + except Exception as e: + logging.error(f'Error in state_dict: {str(e)}', exc_info=True) + raise HTTPException( + status_code=500, detail=str(e)) + + async def load_state_dict(request: Request): + """Validate a state_dict by loading it and returning the result.""" + try: + body = await request.json() + state_dict = body.get('state_dict') + if state_dict is None: + raise HTTPException( + status_code=400, + detail='state_dict is required') + agent = self._prepare_agent(state_dict) + return { + 'status': 'success', + 'state_dict': agent.state_dict(), + } + except HTTPException: + raise + except Exception as e: + logging.error(f'Error in load_state_dict: {str(e)}', + exc_info=True) + raise HTTPException( + status_code=500, detail=str(e)) + + async def reset_state(): + """Return a fresh empty state_dict.""" + try: + agent = self._prepare_agent() + return {'state_dict': agent.state_dict()} + except Exception as e: + logging.error(f'Error in reset: {str(e)}', exc_info=True) + raise HTTPException( + status_code=500, detail=str(e)) + + self.app.add_api_route('/health_check', heartbeat, methods=['GET']) + self.app.add_api_route( + '/chat_completion', process_message, methods=['POST']) + self.app.add_api_route( + '/state_dict', get_state_dict, methods=['POST']) + self.app.add_api_route( + '/load_state_dict', load_state_dict, methods=['POST']) + self.app.add_api_route( + '/reset', reset_state, methods=['POST']) + + def run(self, host='127.0.0.1', port=8090): + logging.info(f'Starting server at {host}:{port}') + uvicorn.run(self.app, host=host, port=port) + + +def parse_args(): + parser = argparse.ArgumentParser(description='Async Agent API Server') + parser.add_argument('--host', type=str, default='127.0.0.1') + parser.add_argument('--port', type=int, default=8090) + parser.add_argument( + '--config', + type=json.loads, + required=True, + help='JSON configuration for the agent') + args = parser.parse_args() + + return args + + +if __name__ == '__main__': + logging.basicConfig(level=logging.INFO) + args = parse_args() + AgentAPIServer(args.config, host=args.host, port=args.port) diff --git a/lagent/serving/ray/__init__.py b/lagent/serving/ray/__init__.py new file mode 100644 index 00000000..bec3e237 --- /dev/null +++ b/lagent/serving/ray/__init__.py @@ -0,0 +1,3 @@ +from .ray_wrapper import AgentRayActor, AsyncAgentRayActor + +__all__ = ['AsyncAgentRayActor', 'AgentRayActor'] diff --git a/lagent/serving/ray/ray_wrapper.py b/lagent/serving/ray/ray_wrapper.py new file mode 100644 index 00000000..c2737a27 --- /dev/null +++ b/lagent/serving/ray/ray_wrapper.py @@ -0,0 +1,123 @@ +import copy +import importlib +import sys +from typing import Dict, Optional + +import ray + +from lagent.schema import AgentMessage +from lagent.utils import load_class_from_string + + +class _StatelessAgentWrapper: + """Wrapper that runs inside a Ray actor. Holds a *template* agent and + creates a fresh instance (via ``new_instance``) per call so that + concurrent ``__call__`` invocations are isolated from each other while + sharing heavy resources (llm, actions, connections).""" + + def __init__(self, agent): + self.agent = agent + + async def stateless_call(self, messages, state_dict=None, **kwargs): + """Create a new agent, optionally load state, run, return result + new state.""" + agent = self.agent.new_instance() + if state_dict: + agent.load_state_dict(state_dict) + response = await agent(*messages, **kwargs) + return {'response': response, 'state_dict': agent.state_dict()} + + def sync_stateless_call(self, messages, state_dict=None, **kwargs): + agent = self.agent.new_instance() + if state_dict: + agent.load_state_dict(state_dict) + response = agent(*messages, **kwargs) + return {'response': response, 'state_dict': agent.state_dict()} + + def get_state_dict(self): + """Return a fresh empty state_dict from the template agent.""" + agent = self.agent.new_instance() + return agent.state_dict() + + +class AsyncAgentRayActor: + """Stateless async Ray actor wrapper for agents. + + Each call can carry a ``state_dict`` to restore the agent's state before + execution. The response includes the updated ``state_dict``. + Heavy resources (llm, actions, connections) are shared across calls + inside the actor while memory is isolated per request. + """ + + def __init__( + self, + config: Dict, + num_gpus: int, + ): + config = copy.deepcopy(config) + cls_name = config.pop('type') + python_path = config.pop('python_path', None) + cls_name = load_class_from_string(cls_name, python_path) if isinstance( + cls_name, str) else cls_name + + # The wrapper holds the template agent; Ray actor wraps the wrapper + WrappedActor = ray.remote(num_gpus=num_gpus)(_StatelessAgentWrapper) + template_agent = cls_name(**config) + self.actor = WrappedActor.remote(template_agent) + + async def __call__( + self, + *message: AgentMessage, + state_dict: Optional[Dict] = None, + **kwargs, + ): + """Run the remote agent with optional state round-trip. + + Returns + ------- + dict ``{'response': AgentMessage, 'state_dict': dict}`` + """ + return await self.actor.stateless_call.remote( + list(message), state_dict=state_dict, **kwargs) + + async def state_dict(self) -> Dict: + return await self.actor.get_state_dict.remote() + + async def reset(self): + """No-op for stateless wrapper – each call already starts fresh.""" + pass + + +class AgentRayActor: + """Stateless sync Ray actor wrapper for agents.""" + + def __init__( + self, + config: Dict, + num_gpus: int, + ): + config = copy.deepcopy(config) + cls_name = config.pop('type') + python_path = config.pop('python_path', None) + cls_name = load_class_from_string(cls_name, python_path) if isinstance( + cls_name, str) else cls_name + + WrappedActor = ray.remote(num_gpus=num_gpus)(_StatelessAgentWrapper) + template_agent = cls_name(**config) + self.actor = WrappedActor.remote(template_agent) + + def __call__( + self, + *message: AgentMessage, + state_dict: Optional[Dict] = None, + **kwargs, + ): + """Run the remote agent with optional state round-trip.""" + return ray.get(self.actor.sync_stateless_call.remote( + list(message), state_dict=state_dict, **kwargs)) + + def state_dict(self) -> Dict: + return ray.get(self.actor.get_state_dict.remote()) + + def reset(self): + """No-op for stateless wrapper.""" + pass diff --git a/lagent/serving/sandbox/__init__.py b/lagent/serving/sandbox/__init__.py new file mode 100644 index 00000000..b1472e82 --- /dev/null +++ b/lagent/serving/sandbox/__init__.py @@ -0,0 +1,26 @@ +"""lagent.serving.sandbox — sandbox deployment via Unix socket daemon. + +Server side (runs inside sandbox): + - :class:`BaseDaemon` — socket server + protocol + - :class:`ActionDaemon` — action execution + - :class:`SkillsDaemon` — skills loading + - :class:`AgentDaemon` — full agent + +Client side: + - :class:`SandboxAgent` — drop-in for AsyncAgent + - ``SandboxActionExecutor`` — see ``lagent.actions.sandbox_executor`` + - ``SandboxSkillsLoader`` — see ``lagent.skills.sandbox_skills`` +""" + +from .agent import SandboxAgent +from .daemon import ActionDaemon, AgentDaemon, BaseDaemon, SkillsDaemon, async_lagent_call, lagent_call + +__all__ = [ + "BaseDaemon", + "ActionDaemon", + "SkillsDaemon", + "AgentDaemon", + "SandboxAgent", + "lagent_call", + "async_lagent_call", +] diff --git a/lagent/serving/sandbox/agent.py b/lagent/serving/sandbox/agent.py new file mode 100644 index 00000000..2d316b54 --- /dev/null +++ b/lagent/serving/sandbox/agent.py @@ -0,0 +1,223 @@ +"""SandboxAgent — drop-in replacement for ``AsyncAgent`` that proxies +to an :class:`~lagent.serving.sandbox.daemon.AgentDaemon` +running inside a sandbox. + +The daemon holds a full agent (LLM + Actions + Skills). SandboxAgent +sends ``chat`` / ``state_dict`` / ``reset`` commands via the sandbox's +bash channel and returns the results as regular ``AgentMessage`` objects. + +Usage:: + + from lagent.serving.sandbox.agent import SandboxAgent + + agent = SandboxAgent( + sandbox_client=my_sandbox_client, + agent_config={ + "type": "lagent.agents.internclaw_agent.InternClawAgent", + "policy_agent": {...}, + "env_agent": {...}, + }, + ) + await agent.connect() + + response = await agent("Fix the bug in main.py") + print(response.content) + + state = await agent.get_state_dict() + await agent.reset() +""" + +from __future__ import annotations + +import asyncio +import inspect +import json +import logging +from typing import Any, Dict, List, Optional, Union + +from lagent.schema import AgentMessage + +logger = logging.getLogger(__name__) + + +class SandboxAgent: + """Drop-in for ``AsyncAgent`` that proxies to a Level 2 daemon. + + Parameters + ---------- + sandbox_client + Object with ``execute(command, ...)`` that runs bash in the sandbox. + May be sync or async. + agent_config : dict + Agent configuration (passed to daemon, which calls ``create_object``). + sock_path : str + Unix socket path inside the sandbox. + cwd : str + Working directory for bash commands. + daemon_module : str + Python module path for the daemon. + """ + + def __init__( + self, + sandbox_client, + agent_config: Dict, + sock_path: str = "/tmp/lagent_action.sock", + cwd: str = "/root", + daemon_module: str = "lagent.serving.sandbox.daemon", + ): + self.sandbox_client = sandbox_client + self.agent_config = agent_config + self.sock_path = sock_path + self.cwd = cwd + self.daemon_module = daemon_module + self._connected = False + self._lock = asyncio.Lock() + + # -- sandbox communication -- + + async def _exec(self, command: str, **kwargs) -> str: + kwargs.setdefault("cwd", self.cwd) + execute_fn = self.sandbox_client.execute + if inspect.iscoroutinefunction(execute_fn): + result = await execute_fn(command, **kwargs) + else: + result = await asyncio.to_thread(execute_fn, command, **kwargs) + if isinstance(result, dict): + stdout = result.get("stdout", "") + # If stdout is empty but stderr has content, raise so caller can debug + if not stdout.strip() and result.get("stderr", "").strip(): + raise RuntimeError(f"Command stderr: {result['stderr'][:500]}") + return stdout + return result + + async def _daemon_call(self, request: dict, timeout_sec: int = 600) -> dict: + """Send a request to daemon. Uses longer timeout for chat commands.""" + request_json = json.dumps(request, ensure_ascii=False) + escaped = request_json.replace("'", "'\\''") + output = await self._exec( + f"/mnt/llm-ai-infra/miniconda3/envs/train/bin/python -m {self.daemon_module} call " + f"--sock {self.sock_path} " + f"'{escaped}'", + timeout_sec=timeout_sec, + ) + if not output.strip(): + raise RuntimeError("Daemon returned empty response (may have crashed or timed out)") + return json.loads(output.strip()) + + # -- lifecycle -- + + async def connect(self) -> None: + """Start the Level 2 daemon inside the sandbox (idempotent).""" + async with self._lock: + if self._connected: + return + + # Write agent config + config_json = json.dumps(self.agent_config, ensure_ascii=False) + escaped = config_json.replace("'", "'\\''") + await self._exec( + f"echo '{escaped}' > /tmp/lagent_agent_config.json" + ) + + # Check if daemon already running, start if not + check = await self._exec( + f"pgrep -f '[l]agent.serving.sandbox.daemon.*--sock {self.sock_path}'" + f" > /dev/null 2>&1 && echo 'running' || echo 'stopped'" + ) + if "stopped" in check: + await self._exec( + f"nohup python -m {self.daemon_module} start " + f"--sock {self.sock_path} " + f"--agent-config /tmp/lagent_agent_config.json " + f"> /tmp/lagent_daemon.log 2>&1 &" + ) + + # Wait for socket + for _ in range(30): + try: + output = await self._exec( + f"test -S {self.sock_path} && echo 'ready' || echo 'waiting'" + ) + if "ready" in output: + break + except Exception: + pass + await asyncio.sleep(0.5) + else: + raise TimeoutError( + f"Daemon did not start within 15s. " + f"Check /tmp/lagent_daemon.log inside sandbox." + ) + + # Verify it's an AgentDaemon + result = await self._daemon_call({"cmd": "ping"}) + assert result.get("type") == "agent", ( + f"Expected AgentDaemon, got: {result}" + ) + self._connected = True + logger.info("SandboxAgent connected (Level 2 daemon)") + + async def close(self) -> None: + if not self._connected: + return + try: + await self._daemon_call({"cmd": "shutdown"}) + except Exception: + pass + self._connected = False + + # -- Agent-compatible interface -- + + async def __call__(self, *messages, **kwargs) -> AgentMessage: + """Send messages to the agent and return its response.""" + if not self._connected: + await self.connect() + + str_messages = [] + for m in messages: + if isinstance(m, AgentMessage): + str_messages.append(m.content if isinstance(m.content, str) else m.model_dump()) + else: + str_messages.append(str(m)) + + result = await self._daemon_call({ + "cmd": "chat", + "messages": str_messages, + }) + + if "error" in result: + return AgentMessage( + sender="SandboxAgent", + content=f"Agent error: {result['error']}", + ) + + return AgentMessage(**{ + k: v for k, v in result.items() + if k in AgentMessage.model_fields + }) + + async def get_state_dict(self) -> Dict: + """Get the agent's full state (memory + traces).""" + if not self._connected: + await self.connect() + result = await self._daemon_call({"cmd": "state_dict"}) + if "error" in result: + raise RuntimeError(result["error"]) + return result.get("state_dict", {}) + + async def load_state_dict(self, state_dict: Dict) -> None: + result = await self._daemon_call({ + "cmd": "load_state_dict", + "state_dict": state_dict, + }) + if "error" in result: + raise RuntimeError(result["error"]) + + async def reset(self, recursive: bool = True) -> None: + result = await self._daemon_call({ + "cmd": "reset", + "recursive": recursive, + }) + if "error" in result: + raise RuntimeError(result["error"]) diff --git a/lagent/serving/sandbox/daemon.py b/lagent/serving/sandbox/daemon.py new file mode 100644 index 00000000..6c558f0b --- /dev/null +++ b/lagent/serving/sandbox/daemon.py @@ -0,0 +1,466 @@ +"""Lagent Daemon — long-running process inside a sandbox that serves +lagent components over a Unix socket. + +Architecture:: + + BaseDaemon ← socket server + protocol + ping/shutdown + ├── ActionDaemon ← Level 1: action calls + list_tools + └── AgentDaemon ← Level 2: chat + state_dict + reset + +Protocol (length-prefixed JSON over Unix stream socket):: + + Request → 4-byte big-endian length + JSON payload + Response ← 4-byte big-endian length + JSON payload + +Usage:: + + # Level 1: actions only + python -m lagent.actions.action_daemon start \\ + --mode actions --config actions.json --sock /tmp/lagent.sock + + # Level 2: full agent + python -m lagent.actions.action_daemon start \\ + --mode agent --config agent.json --sock /tmp/lagent.sock + + # Call (same for both) + python -m lagent.actions.action_daemon call \\ + --sock /tmp/lagent.sock '{"cmd":"ping"}' +""" + +from __future__ import annotations + +import argparse +import asyncio +import inspect +import json +import logging +import os +import struct +import sys +from typing import Any, Dict, List, Optional, Union + +from lagent.actions.action_executor import ActionExecutor, AsyncActionExecutor +from lagent.actions.base_action import BaseAction +from lagent.schema import ActionReturn, ActionStatusCode, AgentMessage, dataclass2dict +from lagent.utils import create_object + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Wire protocol helpers +# --------------------------------------------------------------------------- + +_HEADER_FMT = "!I" # 4-byte unsigned big-endian +_HEADER_SIZE = struct.calcsize(_HEADER_FMT) +_MAX_MSG_SIZE = 64 * 1024 * 1024 # 64 MiB safety cap + + +async def _send_msg(writer: asyncio.StreamWriter, data: bytes) -> None: + writer.write(struct.pack(_HEADER_FMT, len(data))) + writer.write(data) + await writer.drain() + + +async def _recv_msg(reader: asyncio.StreamReader) -> Optional[bytes]: + header = await reader.readexactly(_HEADER_SIZE) + (length,) = struct.unpack(_HEADER_FMT, header) + if length > _MAX_MSG_SIZE: + raise ValueError(f"Message too large: {length} bytes") + return await reader.readexactly(length) + + +# --------------------------------------------------------------------------- +# BaseDaemon — socket server + protocol +# --------------------------------------------------------------------------- + + +class BaseDaemon: + """Base class: asyncio Unix-socket server with JSON protocol. + + Subclasses implement ``_dispatch(request)`` to handle domain-specific + commands. Common commands (ping, shutdown) are handled here. + + Parameters + ---------- + sock_path : str + Path for the Unix domain socket. + """ + + daemon_type: str = "base" + + def __init__(self, sock_path: str = "/tmp/lagent_action.sock"): + self.sock_path = sock_path + self._server: Optional[asyncio.AbstractServer] = None + + async def start(self) -> None: + """Start listening. Removes stale socket file if present.""" + if os.path.exists(self.sock_path): + os.unlink(self.sock_path) + self._server = await asyncio.start_unix_server( + self._handle_client, path=self.sock_path + ) + os.chmod(self.sock_path, 0o777) + logger.info("%s listening on %s", self.__class__.__name__, self.sock_path) + await self._server.serve_forever() + + async def stop(self) -> None: + if self._server: + self._server.close() + await self._server.wait_closed() + if os.path.exists(self.sock_path): + os.unlink(self.sock_path) + logger.info("%s stopped", self.__class__.__name__) + + async def _handle_client( + self, + reader: asyncio.StreamReader, + writer: asyncio.StreamWriter, + ) -> None: + try: + raw = await _recv_msg(reader) + request = json.loads(raw) + response = await self._dispatch(request) + await _send_msg(writer, json.dumps(response, ensure_ascii=False).encode()) + except asyncio.IncompleteReadError: + pass + except Exception as e: + logger.exception("Error handling client request") + try: + await _send_msg(writer, json.dumps({"error": str(e)}).encode()) + except Exception: + pass + finally: + writer.close() + await writer.wait_closed() + + async def _dispatch(self, request: dict) -> dict: + """Handle a request. Override in subclasses for domain logic.""" + cmd = request.get("cmd") + if cmd == "ping": + return {"status": "ok", "type": self.daemon_type} + if cmd == "shutdown": + async def _delayed_close(): + await asyncio.sleep(0.1) + if self._server: + self._server.close() + asyncio.create_task(_delayed_close()) + return {"status": "shutting_down"} + return {"error": f"Unknown command: {cmd}"} + + +# --------------------------------------------------------------------------- +# ActionDaemon — Level 1: action execution +# --------------------------------------------------------------------------- + + +class ActionDaemon(BaseDaemon): + """Serves an ``AsyncActionExecutor`` over Unix socket. + + Usage:: + + daemon = ActionDaemon( + actions=[ShellAction(), ReadFileAction(), ...], + ) + await daemon.start() + + Parameters + ---------- + actions : list + Action instances or config dicts. + sock_path : str + Unix socket path. + """ + + daemon_type = "action" + + def __init__( + self, + actions: Union[List[BaseAction], List[Dict]], + sock_path: str = "/tmp/lagent_action.sock", + ): + super().__init__(sock_path=sock_path) + for i, action in enumerate(actions): + actions[i] = create_object(action) + self.executor = AsyncActionExecutor(actions) + + async def _dispatch(self, request: dict) -> dict: + cmd = request.get("cmd") + + # Common commands + if cmd in ("ping", "shutdown"): + return await super()._dispatch(request) + + if cmd == "list_tools": + return {"tools": self.executor.description()} + + # Action call + name = request.get("name") + parameters = request.get("parameters", {}) + if not name: + return dataclass2dict(ActionReturn( + errmsg="Missing 'name' in request", + state=ActionStatusCode.ARGS_ERROR, + )) + + try: + action_return = await self.executor.forward(name, parameters) + except Exception as e: + logger.exception("Action %s failed", name) + action_return = ActionReturn( + args=parameters, type=name, + errmsg=str(e), state=ActionStatusCode.API_ERROR, + ) + return dataclass2dict(action_return) + + +# --------------------------------------------------------------------------- +# SkillsDaemon — skills loading +# --------------------------------------------------------------------------- + + +class SkillsDaemon(BaseDaemon): + """Serves a ``SkillsLoader`` over Unix socket. + + Usage:: + + daemon = SkillsDaemon( + skills_loader=SkillsLoader(workspace), + ) + await daemon.start() + + Parameters + ---------- + skills_loader : SkillsLoader + Skills loader for the sandbox workspace. + sock_path : str + Unix socket path. + """ + + daemon_type = "skills" + + def __init__(self, skills_loader, sock_path: str = "/tmp/lagent_skills.sock"): + super().__init__(sock_path=sock_path) + self.skills = skills_loader + + async def _dispatch(self, request: dict) -> dict: + cmd = request.get("cmd") + + if cmd in ("ping", "shutdown"): + return await super()._dispatch(request) + + if cmd == "list_skills": + filter_unavailable = request.get("filter_unavailable", True) + return {"skills": await self.skills.list_skills(filter_unavailable=filter_unavailable)} + + if cmd == "skills_summary": + return {"summary": await self.skills.build_skills_summary()} + + if cmd == "load_skill": + content = await self.skills.load_skill(request.get("name", "")) + return {"content": content} + + if cmd == "load_skills_for_context": + content = await self.skills.load_skills_for_context(request.get("names", [])) + return {"content": content} + + if cmd == "get_always_skills": + return {"skills": await self.skills.get_always_skills()} + + return {"error": f"Unknown command: {cmd}"} + + +# --------------------------------------------------------------------------- +# AgentDaemon — Level 2: full agent +# --------------------------------------------------------------------------- + + +class AgentDaemon(BaseDaemon): + """Serves a full ``Agent`` over Unix socket. + + The agent runs entirely inside the sandbox — LLM calls, action + execution, skills, memory — everything is local to the daemon. + + Usage:: + + daemon = AgentDaemon( + agent=InternClawAgent( + policy_agent=..., + env_agent=AsyncEnvAgent(actions=[ShellAction(), ...]), + ), + ) + await daemon.start() + + Parameters + ---------- + agent : Agent or dict + Agent instance or config dict (passed to ``create_object``). + sock_path : str + Unix socket path. + """ + + daemon_type = "agent" + + def __init__(self, agent, sock_path: str = "/tmp/lagent_action.sock"): + super().__init__(sock_path=sock_path) + self.agent = create_object(agent) + + async def _dispatch(self, request: dict) -> dict: + cmd = request.get("cmd") + + # Common commands + if cmd in ("ping", "shutdown"): + return await super()._dispatch(request) + + # Tool introspection (via EnvAgent if available) + if cmd == "list_tools": + env = getattr(self.agent, 'env_agent', None) + executor = getattr(env, 'actions', None) if env else None + if executor: + return {"tools": executor.description()} + return {"tools": []} + + # Agent commands + if cmd == "chat": + messages = request.get("messages", []) + try: + response = await self.agent(*messages) + return self._serialize_agent_message(response) + except Exception as e: + logger.exception("Agent chat failed") + return {"error": str(e)} + + if cmd == "state_dict": + try: + return {"state_dict": self.agent.state_dict()} + except Exception as e: + return {"error": str(e)} + + if cmd == "load_state_dict": + try: + self.agent.load_state_dict(request["state_dict"]) + return {"status": "ok"} + except Exception as e: + return {"error": str(e)} + + if cmd == "reset": + try: + self.agent.reset(recursive=request.get("recursive", True)) + return {"status": "ok"} + except Exception as e: + return {"error": str(e)} + + return {"error": f"Unknown command: {cmd}"} + + @staticmethod + def _serialize_agent_message(msg: AgentMessage) -> dict: + data = msg.model_dump() + if isinstance(msg.content, ActionReturn): + data["content"] = dataclass2dict(msg.content) + return data + + +# --------------------------------------------------------------------------- +# lagent-call: one-shot CLI client +# --------------------------------------------------------------------------- + + +def lagent_call(sock_path: str, request_json: str) -> str: + """Send a single request to the daemon and return the response JSON. + + Synchronous — suitable for CLI and subprocess usage. + """ + return asyncio.run(async_lagent_call(sock_path, request_json)) + + +async def async_lagent_call(sock_path: str, request_json: str) -> str: + """Async variant of :func:`lagent_call`.""" + reader, writer = await asyncio.open_unix_connection(sock_path) + try: + await _send_msg(writer, request_json.encode()) + raw = await _recv_msg(reader) + return raw.decode() + finally: + writer.close() + await writer.wait_closed() + + +# --------------------------------------------------------------------------- +# CLI entry point +# --------------------------------------------------------------------------- + + +def _load_config(config_path: str) -> Union[List[Dict], Dict]: + with open(config_path) as f: + return json.load(f) + + +def main(): + parser = argparse.ArgumentParser( + prog="lagent.actions.action_daemon", + description="Lagent Daemon: serve actions or agents over Unix socket", + ) + sub = parser.add_subparsers(dest="command") + + # -- start -- + p_start = sub.add_parser("start", help="Start the daemon") + p_start.add_argument( + "--sock", default="/tmp/lagent_action.sock", + help="Unix socket path", + ) + p_start.add_argument( + "--mode", choices=["actions", "agent"], default="actions", + help="'actions' for Level 1 (ActionDaemon), 'agent' for Level 2 (AgentDaemon)", + ) + p_start.add_argument( + "--config", + help="Path to JSON config (action list for Level 1, agent dict for Level 2)", + ) + # Backward compat + p_start.add_argument("--actions-config", help=argparse.SUPPRESS) + p_start.add_argument("--agent-config", help=argparse.SUPPRESS) + + # -- call -- + p_call = sub.add_parser("call", help="Send a one-shot request") + p_call.add_argument("--sock", default="/tmp/lagent_action.sock") + p_call.add_argument("request", help="JSON request string") + + args = parser.parse_args() + + if args.command == "start": + logging.basicConfig(level=logging.INFO, format="%(asctime)s %(name)s %(message)s") + + # Handle backward compat flags + mode = args.mode + config_path = args.config + if args.actions_config: + mode = "actions" + config_path = args.actions_config + elif args.agent_config: + mode = "agent" + config_path = args.agent_config + + if not config_path: + parser.error("Must provide --config (or --actions-config / --agent-config)") + + config = _load_config(config_path) + + if mode == "agent": + daemon = AgentDaemon(agent=config, sock_path=args.sock) + else: + daemon = ActionDaemon(actions=config, sock_path=args.sock) + + logger.info("Starting %s", daemon.__class__.__name__) + try: + asyncio.run(daemon.start()) + except (KeyboardInterrupt, asyncio.CancelledError): + logger.info("Interrupted, shutting down") + elif args.command == "call": + result = lagent_call(args.sock, args.request) + print(result, flush=True) + else: + parser.print_help() + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/lagent/serving/sandbox/providers/__init__.py b/lagent/serving/sandbox/providers/__init__.py new file mode 100644 index 00000000..cbc3244c --- /dev/null +++ b/lagent/serving/sandbox/providers/__init__.py @@ -0,0 +1,22 @@ +"""Sandbox providers — create, manage, and connect to sandbox environments.""" + +from .base import SandboxClient, SandboxProvider +from .gateway import GatewayProvider +from .local import LocalClient, LocalProvider + +__all__ = [ + "SandboxClient", + "SandboxProvider", + "GatewayProvider", + "LocalProvider", + "LocalClient", + "ClusterXProvider", +] + + +def __getattr__(name): + # Lazy import ClusterXProvider — clusterx is an optional dependency + if name == "ClusterXProvider": + from .clusterx import ClusterXProvider + return ClusterXProvider + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/lagent/serving/sandbox/providers/base.py b/lagent/serving/sandbox/providers/base.py new file mode 100644 index 00000000..c886e9b3 --- /dev/null +++ b/lagent/serving/sandbox/providers/base.py @@ -0,0 +1,110 @@ +"""Sandbox client and provider base definitions. + +SandboxClient — unified client for interacting with any sandbox. +SandboxProvider — protocol for sandbox lifecycle management. +""" + +from __future__ import annotations + +import base64 +import logging +import shutil +from typing import List, Protocol, Tuple, runtime_checkable + +import requests + +logger = logging.getLogger(__name__) + + +class SandboxClient: + """Unified HTTP client for sandbox interaction. + + Every provider returns a ``SandboxClient`` pointing to a sandbox's + HTTP API. The API contract (``/exec``, ``/upload``, ``/download``, + ``/health``) is the same regardless of the underlying infrastructure. + + Parameters + ---------- + url : str + Base URL of the sandbox HTTP API. + """ + + def __init__(self, url: str): + self.url = url.rstrip("/") + self.session = requests.Session() + self.session.headers.update({ + "Connection": "keep-alive", + "Content-Type": "application/json", + }) + + def execute(self, command: str, cwd: str = "/root", timeout_sec: int = 60) -> dict: + """Execute a bash command inside the sandbox.""" + resp = self.session.post( + f"{self.url}/exec", + json={"command": command, "cwd": cwd, "timeout_sec": timeout_sec}, + ) + resp.raise_for_status() + return resp.json() + + def upload_file(self, local_path: str, remote_path: str) -> dict: + """Upload a local file to the sandbox.""" + with open(local_path, "rb") as f: + content_b64 = base64.b64encode(f.read()).decode("utf-8") + resp = self.session.post( + f"{self.url}/upload", + json={"target_path": remote_path, "content_b64": content_b64}, + ) + resp.raise_for_status() + return resp.json() + + def download_file(self, remote_path: str) -> bytes: + """Download a file from the sandbox.""" + resp = self.session.post( + f"{self.url}/download", + json={"source_path": remote_path}, + ) + resp.raise_for_status() + data = resp.json() + if data.get("ok"): + return base64.b64decode(data["content_b64"]) + raise RuntimeError(data.get("error", "Download failed")) + + def health_check(self) -> dict: + """Check if the sandbox is alive.""" + try: + resp = self.session.get(f"{self.url}/health", timeout=5) + return resp.json() + except Exception as e: + return {"ok": False, "error": str(e)} + + def close(self): + """Close the HTTP session.""" + self.session.close() + + def __repr__(self): + return f"SandboxClient(url={self.url!r})" + + +@runtime_checkable +class SandboxProvider(Protocol): + """Protocol for sandbox lifecycle management. + + Different implementations manage different infrastructure: + k8s Gateway, ClusterX, Docker, local subprocess, etc. + """ + + def create(self, **kwargs) -> Tuple[SandboxClient, str]: + """Create a new sandbox. + + Returns + ------- + client : SandboxClient + Client connected to the new sandbox. + sandbox_id : str + Identifier for lifecycle management (delete, status). + """ + ... + + def delete(self, sandbox_id: str) -> None: + """Delete a sandbox.""" + ... diff --git a/lagent/serving/sandbox/providers/clusterx.py b/lagent/serving/sandbox/providers/clusterx.py new file mode 100644 index 00000000..adc97fed --- /dev/null +++ b/lagent/serving/sandbox/providers/clusterx.py @@ -0,0 +1,244 @@ +"""ClusterX sandbox provider. + +Submits a job to the ClusterX scheduler that runs a +:class:`~lagent.serving.sandbox.server.SandboxServer` inside the +allocated node. The shared filesystem means lagent code is already +accessible — no uploading needed. + +Usage:: + + from lagent.serving.sandbox.providers.clusterx import ClusterXProvider + + provider = ClusterXProvider(partition="puyu-gpu") + client, job_id = provider.create() + # client is a SandboxClient pointing to http://:8080 + result = client.execute("echo hello") + provider.delete(job_id) + +Requires the ``clusterx`` package and a valid cluster configuration. +""" + +from __future__ import annotations + +import logging +import time +from typing import Optional, Tuple + +from .base import SandboxClient + +logger = logging.getLogger(__name__) + +# Default port for SandboxServer inside the job +DEFAULT_PORT = 8080 + + +class ClusterXProvider: + """Manages sandboxes via ClusterX job scheduler. + + Each ``create()`` submits a job that starts a SandboxServer. + The server is accessible via the node IP on the shared network. + + Parameters + ---------- + partition : str + ClusterX partition to submit jobs to. + image : str, optional + Docker image for the job. + port : int + Port for the SandboxServer (default 8080). + server_module : str + Python module to run as the sandbox server. + conda_env : str, optional + Conda environment to activate before starting server. + python_path : str, optional + Extra PYTHONPATH to prepend. + extra_run_kwargs : dict, optional + Extra kwargs passed to the ClusterX run params (cpus_per_task, + memory_per_task, etc.). + """ + + def __init__( + self, + partition: Optional[str] = None, + image: Optional[str] = None, + port: int = DEFAULT_PORT, + server_module: str = "lagent.serving.sandbox.server", + conda_env: Optional[str] = None, + conda_activate_path: str = "/mnt/shared-storage-user/liukuikun/miniconda3/bin/activate", + python_path: Optional[str] = None, + extra_run_kwargs: Optional[dict] = None, + ): + self.port = port + self.server_module = server_module + self.conda_env = conda_env + self.conda_activate_path = conda_activate_path + self.python_path = python_path + self.extra_run_kwargs = extra_run_kwargs or {} + self._jobs: dict[str, dict] = {} + + # Lazy import clusterx + from clusterx import CLUSTER, CLUSTER_MAPPING + self._cluster_type = CLUSTER_MAPPING[CLUSTER]['type'] + self._params_type = CLUSTER_MAPPING[CLUSTER]['params'] + self._client = self._cluster_type() + + # Read defaults from clusterx config (~/.config/clusterx.yaml) + self._cluster_defaults = self._load_cluster_defaults(CLUSTER) + self.partition = partition or self._cluster_defaults.get('partition', '') + self.image = image or self._cluster_defaults.get('image') + + @staticmethod + def _load_cluster_defaults(cluster_name: str) -> dict: + """Read defaults from ~/.config/clusterx.yaml.""" + import os + try: + import yaml + config_path = os.path.expanduser("~/.config/clusterx.yaml") + if os.path.exists(config_path): + with open(config_path) as f: + config = yaml.safe_load(f) + return config.get(cluster_name, {}) or {} + except Exception: + pass + return {} + + def _build_cmd(self) -> str: + """Build the shell command that starts SandboxServer. + + Uses direct file execution (``python /path/to/server.py``) + instead of ``python -m`` to avoid triggering full lagent package + imports which may fail if optional dependencies are missing. + """ + parts = [] + if self.conda_env: + parts.append(f"source {self.conda_activate_path} {self.conda_env} &&") + # Resolve server.py path relative to this file + import lagent.serving.sandbox.server as _mod + server_path = _mod.__file__ + # Use stdlib backend to avoid fastapi/pydantic version issues + parts.append(f"python {server_path} --port {self.port} --backend stdlib") + return " ".join(parts) + + def create( + self, + job_name: Optional[str] = None, + timeout: int = 300, + poll_interval: int = 5, + **kwargs, + ) -> Tuple[SandboxClient, str]: + """Submit a job and wait for the SandboxServer to be ready. + + Parameters + ---------- + job_name : str, optional + Job name. Auto-generated if not provided. + timeout : int + Max seconds to wait for the server to be ready. + poll_interval : int + Seconds between status checks. + **kwargs + Override any ClusterX run params. + + Returns + ------- + client : SandboxClient + Connected to the SandboxServer. + job_id : str + Job identifier for lifecycle management. + """ + from clusterx.launcher.base import JobStatus + + # Build run params — merge config defaults + provider settings + overrides + run_kwargs = { + "partition": self.partition, + "cmd": self._build_cmd(), + "num_nodes": 1, + "tasks_per_node": 1, + "gpus_per_task": 0, + **self.extra_run_kwargs, + **kwargs, + } + if job_name: + run_kwargs["job_name"] = job_name + if self.image: + run_kwargs["image"] = self.image + # Apply mount from clusterx config if not explicitly set + if "mount" not in run_kwargs: + mounts = self._cluster_defaults.get("mount") + if mounts: + run_kwargs["data_sources"] = ",".join(mounts) if isinstance(mounts, list) else mounts + run_kwargs["image"] = self.image + + params = self._params_type.model_construct(**run_kwargs) + + # Submit job + status = self._client.run(params) + job_id = status.job_id + logger.info("Submitted ClusterX job: %s", job_id) + + # Wait for job to be RUNNING and get node IP + deadline = time.monotonic() + timeout + node_ip = None + while time.monotonic() < deadline: + info = self._client.get_job_info(job_id) + if info.status == JobStatus.RUNNING: + if info.nodes_ip: + node_ip = info.nodes_ip[0] + break + elif info.status in (JobStatus.FAILED, JobStatus.STOPPED): + raise RuntimeError( + f"ClusterX job {job_id} failed with status: {info.status}" + ) + time.sleep(poll_interval) + + if not node_ip: + self._client.stop(job_id=job_id) + raise TimeoutError( + f"ClusterX job {job_id} did not start within {timeout}s" + ) + + url = f"http://{node_ip}:{self.port}" + logger.info("Job %s running on %s", job_id, url) + + # Wait for SandboxServer health check + client = SandboxClient(url) + server_deadline = time.monotonic() + 60 # extra 60s for server startup + while time.monotonic() < server_deadline: + health = client.health_check() + if health.get("ok"): + break + time.sleep(2) + else: + self._client.stop(job_id=job_id) + raise TimeoutError( + f"SandboxServer at {url} did not become healthy within 60s" + ) + + self._jobs[job_id] = {"url": url, "node_ip": node_ip} + logger.info("ClusterX sandbox ready: job_id=%s, url=%s", job_id, url) + return client, job_id + + def delete(self, job_id: str) -> None: + """Stop a ClusterX job and release resources.""" + try: + self._client.stop(job_id=job_id) + logger.info("Stopped ClusterX job: %s", job_id) + except Exception as e: + logger.warning("Failed to stop job %s: %s", job_id, e) + self._jobs.pop(job_id, None) + + def get(self, job_id: str) -> dict: + """Get job status.""" + info = self._client.get_job_info(job_id) + return { + "job_id": info.job_id, + "status": str(info.status), + "nodes_ip": info.nodes_ip, + } + + def list(self): + """List tracked sandbox jobs.""" + return [ + {"job_id": jid, **info} + for jid, info in self._jobs.items() + ] diff --git a/lagent/serving/sandbox/providers/gateway.py b/lagent/serving/sandbox/providers/gateway.py new file mode 100644 index 00000000..37634a23 --- /dev/null +++ b/lagent/serving/sandbox/providers/gateway.py @@ -0,0 +1,103 @@ +"""k8s Gateway sandbox provider. + +Wraps the EnvGateway HTTP API for sandbox lifecycle management. +The Gateway already provides sandboxes with an HTTP API (``/exec``, +``/upload``, ``/download``), so we just return a standard +:class:`SandboxClient` pointing to the sandbox URL. +""" + +from __future__ import annotations + +import logging +from typing import Tuple + +import requests + +from lagent.serving.sandbox.providers.base import SandboxClient + +logger = logging.getLogger(__name__) + + +class GatewayProvider: + """Manages sandboxes via the EnvGateway HTTP API. + + Usage:: + + provider = GatewayProvider("http://env-gateway.ailab.ailab.ai") + client, env_id = provider.create(image_tag="hb_3d-scan-calc") + # client is a SandboxClient pointing to the sandbox URL + result = client.execute("echo hello") + provider.delete(env_id) + + Parameters + ---------- + gateway_url : str + Base URL of the EnvGateway service. + """ + + def __init__(self, gateway_url: str): + self.gateway_url = gateway_url.rstrip("/") + self.session = requests.Session() + self.session.headers.update({"Content-Type": "application/json"}) + + def create( + self, + image_tag: str, + ttl_seconds: int = 1800, + **kwargs, + ) -> Tuple[SandboxClient, str]: + """Create a new sandbox environment. + + Parameters + ---------- + image_tag : str + Docker image tag for the sandbox. + ttl_seconds : int + Time-to-live in seconds (default 30 min). + + Returns + ------- + client : SandboxClient + Client connected to the sandbox. + env_id : str + Environment ID for lifecycle management. + """ + resp = self.session.post( + f"{self.gateway_url}/envs", + json={"image_tag": image_tag, "ttl_seconds": ttl_seconds, **kwargs}, + timeout=120, + ) + resp.raise_for_status() + ret = resp.json() + if not ret.get("ok"): + raise RuntimeError(f"Failed to create sandbox: {ret.get('error', ret)}") + url = ret["env"]["url"] + env_id = ret["env"]["env_id"] + logger.info("Created sandbox: url=%s, env_id=%s", url, env_id) + client = SandboxClient(url) + for _ in range(300): + health_json = client.health_check() + if health_json['ok']: + return client, env_id + import time + time.sleep(2) + raise Exception + + + def delete(self, env_id: str) -> None: + """Delete a sandbox environment.""" + resp = self.session.delete( + f"{self.gateway_url}/envs/{env_id}", + timeout=30, + ) + resp.raise_for_status() + logger.info("Deleted sandbox: env_id=%s", env_id) + + def get(self, env_id: str) -> dict: + """Get sandbox status.""" + resp = self.session.get( + f"{self.gateway_url}/envs/{env_id}", + timeout=15, + ) + resp.raise_for_status() + return resp.json() diff --git a/lagent/serving/sandbox/providers/local.py b/lagent/serving/sandbox/providers/local.py new file mode 100644 index 00000000..fdb7f400 --- /dev/null +++ b/lagent/serving/sandbox/providers/local.py @@ -0,0 +1,115 @@ +"""Local sandbox provider — subprocess-based, no network. + +Used for testing and local development. Implements the same interface +as SandboxClient but executes commands via ``subprocess.run()`` and +copies files directly on the local filesystem. +""" + +from __future__ import annotations + +import os +import shutil +import subprocess +from typing import Tuple + +from .base import SandboxClient + + +class LocalClient: + """Sandbox client that runs commands locally via subprocess. + + Implements the same interface as :class:`SandboxClient` so it can + be used as a drop-in replacement. + """ + + def __init__(self, working_dir: str = "/tmp/lagent_sandbox"): + self.working_dir = working_dir + os.makedirs(working_dir, exist_ok=True) + + def execute(self, command: str, cwd: str = None, timeout_sec: int = 60) -> dict: + cwd = cwd or self.working_dir + try: + # Background commands (ending with &): use Popen so the + # background process survives after the shell exits. + if command.rstrip().endswith("&"): + subprocess.Popen( + ["bash", "-c", command], + cwd=cwd, start_new_session=True, + stdin=subprocess.DEVNULL, + ) + return { + "ok": True, + "stdout": "", + "stderr": "", + "return_code": 0, + } + result = subprocess.run( + command, shell=True, capture_output=True, text=True, + cwd=cwd, timeout=timeout_sec, + ) + return { + "ok": result.returncode == 0, + "stdout": result.stdout, + "stderr": result.stderr, + "return_code": result.returncode, + } + except subprocess.TimeoutExpired: + return { + "ok": False, + "stdout": "", + "stderr": f"Command timed out after {timeout_sec}s", + "return_code": 124, + } + + def upload_file(self, local_path: str, remote_path: str) -> dict: + os.makedirs(os.path.dirname(remote_path) or ".", exist_ok=True) + shutil.copy2(local_path, remote_path) + return {"ok": True, "target_path": remote_path} + + def download_file(self, remote_path: str) -> bytes: + with open(remote_path, "rb") as f: + return f.read() + + def health_check(self) -> dict: + return {"ok": True} + + def close(self): + pass + + def __repr__(self): + return f"LocalClient(working_dir={self.working_dir!r})" + + +class LocalProvider: + """Creates local sandbox environments (just a working directory). + + Usage:: + + provider = LocalProvider() + client, sandbox_id = provider.create() + # client is a LocalClient — same interface as SandboxClient + """ + + def __init__(self, base_dir: str = "/tmp/lagent_sandboxes"): + self.base_dir = base_dir + self._sandboxes = {} + self._counter = 0 + + def create(self, working_dir: str = None, **kwargs) -> Tuple[LocalClient, str]: + self._counter += 1 + sandbox_id = f"local-{self._counter}" + working_dir = working_dir or os.path.join(self.base_dir, sandbox_id) + client = LocalClient(working_dir=working_dir) + self._sandboxes[sandbox_id] = client + return client, sandbox_id + + def delete(self, sandbox_id: str) -> None: + client = self._sandboxes.pop(sandbox_id, None) + if client and os.path.exists(client.working_dir): + shutil.rmtree(client.working_dir, ignore_errors=True) + + def list(self): + return [ + {"sandbox_id": sid, "working_dir": c.working_dir} + for sid, c in self._sandboxes.items() + ] diff --git a/lagent/serving/sandbox/server.py b/lagent/serving/sandbox/server.py new file mode 100644 index 00000000..6391308e --- /dev/null +++ b/lagent/serving/sandbox/server.py @@ -0,0 +1,235 @@ +"""SandboxServer — lightweight HTTP API that turns any machine into a sandbox. + +Provides ``/exec``, ``/upload``, ``/download``, ``/health`` endpoints. +Supports two backends: FastAPI (if available) or stdlib http.server (fallback). + +Usage:: + + # Start server (auto-detects backend) + python -m lagent.serving.sandbox.server --port 8080 + + # Force stdlib backend (zero deps) + python -m lagent.serving.sandbox.server --port 8080 --backend stdlib + + # Or run the file directly (no package imports needed) + python /path/to/lagent/serving/sandbox/server.py --port 8080 +""" + +from __future__ import annotations + +import argparse +import base64 +import json +import logging +import os +import subprocess +from pathlib import Path + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# FastAPI backend +# --------------------------------------------------------------------------- + +def create_fastapi_app(): + """Create a FastAPI application.""" + from fastapi import FastAPI + from pydantic import BaseModel + + app = FastAPI(title="Lagent SandboxServer") + + class ExecRequest(BaseModel): + command: str + cwd: str = "/root" + timeout_sec: int = 60 + + class UploadRequest(BaseModel): + target_path: str + content_b64: str + + class DownloadRequest(BaseModel): + source_path: str + + @app.post("/exec") + def execute(req: ExecRequest): + try: + result = subprocess.run( + req.command, shell=True, capture_output=True, text=True, + cwd=req.cwd, timeout=req.timeout_sec, + ) + return { + "ok": result.returncode == 0, + "stdout": result.stdout, + "stderr": result.stderr, + "return_code": result.returncode, + } + except subprocess.TimeoutExpired: + return { + "ok": False, "stdout": "", + "stderr": f"Command timed out after {req.timeout_sec} seconds", + "return_code": 124, + } + except Exception as e: + return {"ok": False, "stdout": "", "stderr": str(e), "return_code": 1} + + @app.post("/upload") + def upload(req: UploadRequest): + try: + target = Path(req.target_path) + target.parent.mkdir(parents=True, exist_ok=True) + target.write_bytes(base64.b64decode(req.content_b64)) + return {"ok": True, "target_path": req.target_path, "size": target.stat().st_size} + except Exception as e: + return {"ok": False, "error": str(e)} + + @app.post("/download") + def download(req: DownloadRequest): + try: + data = Path(req.source_path).read_bytes() + return {"ok": True, "content_b64": base64.b64encode(data).decode("utf-8")} + except Exception as e: + return {"ok": False, "error": str(e)} + + @app.get("/health") + def health(): + return {"ok": True} + + return app + + +# --------------------------------------------------------------------------- +# Stdlib backend (zero deps fallback) +# --------------------------------------------------------------------------- + +def create_stdlib_server(host: str, port: int): + """Create an http.server based server (no third-party deps).""" + from http.server import HTTPServer, BaseHTTPRequestHandler + + class Handler(BaseHTTPRequestHandler): + def do_GET(self): + if self.path == "/health": + self._respond({"ok": True}) + else: + self._respond({"error": "Not found"}, 404) + + def do_POST(self): + body = self._read_body() + if body is None: + return + handlers = { + "/exec": self._handle_exec, + "/upload": self._handle_upload, + "/download": self._handle_download, + } + handler = handlers.get(self.path) + if handler: + handler(body) + else: + self._respond({"error": "Not found"}, 404) + + def _handle_exec(self, body): + command = body.get("command", "") + cwd = body.get("cwd", "/root") + timeout_sec = body.get("timeout_sec", 60) + try: + result = subprocess.run( + command, shell=True, capture_output=True, text=True, + cwd=cwd, timeout=timeout_sec, + ) + self._respond({ + "ok": result.returncode == 0, + "stdout": result.stdout, + "stderr": result.stderr, + "return_code": result.returncode, + }) + except subprocess.TimeoutExpired: + self._respond({ + "ok": False, "stdout": "", + "stderr": f"Command timed out after {timeout_sec}s", + "return_code": 124, + }) + except Exception as e: + self._respond({"ok": False, "stdout": "", "stderr": str(e), "return_code": 1}) + + def _handle_upload(self, body): + try: + target = Path(body["target_path"]) + target.parent.mkdir(parents=True, exist_ok=True) + target.write_bytes(base64.b64decode(body["content_b64"])) + self._respond({"ok": True, "target_path": str(target), "size": target.stat().st_size}) + except Exception as e: + self._respond({"ok": False, "error": str(e)}) + + def _handle_download(self, body): + try: + data = Path(body["source_path"]).read_bytes() + self._respond({"ok": True, "content_b64": base64.b64encode(data).decode("utf-8")}) + except Exception as e: + self._respond({"ok": False, "error": str(e)}) + + def _read_body(self): + try: + length = int(self.headers.get("Content-Length", 0)) + raw = self.rfile.read(length) + return json.loads(raw) if raw else {} + except Exception as e: + self._respond({"error": f"Bad request: {e}"}, 400) + return None + + def _respond(self, data, status=200): + body = json.dumps(data, ensure_ascii=False).encode("utf-8") + self.send_response(status) + self.send_header("Content-Type", "application/json") + self.send_header("Content-Length", str(len(body))) + self.end_headers() + self.wfile.write(body) + + def log_message(self, format, *args): + logger.debug("%s %s", self.address_string(), format % args) + + return HTTPServer((host, port), Handler) + + +# --------------------------------------------------------------------------- +# Entry point +# --------------------------------------------------------------------------- + +def main(): + parser = argparse.ArgumentParser( + prog="lagent.serving.sandbox.server", + description="SandboxServer: HTTP API for sandbox interaction", + ) + parser.add_argument("--host", default="0.0.0.0") + parser.add_argument("--port", type=int, default=8080) + parser.add_argument( + "--backend", choices=["auto", "fastapi", "stdlib"], default="auto", + help="Server backend: fastapi (uvicorn), stdlib (http.server), or auto-detect", + ) + args = parser.parse_args() + + logging.basicConfig(level=logging.INFO, format="%(asctime)s %(name)s %(message)s") + + backend = args.backend + if backend == "auto": + try: + import fastapi, uvicorn # noqa: F401 + backend = "fastapi" + except ImportError: + backend = "stdlib" + + if backend == "fastapi": + import uvicorn + logger.info("Starting SandboxServer (fastapi) on %s:%d", args.host, args.port) + uvicorn.run(create_fastapi_app(), host=args.host, port=args.port) + else: + server = create_stdlib_server(args.host, args.port) + logger.info("Starting SandboxServer (stdlib) on %s:%d", args.host, args.port) + try: + server.serve_forever() + except KeyboardInterrupt: + server.shutdown() + + +if __name__ == "__main__": + main() diff --git a/lagent/skills/__init__.py b/lagent/skills/__init__.py new file mode 100644 index 00000000..c6cc2873 --- /dev/null +++ b/lagent/skills/__init__.py @@ -0,0 +1,3 @@ +from .skills import SkillsLoader, SandboxSkillsBackend, FilesystemSkillsBackend + +__all____all__ = ["SkillsLoader", "SandboxSkillsBackend", "FilesystemSkillsBackend"] \ No newline at end of file diff --git a/lagent/skills/sandbox_skills.py b/lagent/skills/sandbox_skills.py new file mode 100644 index 00000000..75418358 --- /dev/null +++ b/lagent/skills/sandbox_skills.py @@ -0,0 +1,180 @@ +"""SandboxSkillsLoader — drop-in replacement for ``SkillsLoader`` that +routes all skill operations to a :class:`SkillsDaemon` running inside +a sandbox. + +Usage:: + + from lagent.skills.sandbox_skills import SandboxSkillsLoader + + skills = SandboxSkillsLoader( + sandbox_client=client, + workspace="/root/workspace", + ) + await skills.connect() + + summary = await skills.build_skills_summary() + always = await skills.get_always_skills() +""" + +from __future__ import annotations + +import asyncio +import inspect +import json +import logging +from typing import List, Optional + +logger = logging.getLogger(__name__) + + +class SandboxSkillsLoader: + """Drop-in for ``SkillsLoader`` that proxies to a ``SkillsDaemon``. + + Implements the same async interface as ``SkillsLoader``. + Manages its own daemon lifecycle (separate from ActionDaemon). + + Parameters + ---------- + sandbox_client + Object with ``execute(command, ...)`` method. + workspace : str + Workspace path inside the sandbox. + sock_path : str + Unix socket path for the SkillsDaemon inside sandbox. + daemon_module : str + Python module path for the daemon. + """ + + def __init__( + self, + sandbox_client, + workspace: str = "/root/workspace", + sock_path: str = "/tmp/lagent_skills.sock", + daemon_module: str = "lagent.serving.sandbox.daemon", + ): + self.sandbox_client = sandbox_client + self.workspace = workspace + self.sock_path = sock_path + self.daemon_module = daemon_module + self._connected = False + self._lock = asyncio.Lock() + + async def _exec(self, command: str, **kwargs) -> str: + kwargs.setdefault("cwd", "/tmp") + execute_fn = self.sandbox_client.execute + if inspect.iscoroutinefunction(execute_fn): + result = await execute_fn(command, **kwargs) + else: + result = await asyncio.to_thread(execute_fn, command, **kwargs) + if isinstance(result, dict): + return result.get("stdout", "") + return result + + async def _daemon_call(self, request: dict) -> dict: + request_json = json.dumps(request, ensure_ascii=False) + escaped = request_json.replace("'", "'\\''") + output = await self._exec( + f"python -m {self.daemon_module} call " + f"--sock {self.sock_path} " + f"'{escaped}'" + ) + return json.loads(output.strip()) + + async def connect(self) -> None: + """Start the SkillsDaemon inside sandbox (idempotent).""" + async with self._lock: + if self._connected: + return + + # Write a small Python script that starts SkillsDaemon + # with a SkillsLoader for the workspace + start_script = ( + f"from lagent.skills.skills import SkillsLoader; " + f"from lagent.serving.sandbox.daemon import SkillsDaemon; " + f"import asyncio; " + f"loader = SkillsLoader('{self.workspace}'); " + f"daemon = SkillsDaemon(skills_loader=loader, sock_path='{self.sock_path}'); " + f"asyncio.run(daemon.start())" + ) + escaped_script = start_script.replace("'", "'\\''") + await self._exec( + f"echo '{escaped_script}' > /tmp/lagent_start_skills_daemon.py" + ) + + # Check if already running + check = await self._exec( + f"pgrep -f '[l]agent_start_skills_daemon' " + f"> /dev/null 2>&1 && echo 'running' || echo 'stopped'" + ) + if "stopped" in check: + await self._exec( + f"nohup python /tmp/lagent_start_skills_daemon.py " + f"> /tmp/lagent_skills_daemon.log 2>&1 &" + ) + + # Wait for socket + for _ in range(30): + try: + output = await self._exec( + f"test -S {self.sock_path} && echo 'ready' || echo 'waiting'" + ) + if "ready" in output: + break + except Exception: + pass + await asyncio.sleep(0.5) + else: + raise TimeoutError( + f"SkillsDaemon did not start within 15s. " + f"Check /tmp/lagent_skills_daemon.log inside sandbox." + ) + + self._connected = True + logger.info("SandboxSkillsLoader connected") + + async def close(self) -> None: + if not self._connected: + return + try: + await self._daemon_call({"cmd": "shutdown"}) + except Exception: + pass + self._connected = False + + # -- SkillsLoader-compatible interface -- + + async def list_skills(self, filter_unavailable: bool = True) -> list: + if not self._connected: + await self.connect() + result = await self._daemon_call({ + "cmd": "list_skills", + "filter_unavailable": filter_unavailable, + }) + return result.get("skills", []) + + async def build_skills_summary(self) -> str: + if not self._connected: + await self.connect() + result = await self._daemon_call({"cmd": "skills_summary"}) + return result.get("summary", "") + + async def load_skill(self, name: str) -> Optional[str]: + if not self._connected: + await self.connect() + result = await self._daemon_call({"cmd": "load_skill", "name": name}) + return result.get("content") + + async def load_skills_for_context(self, skill_names: List[str]) -> str: + if not self._connected: + await self.connect() + result = await self._daemon_call({ + "cmd": "load_skills_for_context", + "names": skill_names, + }) + return result.get("content", "") + + async def get_always_skills(self) -> List[str]: + if not self._connected: + await self.connect() + result = await self._daemon_call({"cmd": "get_always_skills"}) + return result.get("skills", []) diff --git a/lagent/skills/skills.py b/lagent/skills/skills.py new file mode 100644 index 00000000..32a1c406 --- /dev/null +++ b/lagent/skills/skills.py @@ -0,0 +1,403 @@ +"""Skills loader for agent capabilities.""" + +import json +import os +import re +import shutil +from pathlib import Path +from typing import Any, Callable +import asyncio + +# Default builtin skills directory (relative to this file) +BUILTIN_SKILLS_DIR = Path(__file__).parent.parent / "skills" + + +class BaseSkillsBackend: + """Abstract backend for skill discovery and loading.""" + + async def list_skill_entries(self) -> list[dict[str, str]]: + raise NotImplementedError + + async def read_skill(self, name: str) -> str | None: + raise NotImplementedError + + +class FilesystemSkillsBackend(BaseSkillsBackend): + """Filesystem-backed skill backend.""" + + def __init__(self, workspace_skills: Path, builtin_skills: Path | None = None): + self.workspace_skills = workspace_skills + self.builtin_skills = builtin_skills + + async def list_skill_entries(self) -> list[dict[str, str]]: + skills: list[dict[str, str]] = [] + + if self.workspace_skills.exists(): + for skill_dir in self.workspace_skills.iterdir(): + if skill_dir.is_dir(): + skill_file = skill_dir / "SKILL.md" + if skill_file.exists(): + skills.append({"name": skill_dir.name, "path": str(skill_file), "source": "workspace"}) + + if self.builtin_skills and self.builtin_skills.exists(): + for skill_dir in self.builtin_skills.iterdir(): + if skill_dir.is_dir(): + skill_file = skill_dir / "SKILL.md" + if skill_file.exists() and not any(s["name"] == skill_dir.name for s in skills): + skills.append({"name": skill_dir.name, "path": str(skill_file), "source": "builtin"}) + + return skills + + async def read_skill(self, name: str) -> str | None: + workspace_skill = self.workspace_skills / name / "SKILL.md" + if workspace_skill.exists(): + return await asyncio.to_thread(workspace_skill.read_text, encoding="utf-8") + + if self.builtin_skills: + builtin_skill = self.builtin_skills / name / "SKILL.md" + if builtin_skill.exists(): + return await asyncio.to_thread(builtin_skill.read_text, encoding="utf-8") + + return None + + +class SandboxSkillsBackend(BaseSkillsBackend): + """Sandbox-backed skills storage that executes commands in a remote execution environment.""" + + def __init__( + self, + action: Any, + *, + workspace_root: str = ".", + builtin_skills: Path | None = None, + command_builder: Callable[[str, str], str] | None = None, + ): + self.action = action + self.workspace_root = workspace_root.rstrip("/") or "." + self.builtin_skills = builtin_skills + self.command_builder = command_builder or self._default_command_builder + + @staticmethod + def _b64_python(script: str) -> str: + """Encode a Python script as base64 and return a shell command that decodes and executes it. + + This avoids all quoting/escaping issues when running Python via a remote shell. + """ + import base64 + encoded = base64.b64encode(script.encode()).decode() + return f"echo {encoded} | base64 -d | python3" + + def _default_command_builder(self, op: str, target: str) -> str: + skills_root = f"{self.workspace_root}/skills" + if op == "list": + script = ( + "import json\n" + "from pathlib import Path\n" + f"root = Path('{skills_root}')\n" + "items = [\n" + " {'name': d.name, 'path': str(d / 'SKILL.md'), 'source': 'workspace'}\n" + " for d in (root.iterdir() if root.exists() else [])\n" + " if d.is_dir() and (d / 'SKILL.md').exists()\n" + "]\n" + "print(json.dumps(items, ensure_ascii=False))\n" + ) + return self._b64_python(script) + if op == "read": + skill_file = f"{skills_root}/{target}/SKILL.md" + return f"cat '{skill_file}'" + raise ValueError(f"Unsupported operation: {op}") + + async def _run(self, command: str) -> str | None: + """异步执行命令并解析输出""" + import json + from lagent.schema import ActionStatusCode + + result = await self.action.run(command=command) + if result.state != ActionStatusCode.SUCCESS: + return None + + try: + # 尝试解析 MCP 格式的返回 + if isinstance(result.result, list) and len(result.result) > 0: + content_str = result.result[0].get('content', '') + content_dict = json.loads(content_str) + if content_dict.get('exit_code') == 0: + return content_dict.get('stdout', '').strip() + else: + return None + except Exception: + pass + + return None + + async def list_skill_entries(self) -> list[dict[str, str]]: + """异步获取 skill 列表""" + import json + raw = await self._run(self.command_builder("list", "")) + skills: list[dict[str, str]] = [] + if raw: + # stdout 可能包含多行,尝试从每行中找到 JSON 数组 + for line in raw.splitlines(): + line = line.strip() + if not line or not line.startswith('['): + continue + try: + parsed = json.loads(line) + if isinstance(parsed, list): + skills = [ + s for s in parsed + if isinstance(s, dict) and 'name' in s + ] + break + except json.JSONDecodeError: + continue + # 合并内置 skills + if self.builtin_skills and self.builtin_skills.exists(): + for skill_dir in self.builtin_skills.iterdir(): + if skill_dir.is_dir(): + skill_file = skill_dir / "SKILL.md" + if skill_file.exists() and not any(s["name"] == skill_dir.name for s in skills): + skills.append({"name": skill_dir.name, "path": str(skill_file), "source": "builtin"}) + return skills + + async def read_skill(self, name: str) -> str | None: + """异步读取 skill 内容""" + content = await self._run(self.command_builder("read", name)) + if content: + return content + if self.builtin_skills: + builtin_skill = self.builtin_skills / name / "SKILL.md" + if builtin_skill.exists(): + return await asyncio.to_thread(builtin_skill.read_text, encoding="utf-8") + return None + + +class SkillsLoader: + """ + Loader for agent skills. + + Skills are markdown files (SKILL.md) that teach the agent how to use + specific tools or perform certain tasks. + """ + + def __init__(self, workspace: Path, builtin_skills_dir: Path | None = None): + self.workspace = Path(workspace) + self.workspace_skills = self.workspace / "skills" + self.builtin_skills = Path(builtin_skills_dir) if builtin_skills_dir else BUILTIN_SKILLS_DIR + self.backend: BaseSkillsBackend = FilesystemSkillsBackend(self.workspace_skills, self.builtin_skills) + + def bind_backend(self, backend: BaseSkillsBackend) -> None: + """Replace the default filesystem backend with a custom backend.""" + self.backend = backend + + async def list_skills(self, filter_unavailable: bool = True) -> list[dict[str, str]]: + """ + List all available skills. + + Args: + filter_unavailable: If True, filter out skills with unmet requirements. + + Returns: + List of skill info dicts with 'name', 'path', 'source'. + """ + skills = await self.backend.list_skill_entries() + + # Filter by requirements + if filter_unavailable: + filtered_skills = [] + for s in skills: + meta = await self._get_skill_meta(s["name"]) + if self._check_requirements(meta): + filtered_skills.append(s) + return filtered_skills + return skills + + async def load_skill(self, name: str) -> str | None: + """ + Load a skill by name. + + Args: + name: Skill name (directory name). + + Returns: + Skill content or None if not found. + """ + return await self.backend.read_skill(name) + + async def load_skills_for_context(self, skill_names: list[str]) -> str: + """ + Load specific skills for inclusion in agent context. + + Args: + skill_names: List of skill names to load. + + Returns: + Formatted skills content. + """ + parts = [] + for name in skill_names: + content = await self.load_skill(name) + if content: + content = self._strip_frontmatter(content) + parts.append(f"### Skill: {name}\n\n{content}") + + return "\n\n---\n\n".join(parts) if parts else "" + + async def build_skills_summary(self) -> str: + """ + Build a summary of all skills (name, description, path, availability). + + This is used for progressive loading - the agent can read the full + skill content using read_file when needed. + + Returns: + XML-formatted skills summary. + """ + all_skills = await self.list_skills(filter_unavailable=False) + if not all_skills: + return "" + + def escape_xml(s: str) -> str: + return s.replace("&", "&").replace("<", "<").replace(">", ">") + + lines = [""] + for s in all_skills: + name = escape_xml(s["name"]) + path = s["path"] + desc = escape_xml(await self._get_skill_description(s["name"])) + skill_meta = await self._get_skill_meta(s["name"]) + available = self._check_requirements(skill_meta) + + lines.append(f" ") + lines.append(f" {name}") + lines.append(f" {desc}") + lines.append(f" {path}") + + # Show missing requirements for unavailable skills + if not available: + missing = self._get_missing_requirements(skill_meta) + if missing: + lines.append(f" {escape_xml(missing)}") + + lines.append(" ") + lines.append("") + + return "\n".join(lines) + + def _get_missing_requirements(self, skill_meta: dict) -> str: + """Get a description of missing requirements.""" + missing = [] + requires = skill_meta.get("requires", {}) + for b in requires.get("bins", []): + if not shutil.which(b): + missing.append(f"CLI: {b}") + for env in requires.get("env", []): + if not os.environ.get(env): + missing.append(f"ENV: {env}") + return ", ".join(missing) + + async def _get_skill_description(self, name: str) -> str: + """Get the description of a skill from its frontmatter.""" + meta = await self.get_skill_metadata(name) + if meta and meta.get("description"): + return meta["description"] + return name # Fallback to skill name + + def _strip_frontmatter(self, content: str) -> str: + """Remove YAML frontmatter from markdown content.""" + if content.startswith("---"): + match = re.match(r"^---\n.*?\n---\n", content, re.DOTALL) + if match: + return content[match.end():].strip() + return content + + def _parse_internclaw_metadata(self, raw: str) -> dict: + """Parse skill metadata JSON from frontmatter (supports internclaw and openclaw keys).""" + try: + data = json.loads(raw) + return data.get("internclaw", data.get("openclaw", {})) if isinstance(data, dict) else {} + except (json.JSONDecodeError, TypeError): + return {} + + def _check_requirements(self, skill_meta: dict) -> bool: + """Check if skill requirements are met (bins, env vars).""" + requires = skill_meta.get("requires", {}) + for b in requires.get("bins", []): + if not shutil.which(b): + return False + for env in requires.get("env", []): + if not os.environ.get(env): + return False + return True + + async def _get_skill_meta(self, name: str) -> dict: + """Get internclaw metadata for a skill (cached in frontmatter).""" + meta = await self.get_skill_metadata(name) or {} + return self._parse_internclaw_metadata(meta.get("metadata", "")) + + async def get_always_skills(self) -> list[str]: + """Get skills marked as always=true that meet requirements.""" + result = [] + for s in await self.list_skills(filter_unavailable=True): + meta = await self.get_skill_metadata(s["name"]) or {} + skill_meta = self._parse_internclaw_metadata(meta.get("metadata", "")) + if skill_meta.get("always") or meta.get("always"): + result.append(s["name"]) + return result + + async def get_skill_metadata(self, name: str) -> dict | None: + """ + Get metadata from a skill's frontmatter. + + Args: + name: Skill name. + + Returns: + Metadata dict or None. + """ + content = await self.load_skill(name) + if not content: + return None + + if content.startswith("---"): + match = re.match(r"^---\n(.*?)\n---", content, re.DOTALL) + if match: + # Simple YAML parsing + metadata = {} + for line in match.group(1).split("\n"): + if ":" in line: + key, value = line.split(":", 1) + metadata[key.strip()] = value.strip().strip('"\'') + return metadata + + return None + + +if __name__ == '__main__': + from lagent.actions.mcp_client import AsyncMCPClientSandbox + import asyncio + import json + from pathlib import Path + init_dir = "/mnt/shared-storage-user/llmit/user/liukuikun/workspace/lagent/workspace" + + async def main(): + shell_action = AsyncMCPClientSandbox('http', url='http://simple-shell.ailab.ailab.ai/mcp', init_dir=init_dir) + home_path = await shell_action.run(command='ls -la') + home_path = json.loads(home_path.result[0]['content'])['cwd'] + + skill_loader = SkillsLoader(Path(home_path)) + backend = SandboxSkillsBackend(shell_action, workspace_root=os.path.join(home_path, 'workspace')) + skill_loader.bind_backend(backend) + + try: + skills = await skill_loader.list_skills(filter_unavailable=False) + print("Skills:") + for s in skills: + print(s) + except Exception as e: + print("Failed to parse skills:", e) + + print(await skill_loader.build_skills_summary()) + print(await skill_loader.load_skills_for_context(['weather'])) + + asyncio.run(main()) diff --git a/lagent/utils/config.py b/lagent/utils/config.py new file mode 100644 index 00000000..06137955 --- /dev/null +++ b/lagent/utils/config.py @@ -0,0 +1,74 @@ +import collections.abc +import importlib.util +import os +import types + +from benedict import benedict + + +class ConfigDict(dict): + + def __getattr__(self, item): + if item in self: + return self[item] + raise AttributeError(f"'ConfigDict' object has no attribute '{item}'") + + def __setattr__(self, key, value): + self[key] = value + + +class Config: + + @staticmethod + def fromfile(file_path): + config_dict = ConfigDict() + if not os.path.isfile(file_path): + raise FileNotFoundError(f'Config file not found: {file_path}') + + # Load the configuration file as a module + spec = importlib.util.spec_from_file_location('config_module', file_path) + config_module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(config_module) + + # Function to convert nested dictionaries to ConfigDict recursively + def convert_to_config_dict(d): + if isinstance(d, dict): + + config_dict = ConfigDict() + for key, value in d.items(): + if isinstance(value, dict): + config_dict[key] = convert_to_config_dict(value) + else: + config_dict[key] = value + return config_dict + else: + return d + + # Retrieve all attributes (variables) from the module + for attribute_name in dir(config_module): + if not attribute_name.startswith('__'): + config_dict[attribute_name] = convert_to_config_dict(getattr(config_module, attribute_name)) + for key, value in list(config_dict.items()): + if isinstance(value, (types.FunctionType, types.ModuleType)): + config_dict.pop(key) + return config_dict + + +def to_native_types(obj: benedict): + """ + 递归地将任何嵌套的自定义对象(如 benedict)和集合 + 转换为 Python 的原生类型(dict, list, tuple, set 等)。 + """ + # 1. 最优先处理类字典对象 + if isinstance(obj, dict): + return {key: to_native_types(value) for key, value in obj.items()} + + # 2. 处理非字符串的可迭代对象 (list, tuple, set 等) + # 必须先排除 str 和 bytes,因为它们也是可迭代的,但我们不希望遍历其字符。 + if isinstance(obj, collections.abc.Iterable) and not isinstance(obj, (str, bytes)): + # type(obj) -> 获取原始容器的类型 (如 list, tuple, 或 set) + # (...) -> 使用该类型的构造函数,重新创建一个新的、已转换内容的容器。 + return type(obj)(to_native_types(item) for item in obj) + + # 3. 如果不是以上任何情况(如 int, str, bool, None 等),直接返回值 + return obj diff --git a/project_template/README.md b/project_template/README.md new file mode 100644 index 00000000..a407de3f --- /dev/null +++ b/project_template/README.md @@ -0,0 +1,57 @@ +# Agent Project Template + +一个标准的 Lagent Agent 项目的参考结构。 + +## 目录结构 + +``` +my-agent/ +├── config.py # 唯一入口:完整的 PyConfig(必须) +├── agent.py # 自定义 Agent 子类(可选) +├── actions.py # 自定义 Actions(可选) +├── memory.py # 自定义 Memory(可选) +├── prompts.py # System prompt 和常量(可选) +└── skills/ # 技能文件目录(可选) + └── review.md +``` + +## 核心原则 + +1. **config.py 是唯一的完整表达** — 所有 Agent 配置最终都是一个 + `agent_config = dict(type=..., ...)` 的 PyConfig dict +2. **没有隐式假设** — 模型、Actions、Memory 的参数全部显式声明 +3. **create_object() 递归展开** — 嵌套 dict 中的每个 `type=...` + 都会被 `create_object()` 自动实例化 + +## 使用方式 + +```python +# 方式 1: AgentLoader 自动发现 +# 把项目放在 workspace/agents/my-agent/ 下即可 + +# 方式 2: 手动注册 +from lagent.interclaw.services.agent import AgentService +service = AgentService() +service.register("my-agent", agent_config=config.agent_config) + +# 方式 3: 直接实例化 +from lagent.utils import create_object +agent = create_object(config.agent_config) +response = await agent("Hello") +``` + +## 简化形式 + +如果你的 Agent 只需要换个 prompt 和 model,可以用 AGENT.md 代替整个项目: + +```markdown +--- +name: my-agent +description: A simple agent +model: gpt-4 +max_turns: 50 +--- +You are a helpful assistant. +``` + +但 AGENT.md 是语法糖,需要 preset/defaults 层展开为完整 PyConfig。 diff --git a/project_template/agents/__init__.py b/project_template/agents/__init__.py new file mode 100644 index 00000000..bffd3790 --- /dev/null +++ b/project_template/agents/__init__.py @@ -0,0 +1,75 @@ +"""Agent Project 入口 — build() 工厂函数。 + +当 Agent 有运行时依赖(AgentService、CompactAgent 等)时, +纯声明式的 config.py 无法完整表达。此时用 build() 来处理组装逻辑。 + +AgentLoader 发现 __init__.py 中有 build() 时,会将其存入 +AgentSpec.build,由 default_agent_factory 在 spawn 时调用。 + +签名: + async def build(spec: AgentSpec, task: str) -> Agent +""" + +from lagent.utils import create_object +from lagent.actions.subagent import AgentAction + +# 从同目录的 config.py 导入基础配置 +from .config import ( + agent_config, + name, + description, + max_turns, + background, +) + + +async def build(spec, task): + """组装完整的 InternClaw Agent,处理运行时依赖注入。 + + 调用链: + AgentService.spawn("internclaw-standard", task) + → default_agent_factory(spec, task) + → spec.build(spec, task) ← 就是这个函数 + + Parameters + ---------- + spec : AgentSpec + 由 AgentLoader 从 config.py 解析出的 spec。 + spec.agent_config 就是 config.py 中的 agent_config dict。 + task : str + 用户任务描述。 + + Returns + ------- + InternClawAgent + 完全组装好的 Agent 实例,可以直接 await agent(task)。 + """ + import copy + + # 1. 深拷贝 config,避免污染原始模板 + cfg = copy.deepcopy(spec.agent_config or agent_config) + + # 2. 先创建核心 Agent(不含运行时依赖的 actions) + agent = create_object(cfg) + + # 3. 注入运行时依赖 + # 这些组件需要 agent 实例或 agent_service 才能创建 + # + # 示例: 注入 CompactAgent + LongTermMemory + # + # from lagent.agents.compact_agent import AsyncCompactAgent + # from lagent.memory.long_term import FilesystemLongTermMemory + # from lagent.actions.save_memory import SaveMemoryAction + # + # ltm = FilesystemLongTermMemory(workspace_path) + # compact = AsyncCompactAgent(llm=model) + # agent.compact_agent = compact + # agent.long_term_memory = ltm + # agent.env_agent.long_term_memory = ltm + # + # # For consolidation, create a standard InternClawAgent + # # with SaveMemoryAction + consolidation prompt + # save_action = SaveMemoryAction(ltm) + # # ... configure consolidate_agent as InternClawAgent instance + + return agent diff --git a/project_template/agents/actions.py b/project_template/agents/actions.py new file mode 100644 index 00000000..e52103ad --- /dev/null +++ b/project_template/agents/actions.py @@ -0,0 +1,105 @@ +"""自定义 Actions 示例。 + +Action 是 Agent 与外部世界交互的接口。每个 @tool_api 方法 +都会被自动转换为 LLM 可调用的工具。 + +两种模式: + - 单工具 Action: 只有一个 run() 方法 + - 工具箱 (Toolkit): 多个 @tool_api 方法,LLM 看到的工具名 + 是 "ClassName.method_name" 格式 + +注意: 不要在 Action 文件中使用 `from __future__ import annotations`, +否则会破坏 @tool_api 的类型解析。 +""" + +from typing import Annotated, Optional + +from lagent.actions.base_action import BaseAction, tool_api +from lagent.schema import ActionReturn, ActionStatusCode + + +class FileAnalyzer(BaseAction): + """示例:文件分析工具箱(Toolkit 模式)。 + + LLM 看到的工具列表: + - FileAnalyzer.count_lines + - FileAnalyzer.search_pattern + """ + + @tool_api + def count_lines( + self, + file_path: Annotated[str, "要分析的文件路径"], + ) -> ActionReturn: + """统计文件行数。""" + try: + with open(file_path) as f: + count = sum(1 for _ in f) + return ActionReturn( + result=[{"type": "text", "content": f"{count} lines"}], + state=ActionStatusCode.SUCCESS, + ) + except Exception as e: + return ActionReturn( + errmsg=str(e), + state=ActionStatusCode.API_ERROR, + ) + + @tool_api + def search_pattern( + self, + file_path: Annotated[str, "要搜索的文件路径"], + pattern: Annotated[str, "正则表达式模式"], + max_results: Annotated[int, "最大返回数量"] = 10, + ) -> ActionReturn: + """在文件中搜索匹配正则表达式的行。""" + import re + + try: + matches = [] + with open(file_path) as f: + for i, line in enumerate(f, 1): + if re.search(pattern, line): + matches.append(f"L{i}: {line.rstrip()}") + if len(matches) >= max_results: + break + result = "\n".join(matches) if matches else "No matches found" + return ActionReturn( + result=[{"type": "text", "content": result}], + state=ActionStatusCode.SUCCESS, + ) + except Exception as e: + return ActionReturn( + errmsg=str(e), + state=ActionStatusCode.API_ERROR, + ) + + +class HealthCheck(BaseAction): + """示例:单工具 Action(非 Toolkit 模式)。 + + LLM 看到的工具名就是 "HealthCheck"。 + """ + + @tool_api + def run( + self, + service_url: Annotated[str, "要检查的服务 URL"], + timeout: Annotated[int, "超时秒数"] = 5, + ) -> ActionReturn: + """检查一个 HTTP 服务是否可达。""" + import urllib.request + + try: + req = urllib.request.Request(service_url, method="HEAD") + with urllib.request.urlopen(req, timeout=timeout) as resp: + status = resp.status + return ActionReturn( + result=[{"type": "text", "content": f"OK (HTTP {status})"}], + state=ActionStatusCode.SUCCESS, + ) + except Exception as e: + return ActionReturn( + result=[{"type": "text", "content": f"FAIL: {e}"}], + state=ActionStatusCode.SUCCESS, + ) diff --git a/project_template/agents/agent.py b/project_template/agents/agent.py new file mode 100644 index 00000000..37062b2e --- /dev/null +++ b/project_template/agents/agent.py @@ -0,0 +1,41 @@ +"""自定义 Agent 子类示例。 + +大多数情况下不需要自定义 Agent 子类 — 直接用 InternClawAgent 配合 +不同的 config 就够了。只有在需要修改核心循环逻辑时才继承。 + +典型场景: + - 自定义 finish_condition + - 在 Policy-Env 循环中插入额外逻辑(如 reward 计算) + - 修改 forward() 的消息路由方式 +""" + +from typing import Optional + +from lagent.agents.internclaw_agent import InternClawAgent +from lagent.schema import AgentMessage + + +class MyAgent(InternClawAgent): + """示例:带自定义终止条件的 InternClaw Agent。""" + + def __init__( + self, + *args, + stop_words: Optional[list] = None, + **kwargs, + ): + # 覆盖 finish_condition + kwargs.setdefault("finish_condition", self._should_stop) + super().__init__(*args, **kwargs) + self.stop_words = stop_words or ["TASK_COMPLETE"] + + def _should_stop(self, policy_msg, env_msg) -> bool: + """当 Policy 不再调用工具,或输出包含 stop word 时终止。""" + if policy_msg is None: + return False + # 无工具调用 → 结束 + if not policy_msg.tool_calls: + return True + # 包含 stop word → 结束 + content = policy_msg.content or "" + return any(w in content for w in self.stop_words) diff --git a/project_template/agents/config.py b/project_template/agents/config.py new file mode 100644 index 00000000..b2b09ca0 --- /dev/null +++ b/project_template/agents/config.py @@ -0,0 +1,85 @@ +"""Agent Project Config — PyConfig entry point. + +The only required file for an agent project. +AgentLoader reads ``agent_config`` and calls ``build(agent_config)`` +to create the Agent instance. + +Required exports +---------------- +agent_config : dict + A PyConfig dict recognized by ``create_object()``. + +Optional exports +---------------- +name : str — Agent type name (defaults to directory name). +description : str — One-line description. +max_turns : int — Max interaction turns (default 500). +background : bool — Run in background (default False). +build : callable — Custom build function: ``(config_dict) -> Agent``. + When unset, defaults to ``create_object(agent_config)``. +""" + +from lagent.agents.internclaw_agent import ( + AsyncEnvAgent, + AsyncPolicyAgent, + InternClawAgent, +) +from lagent.llms.model import AsyncAPIClient +from lagent.agents.aggregator.context import InternClawContextBuilder +from lagent.actions.mcp_client import AsyncMCPClientSandbox + +# ── Metadata ────────────────────────────────────────────────────────── + +name = "internclaw-standard" +description = "Standard InternClaw Agent: Policy-Env dual loop with sandbox" +background = False + +# ── Sub-component configs ───────────────────────────────────────────── + +llm = dict( + type=AsyncAPIClient, + model=dict( + model="gpt-5.4", + base_url="http://35.220.164.252:3888/v1", + api_key=" ", + proxy="http://100.100.72.89:8899", + ), + sample_params=dict(temperature=0.7, top_p=1.0, top_k=50), + timeout=600, + max_retry=50, +) + +sandbox_action = dict( + type=AsyncMCPClientSandbox, + server_type='http', + url="http://simple-shell.ailab.ailab.ai/mcp", +) + +# ── Full Agent Config ───────────────────────────────────────────────── + +agent_config = dict( + type=InternClawAgent, + policy_agent=dict( + type=AsyncPolicyAgent, + llm=llm, + aggregator=dict(type=InternClawContextBuilder), + name="policy", + ), + env_agent=dict( + type=AsyncEnvAgent, + actions=[sandbox_action], + name="env", + ), + max_turn=500, + # workspace 默认 cwd,运行时可覆盖 + # finish_condition 用默认值(无 tool_calls 时停止) +) + +# build = None (default: create_object) +# To add runtime dependencies like CompactAction, define a custom build: +# +# async def build(config): +# from lagent.utils import create_object +# agent = create_object(config) +# # ... inject runtime deps ... +# return agent diff --git a/project_template/agents/memory.py b/project_template/agents/memory.py new file mode 100644 index 00000000..8080f654 --- /dev/null +++ b/project_template/agents/memory.py @@ -0,0 +1,57 @@ +"""自定义 LongTermMemory 示例。 + +Memory 在 Lagent 中有清晰的分层: + +1. **Agent.memory (Memory)** — 短期记忆,存储当前对话的消息列表。 + 这是 Agent 基类自带的,通常不需要自定义。 + +2. **LongTermMemory** — 长期记忆存储服务(纯数据,不依赖 LLM)。 + 由 EnvAgent 持有用于读取(get_info → env_info['memory']), + 由 orchestrator 编排写入(通过 SaveMemoryAction 或直接调用)。 + +3. **CompactAgent** — 上下文压缩 Agent,token 超阈值时压缩对话。 + +4. **ConsolidateAgent** — 就是一个配了 consolidation prompt + + SaveMemoryAction 的 InternClawAgent 实例,不需要单独的类。 + +内置实现: + - FilesystemLongTermMemory: 本地 MEMORY.md + HISTORY.md + - SandboxLongTermMemory: 远程沙箱存储 + - (可扩展) MemoryOS / OpenViking 等外部库包装 +""" + +from typing import Any + +from lagent.memory.long_term import LongTermMemory + + +class ProjectLongTermMemory(LongTermMemory): + """示例:带项目上下文的 LongTermMemory。 + + 在标准的持久化记忆之外,额外注入一个 "项目知识" 摘要, + 让 Agent 在长对话中始终记得项目的关键信息。 + """ + + def __init__( + self, + backend_ltm: LongTermMemory, + project_context: str = "", + ): + self._inner = backend_ltm + self._project_context = project_context + + async def get_info(self) -> dict[str, Any]: + """返回注入 system prompt 的记忆信息。""" + info = await self._inner.get_info() + if self._project_context: + info["project_context"] = self._project_context + return info + + async def read(self) -> str: + return await self._inner.read() + + async def write(self, content: str) -> None: + await self._inner.write(content) + + async def append_history(self, entry: str) -> None: + await self._inner.append_history(entry) diff --git a/project_template/agents/prompts.py b/project_template/agents/prompts.py new file mode 100644 index 00000000..0404b997 --- /dev/null +++ b/project_template/agents/prompts.py @@ -0,0 +1,31 @@ +"""Prompt 常量和模板。 + +将 system prompt 独立成文件,方便修改和版本管理。 +config.py 中通过 `from .prompts import SYSTEM_PROMPT` 引用。 +""" + +SYSTEM_PROMPT = """\ +You are an expert software engineer. You have access to tools for reading, +writing, and executing code in a sandboxed environment. + +## Guidelines + +1. **Think step by step** before taking action. +2. **Read before writing** — understand existing code before modifying. +3. **Test your changes** — run the code after modifications. +4. **Be concise** — explain what you did, not what you're about to do. + +## Project Context + +{project_context} +""" + +COMPACT_PROMPT = """\ +Summarize the conversation so far. Focus on: +1. What task was requested +2. What has been accomplished +3. What files were modified +4. What remains to be done + +Be concise but preserve all actionable context. +""" diff --git a/project_template/agents/skills/code-review/code-review.md b/project_template/agents/skills/code-review/code-review.md new file mode 100644 index 00000000..f4782571 --- /dev/null +++ b/project_template/agents/skills/code-review/code-review.md @@ -0,0 +1,24 @@ +--- +name: code-review +description: Code review best practices and checklist +--- + +# Code Review Skill + +When reviewing code, check the following: + +## Security +- No hardcoded secrets or API keys +- Input validation on all user-facing endpoints +- SQL injection / XSS prevention + +## Quality +- Functions are < 50 lines +- No duplicated logic (DRY) +- Meaningful variable names +- Error handling for all I/O operations + +## Testing +- New functions have corresponding tests +- Edge cases are covered +- Test names describe the expected behavior diff --git a/project_template/tasks/instruction.md b/project_template/tasks/instruction.md new file mode 100644 index 00000000..e69de29b diff --git a/tests/data/agents/cli-echo/config.py b/tests/data/agents/cli-echo/config.py new file mode 100644 index 00000000..11eaa10e --- /dev/null +++ b/tests/data/agents/cli-echo/config.py @@ -0,0 +1,14 @@ +"""Example CLI agent config that wraps ``echo`` for testing.""" +from lagent.adapters.cli_adapter import CLIAgentAdapter + +name = "cli-echo" +description = "Echo agent for testing — returns the task as-is" +background = False + +agent_config = dict( + type=CLIAgentAdapter, + name="cli-echo", + description="Echo agent for testing", + command_template="echo '{task}'", + timeout=10, +) diff --git a/tests/data/agents/e2e-agent/config.py b/tests/data/agents/e2e-agent/config.py new file mode 100644 index 00000000..8c7256bf --- /dev/null +++ b/tests/data/agents/e2e-agent/config.py @@ -0,0 +1,50 @@ +"""E2E test agent — real InternClawAgent with LLM. + +Used by integration tests that need a real LLM call. +Requires network access to the API endpoint. +""" + +from pathlib import Path + +from lagent.agents.aggregator.context import InternClawContextBuilder +from lagent.agents.internclaw_agent import ( + AsyncEnvAgent, + AsyncPolicyAgent, + InternClawAgent, +) +from lagent.llms.model import AsyncAPIClient + +name = "e2e-agent" +description = "E2E test agent with real LLM" +background = False + +llm = dict( + type=AsyncAPIClient, + model=dict( + model="gpt-5.4", + base_url="http://35.220.164.252:3888/v1", + api_key=" ", + proxy="http://100.100.72.89:8899", + ), + sample_params=dict(temperature=0.1), + timeout=60, + max_retry=3, + sleep_interval=1, +) + +agent_config = dict( + type=InternClawAgent, + policy_agent=dict( + type=AsyncPolicyAgent, + llm=llm, + aggregator=dict(type=InternClawContextBuilder, workspace=Path("/tmp")), + name="policy", + ), + env_agent=dict( + type=AsyncEnvAgent, + actions=[], + name="env", + ), + max_turn=3, + finish_condition=lambda m, _: m is not None and not m.tool_calls, +) diff --git a/tests/data/agents/simple-agent/config.py b/tests/data/agents/simple-agent/config.py new file mode 100644 index 00000000..fc9a5c0e --- /dev/null +++ b/tests/data/agents/simple-agent/config.py @@ -0,0 +1,32 @@ +"""Simple test agent — no LLM, no network. + +A minimal AsyncAgent that echoes back the input. +Used by unit tests that don't need real LLM calls. +""" + +from lagent.agents.agent import AsyncAgent +from lagent.schema import AgentMessage + + +class EchoAgent(AsyncAgent): + """Agent that echoes input — no LLM needed.""" + + async def forward(self, *message, **kwargs): + text = " ".join( + m.content if isinstance(m, AgentMessage) else str(m) + for m in message + ) + return AgentMessage( + sender=self.name, + content=f"echo: {text}", + ) + + +name = "simple-agent" +description = "Echo agent for unit tests (no LLM)" +background = False + +agent_config = dict( + type=EchoAgent, + name="simple-agent", +) diff --git a/tests/run_e2e.py b/tests/run_e2e.py new file mode 100644 index 00000000..57777bdd --- /dev/null +++ b/tests/run_e2e.py @@ -0,0 +1,70 @@ +"""E2E: Load agent from project_template/config.py and run a task.""" + +import asyncio +import sys +from pathlib import Path + +# Ensure lagent is importable +sys.path.insert(0, str(Path(__file__).resolve().parents[2])) + +from lagent.utils import create_object +from lagent.utils.config import Config + + +async def main(): + # ── Step 1: Load config ────────────────────────────────────────── + config_path = Path(__file__).parent / "project_template" / "config.py" + print(f"Loading config from: {config_path}") + + cfg = Config.fromfile(str(config_path)) + agent_config = cfg.agent_config + build_fn = getattr(cfg, 'build', None) + + print(f" name: {cfg.name}") + print(f" agent type: {agent_config['type'].__name__}") + print(f" build: {build_fn or 'create_object (default)'}") + + # ── Step 2: Build agent ────────────────────────────────────────── + # Inject workspace path into aggregator config + import copy + agent_config = copy.deepcopy(agent_config) + workspace = Path(__file__).resolve().parents[2] / "workspace" + agent_config['policy_agent']['aggregator']['workspace'] = workspace + + if build_fn and callable(build_fn): + agent = await build_fn(agent_config) + else: + agent = create_object(agent_config) + + print(f" agent created: {agent.__class__.__name__}") + print(f" policy: {agent.policy_agent.__class__.__name__}") + print(f" env: {agent.env_agent.__class__.__name__}") + print(f" actions: {list(agent.env_agent.actions.actions.keys())}") + + # ── Step 3: Run a simple task ──────────────────────────────────── + task = "List the files in the current directory using ls -la" + print(f"\n{'='*60}") + print(f"Task: {task}") + print(f"{'='*60}\n") + + try: + response = await agent(task) + print(f"\n{'='*60}") + print(f"Response: {response.content[:500] if response.content else 'empty'}") + print(f"Finish reason: {getattr(response, 'finish_reason', 'N/A')}") + print(f"{'='*60}") + except Exception as e: + print(f"Error: {e}") + import traceback + traceback.print_exc() + finally: + # Cleanup MCP connections + for action in agent.env_agent.actions.actions.values(): + if hasattr(action, 'close'): + await action.close() + + print("\n✅ E2E complete") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/tests/test_actions/debug_agent_daemon.py b/tests/test_actions/debug_agent_daemon.py new file mode 100644 index 00000000..c72d6ec7 --- /dev/null +++ b/tests/test_actions/debug_agent_daemon.py @@ -0,0 +1,300 @@ +"""Manual debug script for AgentDaemon E2E on Gateway sandbox. + +Usage:: + + /mnt/llm-ai-infra/miniconda3/envs/train/bin/python tests/test_actions/debug_agent_daemon.py + + # Custom gateway/image + /mnt/llm-ai-infra/miniconda3/envs/train/bin/python tests/test_actions/debug_agent_daemon.py \\ + --gateway http://env-gateway.ailab.ailab.ai \\ + --image hb_3d-scan-calc \\ + --ttl 600 +""" + +import argparse +import asyncio +import base64 +import io +import json +import os +import sys +import tarfile +import time + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../..")) + +from lagent.serving.sandbox.providers.gateway import GatewayProvider +from lagent.serving.sandbox.agent import SandboxAgent +from workspace.agents.default_agent.config import agent_config + +LAGENT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")) + +# Minimal __init__.py to avoid optional dep issues +MINIMAL_ACTIONS_INIT = ( + "from .action_executor import ActionExecutor, AsyncActionExecutor\n" + "from .base_action import AsyncActionMixin, BaseAction, tool_api\n" + "from .builtin_actions import FinishAction, InvalidAction, NoAction\n" + "from .parser import BaseParser, JsonParser, TupleParser\n" +) +MINIMAL_HOOKS_INIT = "from .hook import Hook, RemovableHandle\n" + +# All deps needed by lagent in the sandbox +ALL_DEPS = ( + "griffe termcolor asyncer func_timeout openai jinja2 tiktoken " + "aiohttp tenacity pydantic requests json5 jsonschema timeout-decorator" +) + + +def upload_text(client, path, content): + b64 = base64.b64encode(content.encode()).decode() + client.session.post(f"{client.url}/upload", json={ + "target_path": path, "content_b64": b64 + }) + + +def upload_lagent_source(client): + """Tar and upload lagent source + workspace to sandbox. + + - lagent source → /tmp/lagent/ (for PYTHONPATH) + - workspace → /root/workspace/ (matches agent config) + """ + # 1. Upload lagent source code → /tmp/ + buf = io.BytesIO() + with tarfile.open(fileobj=buf, mode="w:gz") as tar: + for d in ["lagent", "requirements"]: + src = os.path.join(LAGENT_PATH, d) + if not os.path.exists(src): + continue + for root, dirs, files in os.walk(src): + dirs[:] = [x for x in dirs if x != "__pycache__"] + for fname in files: + if fname.endswith(".pyc"): + continue + full = os.path.join(root, fname) + tar.add(full, arcname=os.path.relpath(full, LAGENT_PATH)) + content_b64 = base64.b64encode(buf.getvalue()).decode() + client.session.post(f"{client.url}/upload", json={ + "target_path": "/tmp/lagent_src.tar.gz", "content_b64": content_b64 + }) + print(f" Source: {len(buf.getvalue()) / 1024:.0f} KB → /tmp/") + client.execute("cd /tmp && tar xzf lagent_src.tar.gz") + + # 2. Upload workspace → /root/workspace/ (matches config) + ws_src = os.path.join(LAGENT_PATH, "workspace") + if os.path.exists(ws_src): + buf2 = io.BytesIO() + with tarfile.open(fileobj=buf2, mode="w:gz") as tar: + for root, dirs, files in os.walk(ws_src): + dirs[:] = [x for x in dirs if x != "__pycache__"] + for fname in files: + full = os.path.join(root, fname) + # arcname: workspace/skills/... → skills/... + arcname = os.path.relpath(full, ws_src) + tar.add(full, arcname=arcname) + content_b64 = base64.b64encode(buf2.getvalue()).decode() + client.session.post(f"{client.url}/upload", json={ + "target_path": "/root/workspace.tar.gz", "content_b64": content_b64 + }) + client.execute("mkdir -p /root/workspace && cd /root/workspace && tar xzf /root/workspace.tar.gz") + print(f" Workspace: {len(buf2.getvalue()) / 1024:.0f} KB → /root/workspace/") + + +def setup_sandbox(client): + """Upload lagent, fix __init__.py, install deps.""" + print("[1/3] Uploading lagent source + workspace...") + upload_lagent_source(client) + + print("[2/3] Fixing __init__.py (minimal imports)...") + upload_text(client, "/tmp/lagent/actions/__init__.py", MINIMAL_ACTIONS_INIT) + upload_text(client, "/tmp/lagent/hooks/__init__.py", MINIMAL_HOOKS_INIT) + + print("[3/3] Installing deps...") + r = client.execute(f"/mnt/llm-ai-infra/miniconda3/envs/train/bin/python -m pip list 2>&1 ", timeout_sec=300) + print(f" {r.get('stdout', '').strip()[-120:]}") + + # Ensure workspace dirs exist (in case workspace upload didn't have them) + client.execute("mkdir -p /root/workspace/memory /root/workspace/skills") + + +def verify_agent_creation(client): + """Test that the agent config can be instantiated inside sandbox.""" + print("\n--- Verify agent creation ---") + upload_text(client, "/tmp/agent_config.json", json.dumps(agent_config, ensure_ascii=False)) + upload_text(client, "/tmp/test_create.py", """ +import json, traceback +from lagent.utils import create_object +config = json.load(open("/tmp/agent_config.json")) +try: + agent = create_object(config) + print("OK:", type(agent).__name__) + print("PolicyAgent:", type(agent.policy_agent).__name__) + print("EnvAgent:", type(agent.env_agent).__name__) + if hasattr(agent.env_agent, 'actions'): + print("Tools:", list(agent.env_agent.actions.actions.keys())) +except: + traceback.print_exc() +""") + r = client.execute( + "PYTHONPATH=/tmp:$PYTHONPATH /mnt/llm-ai-infra/miniconda3/envs/train/bin/python /tmp/test_create.py 2>&1", + timeout_sec=60, + ) + # Read from file if stdout is truncated + client.execute( + "PYTHONPATH=/tmp:$PYTHONPATH /mnt/llm-ai-infra/miniconda3/envs/train/bin/python /tmp/test_create.py > /tmp/create_result.txt 2>&1", + timeout_sec=60, + ) + data = client.download_file("/tmp/create_result.txt") + result = data.decode() + print(result) + return "OK:" in result + + +def start_agent_daemon(client, sock_path="/tmp/lagent_agent.sock"): + """Start the AgentDaemon inside the sandbox.""" + print("\n--- Starting AgentDaemon ---") + client.execute( + f"PYTHONPATH=/tmp:$PYTHONPATH nohup /mnt/llm-ai-infra/miniconda3/envs/train/bin/python -m lagent.serving.sandbox.daemon start " + f"--mode agent --config /tmp/agent_config.json " + f"--sock {sock_path} " + f"> /tmp/lagent_agent.log 2>&1 &" + ) + print(" Waiting for socket...") + for i in range(20): + time.sleep(2) + r = client.execute(f"test -S {sock_path} && echo 'ready' || echo 'waiting'") + status = r.get("stdout", "").strip() + if "ready" in status: + print(f" Socket ready! ({(i+1)*2}s)") + return True + print(f" [{(i+1)*2}s] {status}") + + print(" FAILED. Daemon log:") + data = client.download_file("/tmp/lagent_agent.log") + print(data.decode()[:1000]) + return False + + +async def test_agent(client, sock_path="/tmp/lagent_agent.sock"): + """Run interactive tests against the AgentDaemon.""" + # Patch client for PYTHONPATH + original_exec = client.execute + def patched(command, **kw): + return original_exec(f"PYTHONPATH=/tmp:$PYTHONPATH {command}", **kw) + client.execute = patched + + agent = SandboxAgent( + sandbox_client=client, + agent_config=agent_config, + sock_path=sock_path, + ) + agent._connected = True + + # Ping + print("\n--- Ping ---") + r = await agent._daemon_call({"cmd": "ping"}) + print(f" {r}") + + # List tools + print("\n--- List tools ---") + r = await agent._daemon_call({"cmd": "list_tools"}) + tools = [t["name"] for t in r.get("tools", [])] + print(f" {tools}") + + # Chat + print("\n--- Chat ---") + print(" Sending: '请执行 echo hello world 并告诉我结果'") + response = await agent("请执行 echo hello world 并告诉我结果") + print(f" Response type: {type(response).__name__}") + print(f" Content: {str(response.content)[:500]}") + + # State dict + print("\n--- State dict ---") + state = await agent.get_state_dict() + print(f" Keys: {list(state.keys())[:10]}") + for k, v in state.items(): + if isinstance(v, list): + print(f" {k}: {len(v)} items") + else: + print(f" {k}: {type(v).__name__}") + + # Interactive mode + print("\n--- Interactive mode (type 'quit' to exit) ---") + while True: + try: + user_input = input("\n[You] > ").strip() + except (EOFError, KeyboardInterrupt): + break + if not user_input or user_input.lower() in ("quit", "exit", "q"): + break + if user_input == "state": + state = await agent.get_state_dict() + print(f"State: {json.dumps({k: type(v).__name__ for k, v in state.items()}, indent=2)}") + continue + if user_input == "reset": + await agent.reset() + print("Reset done.") + continue + if user_input == "tools": + r = await agent._daemon_call({"cmd": "list_tools"}) + print(json.dumps([t["name"] for t in r.get("tools", [])], indent=2)) + continue + + response = await agent(user_input) + print(f"\n[Agent] {str(response.content)[:1000]}") + + +def main(): + parser = argparse.ArgumentParser(description="Debug AgentDaemon E2E on Gateway sandbox") + parser.add_argument("--gateway", default="http://env-gateway.ailab.ailab.ai") + parser.add_argument("--image", default="hb_3d-scan-calc") + parser.add_argument("--ttl", type=int, default=600, help="Sandbox TTL in seconds") + parser.add_argument("--sock", default="/tmp/lagent_agent.sock") + parser.add_argument("--skip-setup", action="store_true", help="Skip upload/install (reuse existing sandbox)") + args = parser.parse_args() + + print("=" * 60) + print(" AgentDaemon Debug — Gateway Sandbox") + print("=" * 60) + + provider = GatewayProvider(args.gateway) + print(f"\nCreating sandbox (image={args.image}, ttl={args.ttl}s)...") + client, env_id = provider.create(image_tag=args.image, ttl_seconds=args.ttl) + print(f"url: {client.url}") + print(f"env_id: {env_id}") + + try: + if not args.skip_setup: + print(f"\n{'='*60}") + print(" Setup") + print(f"{'='*60}\n") + setup_sandbox(client) + + print(f"\n{'='*60}") + print(" Verify") + print(f"{'='*60}") + if not verify_agent_creation(client): + print("\nAgent creation failed. Debug with:") + print(f" curl {client.url}/exec -X POST -H 'Content-Type: application/json' \\") + print(f" -d '{{\"command\": \"PYTHONPATH=/tmp:$PYTHONPATH /mnt/llm-ai-infra/miniconda3/envs/train/bin/python /tmp/test_create.py 2>&1\"}}'") + return + + print(f"\n{'='*60}") + print(" Daemon") + print(f"{'='*60}") + if not start_agent_daemon(client, args.sock): + return + + print(f"\n{'='*60}") + print(" Test") + print(f"{'='*60}") + asyncio.run(test_agent(client, args.sock)) + + finally: + print(f"\n{'='*60}") + print(f"Cleaning up sandbox {env_id}...") + provider.delete(env_id) + print("Done.") + + +if __name__ == "__main__": + main() diff --git a/tests/test_actions/test_cron_action.py b/tests/test_actions/test_cron_action.py new file mode 100644 index 00000000..167bed74 --- /dev/null +++ b/tests/test_actions/test_cron_action.py @@ -0,0 +1,236 @@ +"""Unit tests for CronAction (lagent/actions/cron.py).""" + +import os +import sys +import tempfile +import types +from datetime import datetime, timezone, timedelta +from pathlib import Path + +import pytest + +# --- bypass circular import in lagent.services.__init__.py --- +_here = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +if _here not in sys.path: + sys.path.insert(0, _here) +if "lagent.services" not in sys.modules: + _pkg = types.ModuleType("lagent.services") + _pkg.__path__ = [os.path.join(_here, "lagent", "services")] + _pkg.__package__ = "lagent.services" + sys.modules["lagent.services"] = _pkg + +from lagent.actions.cron import CronAction, AsyncCronAction +from lagent.schema import ActionStatusCode +from lagent.services.cron import CronService + + +def _make_action(tmpdir) -> CronAction: + path = Path(tmpdir) / "jobs.json" + svc = CronService(store_path=path) + return CronAction(cron_service=svc, channel="test", chat_id="c1") + + +# ── add ────────────────────────────────────────────────────────────── + +class TestCronActionAdd: + def test_add_every(self): + with tempfile.TemporaryDirectory() as tmpdir: + action = _make_action(tmpdir) + result = action.add( + name="repeat", + message="do stuff", + schedule_kind="every", + every_seconds=30.0, + ) + assert result.result is not None + assert "repeat" in result.result[0]["content"] + + def test_add_at(self): + with tempfile.TemporaryDirectory() as tmpdir: + action = _make_action(tmpdir) + future = (datetime.now(timezone.utc) + timedelta(hours=1)).isoformat() + result = action.add( + name="once", + message="remind", + schedule_kind="at", + at=future, + ) + assert result.result is not None + assert "once" in result.result[0]["content"] + + def test_add_cron(self): + pytest.importorskip("croniter") + with tempfile.TemporaryDirectory() as tmpdir: + action = _make_action(tmpdir) + result = action.add( + name="daily", + message="report", + schedule_kind="cron", + cron_expr="0 9 * * *", + timezone="UTC", + ) + assert result.result is not None + assert "daily" in result.result[0]["content"] + + def test_add_invalid_kind_returns_args_error(self): + with tempfile.TemporaryDirectory() as tmpdir: + action = _make_action(tmpdir) + result = action.add( + name="bad", + message="x", + schedule_kind="invalid", + ) + assert result.state == ActionStatusCode.ARGS_ERROR + assert "Invalid schedule_kind" in result.errmsg + + def test_add_at_without_datetime_returns_api_error(self): + with tempfile.TemporaryDirectory() as tmpdir: + action = _make_action(tmpdir) + result = action.add( + name="bad", + message="x", + schedule_kind="at", + ) + assert result.state == ActionStatusCode.API_ERROR + + def test_add_every_without_seconds_returns_api_error(self): + with tempfile.TemporaryDirectory() as tmpdir: + action = _make_action(tmpdir) + result = action.add( + name="bad", + message="x", + schedule_kind="every", + ) + assert result.state == ActionStatusCode.API_ERROR + + def test_add_cron_without_expr_returns_api_error(self): + with tempfile.TemporaryDirectory() as tmpdir: + action = _make_action(tmpdir) + result = action.add( + name="bad", + message="x", + schedule_kind="cron", + ) + assert result.state == ActionStatusCode.API_ERROR + + def test_add_populates_channel_and_chat_id(self): + with tempfile.TemporaryDirectory() as tmpdir: + action = _make_action(tmpdir) + action.add( + name="j1", + message="hello", + schedule_kind="every", + every_seconds=60.0, + ) + job = action._cron.list_jobs(include_disabled=True)[0] + assert job.payload["channel"] == "test" + assert job.payload["chat_id"] == "c1" + + def test_add_at_sets_delete_after_run(self): + with tempfile.TemporaryDirectory() as tmpdir: + action = _make_action(tmpdir) + future = (datetime.now(timezone.utc) + timedelta(hours=1)).isoformat() + action.add( + name="once", + message="x", + schedule_kind="at", + at=future, + ) + job = action._cron.list_jobs(include_disabled=True)[0] + assert job.delete_after_run is True + + def test_add_every_does_not_set_delete_after_run(self): + with tempfile.TemporaryDirectory() as tmpdir: + action = _make_action(tmpdir) + action.add( + name="repeat", + message="x", + schedule_kind="every", + every_seconds=10.0, + ) + job = action._cron.list_jobs(include_disabled=True)[0] + assert job.delete_after_run is False + + +# ── list ───────────────────────────────────────────────────────────── + +class TestCronActionList: + def test_list_empty(self): + with tempfile.TemporaryDirectory() as tmpdir: + action = _make_action(tmpdir) + result = action.list() + assert "No active" in result.result[0]["content"] + + def test_list_with_jobs(self): + with tempfile.TemporaryDirectory() as tmpdir: + action = _make_action(tmpdir) + action.add(name="j1", message="a", schedule_kind="every", every_seconds=10.0) + action.add(name="j2", message="b", schedule_kind="every", every_seconds=20.0) + result = action.list() + content = result.result[0]["content"] + assert "j1" in content + assert "j2" in content + + def test_list_hides_disabled(self): + with tempfile.TemporaryDirectory() as tmpdir: + action = _make_action(tmpdir) + action.add(name="j1", message="a", schedule_kind="every", every_seconds=10.0) + # Manually disable the job + action._cron.list_jobs(include_disabled=True)[0].enabled = False + result = action.list() + assert "No active" in result.result[0]["content"] + + def test_list_shows_schedule_description(self): + with tempfile.TemporaryDirectory() as tmpdir: + action = _make_action(tmpdir) + action.add(name="j1", message="a", schedule_kind="every", every_seconds=30.0) + result = action.list() + content = result.result[0]["content"] + assert "every 30.0s" in content + + +# ── remove ─────────────────────────────────────────────────────────── + +class TestCronActionRemove: + def test_remove_existing(self): + with tempfile.TemporaryDirectory() as tmpdir: + action = _make_action(tmpdir) + action.add(name="j1", message="a", schedule_kind="every", every_seconds=10.0) + job_id = action._cron.list_jobs(include_disabled=True)[0].id + result = action.remove(job_id=job_id) + assert "removed" in result.result[0]["content"] + assert len(action._cron.list_jobs(include_disabled=True)) == 0 + + def test_remove_nonexistent(self): + with tempfile.TemporaryDirectory() as tmpdir: + action = _make_action(tmpdir) + result = action.remove(job_id="nope") + assert result.state == ActionStatusCode.API_ERROR + assert "not found" in result.errmsg + + +# ── toolkit metadata ───────────────────────────────────────────────── + +class TestCronActionMeta: + def test_is_toolkit(self): + with tempfile.TemporaryDirectory() as tmpdir: + action = _make_action(tmpdir) + assert action.is_toolkit + + def test_has_three_apis(self): + with tempfile.TemporaryDirectory() as tmpdir: + action = _make_action(tmpdir) + desc = action.description + api_names = {api["name"] for api in desc["api_list"]} + assert api_names == {"add", "list", "remove"} + + +# ── async variant ──────────────────────────────────────────────────── + +class TestAsyncCronAction: + def test_instantiation(self): + with tempfile.TemporaryDirectory() as tmpdir: + path = Path(tmpdir) / "jobs.json" + svc = CronService(store_path=path) + action = AsyncCronAction(cron_service=svc) + assert action._cron is svc diff --git a/tests/test_actions/test_sandbox_executor.py b/tests/test_actions/test_sandbox_executor.py new file mode 100644 index 00000000..9d7fe451 --- /dev/null +++ b/tests/test_actions/test_sandbox_executor.py @@ -0,0 +1,1023 @@ +"""Tests for ActionDaemon, SandboxActionExecutor and HybridActionExecutor. + +Run with:: + + # Local unit tests (no sandbox needed, uses local daemon subprocess) + pytest tests/test_actions/test_sandbox_executor.py -v + + # Real sandbox E2E (requires gateway access) + RUN_E2E=1 pytest tests/test_actions/test_sandbox_executor.py -v -k "E2E" +""" + +from __future__ import annotations + +import asyncio +import base64 +import io +import json +import os +import subprocess +import tarfile +import time + +import pytest +import requests + +from lagent.serving.sandbox.daemon import ActionDaemon, async_lagent_call, lagent_call +from lagent.actions.hybrid_executor import HybridActionExecutor +from lagent.actions.sandbox_executor import ( + SandboxActionExecutor, + _ToolDescriptionProxy, + _ToolDescriptionStub, + _deserialize_action_return, +) +from lagent.schema import ActionReturn, ActionStatusCode, ActionValidCode, AgentMessage, FunctionCall + +# E2E tests are skipped unless RUN_E2E=1 environment variable is set +e2e = pytest.mark.skipif( + not os.environ.get("RUN_E2E"), reason="Set RUN_E2E=1 to run real sandbox E2E tests" +) + + +# ===================================================================== +# Part 1: Local unit tests (daemon subprocess, no real sandbox) +# ===================================================================== + +SOCK_PATH = "/tmp/lagent_test_daemon.sock" +CONFIG_PATH = "/tmp/lagent_test_actions.json" +ACTIONS_CONFIG = [ + {"type": "lagent.actions.python_interpreter.PythonInterpreter"}, +] + + +@pytest.fixture(scope="module") +def daemon_process(): + """Start an ActionDaemon in a subprocess for the entire test module.""" + with open(CONFIG_PATH, "w") as f: + json.dump(ACTIONS_CONFIG, f) + + proc = subprocess.Popen( + [ + "python", "-m", "lagent.serving.sandbox.daemon", "start", + "--sock", SOCK_PATH, + "--actions-config", CONFIG_PATH, + ], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + for _ in range(30): + if os.path.exists(SOCK_PATH): + break + time.sleep(0.2) + else: + proc.kill() + raise RuntimeError("Daemon did not start") + + yield proc + + try: + lagent_call(SOCK_PATH, '{"cmd":"shutdown"}') + except Exception: + pass + time.sleep(0.5) + proc.terminate() + proc.wait(timeout=5) + if os.path.exists(SOCK_PATH): + os.unlink(SOCK_PATH) + + +class LocalSandboxClient: + """Mock sandbox client that runs bash locally (async).""" + + async def execute(self, command: str, cwd: str = "/tmp", timeout_sec: int = 30) -> dict: + proc = await asyncio.create_subprocess_shell( + command, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + cwd=cwd, + ) + stdout, stderr = await proc.communicate() + return { + "ok": proc.returncode == 0, + "stdout": stdout.decode(), + "stderr": stderr.decode(), + "exit_code": proc.returncode, + } + + +class SyncLocalSandboxClient: + """Sync mock sandbox client (like the user's real SandboxClient).""" + + def execute(self, command: str, cwd: str = "/tmp", timeout_sec: int = 30) -> dict: + import subprocess as sp + result = sp.run(command, shell=True, capture_output=True, text=True, cwd=cwd) + return { + "ok": result.returncode == 0, + "stdout": result.stdout, + "stderr": result.stderr, + "exit_code": result.returncode, + } + + +# --------------------------------------------------------------------------- +# ActionDaemon tests +# --------------------------------------------------------------------------- + + +class TestActionDaemon: + + def test_ping(self, daemon_process): + result = json.loads(lagent_call(SOCK_PATH, '{"cmd":"ping"}')) + assert result["status"] == "ok" + assert result["type"] == "action" + + def test_list_tools(self, daemon_process): + result = json.loads(lagent_call(SOCK_PATH, '{"cmd":"list_tools"}')) + assert "tools" in result + tools = result["tools"] + assert len(tools) >= 1 + assert tools[0]["name"] == "PythonInterpreter" + assert "parameters" in tools[0] + + def test_action_call_success(self, daemon_process): + request = json.dumps({ + "name": "PythonInterpreter", + "parameters": {"command": "def solution():\n return 42"}, + }) + result = json.loads(lagent_call(SOCK_PATH, request)) + assert result["state"] == ActionStatusCode.SUCCESS + assert result["result"][0]["content"] == "42" + + def test_action_call_error(self, daemon_process): + request = json.dumps({ + "name": "PythonInterpreter", + "parameters": {"command": "invalid code without solution()"}, + }) + result = json.loads(lagent_call(SOCK_PATH, request)) + assert result["state"] == ActionStatusCode.API_ERROR + + def test_missing_name(self, daemon_process): + result = json.loads(lagent_call(SOCK_PATH, '{"parameters":{}}')) + assert result["state"] == ActionStatusCode.ARGS_ERROR + + def test_unknown_action(self, daemon_process): + request = json.dumps({ + "name": "NonExistentAction", + "parameters": {}, + }) + result = json.loads(lagent_call(SOCK_PATH, request)) + assert result["state"] in (ActionStatusCode.ARGS_ERROR, ActionStatusCode.API_ERROR) + + @pytest.mark.asyncio + async def test_async_call(self, daemon_process): + result = await async_lagent_call(SOCK_PATH, '{"cmd":"ping"}') + parsed = json.loads(result) + assert parsed["status"] == "ok" + + +# --------------------------------------------------------------------------- +# Deserialization tests +# --------------------------------------------------------------------------- + + +class TestDeserialization: + + def test_deserialize_success(self): + text = json.dumps({ + "args": {"command": "ls"}, + "type": "ShellAction", + "result": [{"type": "text", "content": "file.txt"}], + "state": 0, + "valid": 0, + "errmsg": None, + "url": None, + "thought": None, + }) + ar = _deserialize_action_return(text) + assert isinstance(ar, ActionReturn) + assert ar.state == ActionStatusCode.SUCCESS + assert ar.valid == ActionValidCode.OPEN + assert ar.result[0]["content"] == "file.txt" + + def test_deserialize_error(self): + text = json.dumps({ + "args": {}, + "state": -1002, + "valid": 0, + "errmsg": "something broke", + }) + ar = _deserialize_action_return(text) + assert ar.state == ActionStatusCode.API_ERROR + assert ar.errmsg == "something broke" + + +# --------------------------------------------------------------------------- +# ToolDescriptionProxy tests +# --------------------------------------------------------------------------- + + +class TestToolDescriptionProxy: + + def test_proxy_contains(self): + proxy = _ToolDescriptionProxy({"shell": {"name": "shell"}}) + assert "shell" in proxy + assert "missing" not in proxy + + def test_proxy_values(self): + proxy = _ToolDescriptionProxy({ + "a": {"name": "a", "description": "A"}, + "b": {"name": "b", "description": "B"}, + }) + stubs = proxy.values() + assert len(stubs) == 2 + assert all(isinstance(s, _ToolDescriptionStub) for s in stubs) + assert stubs[0].name == "a" + assert stubs[0].is_toolkit is False + + def test_proxy_keys(self): + proxy = _ToolDescriptionProxy({"x": {"name": "x"}, "y": {"name": "y"}}) + assert proxy.keys() == ["x", "y"] + + +# --------------------------------------------------------------------------- +# SandboxActionExecutor tests (async mock client) +# --------------------------------------------------------------------------- + + +class TestSandboxActionExecutorAsync: + + @pytest.fixture() + def executor(self, daemon_process): + client = LocalSandboxClient() + ex = SandboxActionExecutor( + sandbox_client=client, + actions_config=ACTIONS_CONFIG, + sock_path=SOCK_PATH, + cwd="/tmp", + ) + ex._connected = True + ex._tool_descriptions = { + "PythonInterpreter": { + "name": "PythonInterpreter", + "description": "...", + "parameters": [{"name": "command", "type": "STRING", "description": ""}], + "required": ["command"], + } + } + return ex + + @pytest.mark.asyncio + async def test_forward_success(self, executor): + result = await executor.forward( + "PythonInterpreter", + {"command": "def solution():\n return 99"}, + ) + assert result.state == ActionStatusCode.SUCCESS + assert result.result[0]["content"] == "99" + assert result.valid == ActionValidCode.OPEN + + @pytest.mark.asyncio + async def test_forward_unknown_action(self, executor): + result = await executor.forward("NonExistent", {"x": 1}) + assert result.valid == ActionValidCode.INVALID + + @pytest.mark.asyncio + async def test_contains(self, executor): + assert "PythonInterpreter" in executor + assert "NonExistent" not in executor + + @pytest.mark.asyncio + async def test_description(self, executor): + descs = executor.description() + assert len(descs) == 1 + assert descs[0]["name"] == "PythonInterpreter" + + @pytest.mark.asyncio + async def test_call_with_agent_message(self, executor): + msg = AgentMessage( + sender="test", + content=FunctionCall( + name="PythonInterpreter", + parameters={"command": "def solution():\n return 7"}, + ), + ) + response = await executor(msg) + assert isinstance(response, AgentMessage) + ar = response.content + assert isinstance(ar, ActionReturn) + assert ar.result[0]["content"] == "7" + + +# --------------------------------------------------------------------------- +# SandboxActionExecutor with sync client +# --------------------------------------------------------------------------- + + +class TestSandboxActionExecutorSync: + + @pytest.fixture() + def executor(self, daemon_process): + client = SyncLocalSandboxClient() + ex = SandboxActionExecutor( + sandbox_client=client, + actions_config=ACTIONS_CONFIG, + sock_path=SOCK_PATH, + cwd="/tmp", + ) + ex._connected = True + ex._tool_descriptions = { + "PythonInterpreter": { + "name": "PythonInterpreter", + "description": "...", + "parameters": [{"name": "command", "type": "STRING", "description": ""}], + "required": ["command"], + } + } + return ex + + @pytest.mark.asyncio + async def test_sync_client_forward(self, executor): + """Verify sync client is wrapped with to_thread and works.""" + result = await executor.forward( + "PythonInterpreter", + {"command": "def solution():\n return 55"}, + ) + assert result.state == ActionStatusCode.SUCCESS + assert result.result[0]["content"] == "55" + + +# --------------------------------------------------------------------------- +# HybridActionExecutor tests +# --------------------------------------------------------------------------- + + +class TestHybridActionExecutor: + + @pytest.fixture() + def hybrid(self, daemon_process): + client = LocalSandboxClient() + sandbox = SandboxActionExecutor( + sandbox_client=client, + actions_config=ACTIONS_CONFIG, + sock_path=SOCK_PATH, + cwd="/tmp", + ) + sandbox._connected = True + sandbox._tool_descriptions = { + "PythonInterpreter": { + "name": "PythonInterpreter", + "description": "...", + "parameters": [{"name": "command", "type": "STRING", "description": ""}], + "required": ["command"], + } + } + return HybridActionExecutor( + local_actions=[], + sandbox_executor=sandbox, + ) + + @pytest.mark.asyncio + async def test_route_to_sandbox(self, hybrid): + result = await hybrid.forward( + "PythonInterpreter", + {"command": "def solution():\n return 123"}, + ) + assert result.state == ActionStatusCode.SUCCESS + assert result.result[0]["content"] == "123" + + @pytest.mark.asyncio + async def test_route_unknown_to_invalid(self, hybrid): + result = await hybrid.forward("DoesNotExist", {}) + assert result.valid == ActionValidCode.INVALID + + def test_description_merges(self, hybrid): + descs = hybrid.description() + names = [d["name"] for d in descs] + assert "PythonInterpreter" in names + + def test_contains_sandbox_tools(self, hybrid): + assert "PythonInterpreter" in hybrid + + def test_keys_merged(self, hybrid): + assert "PythonInterpreter" in hybrid.keys() + + +# ===================================================================== +# Part 2: Real sandbox E2E tests +# +# Run with: pytest tests/test_actions/test_sandbox_executor.py -v -k "E2E" --run-e2e +# ===================================================================== + +GATEWAY_URL = "http://env-gateway.ailab.ailab.ai" +IMAGE_TAG = "hb_3d-scan-calc" +E2E_ACTIONS_CONFIG = [ + {"type": "lagent.actions.shell.ShellAction", "working_dir": "/root"}, + {"type": "lagent.actions.filesystem.ReadFileAction", "workspace": "/root"}, + {"type": "lagent.actions.filesystem.WriteFileAction", "workspace": "/root"}, + {"type": "lagent.actions.filesystem.EditFileAction", "workspace": "/root"}, + {"type": "lagent.actions.python_interpreter.PythonInterpreter"}, +] +# Minimal __init__.py — avoids importing actions with heavy optional deps +_MINIMAL_ACTIONS_INIT = ( + "from .action_executor import ActionExecutor, AsyncActionExecutor\n" + "from .base_action import AsyncActionMixin, BaseAction, tool_api\n" + "from .builtin_actions import FinishAction, InvalidAction, NoAction\n" + "from .parser import BaseParser, JsonParser, TupleParser\n" +) +_MINIMAL_HOOKS_INIT = "from .hook import Hook, RemovableHandle\n" +_REQUIRED_PIPS = "griffe termcolor asyncer func_timeout" + + +class RealSandboxClient: + """Wraps the real EnvGateway SandboxClient for use with SandboxActionExecutor. + + Prepends PYTHONPATH so the uploaded lagent source is importable. + """ + + def __init__(self, base_url: str, pythonpath: str = "/tmp"): + self.base_url = base_url + self.pythonpath = pythonpath + self.session = requests.Session() + self.session.headers.update({ + "Connection": "keep-alive", + "Content-Type": "application/json", + }) + + def execute(self, command: str, cwd: str = "/root", timeout_sec: int = 60) -> dict: + command = f"PYTHONPATH={self.pythonpath}:$PYTHONPATH {command}" + resp = self.session.post( + f"{self.base_url}/exec", + json={"command": command, "cwd": cwd, "timeout_sec": timeout_sec}, + ) + return resp.json() + + +@pytest.fixture(scope="module") +def sandbox_env(): + """Create a real sandbox, upload lagent, install deps, start daemon. + + Yields (sandbox_url, env_id, client). + Cleans up the sandbox after all E2E tests finish. + """ + gw = requests.Session() + gw.headers.update({"Content-Type": "application/json"}) + + # 1. Create sandbox + resp = gw.post( + f"{GATEWAY_URL}/envs", + json={"image_tag": IMAGE_TAG, "ttl_seconds": 600}, + timeout=120, + ) + ret = resp.json() + assert ret["ok"], f"Failed to create sandbox: {ret}" + sandbox_url = ret["env"]["url"] + env_id = ret["env"]["env_id"] + + session = requests.Session() + session.headers.update({"Connection": "keep-alive", "Content-Type": "application/json"}) + + def _exec(cmd, timeout=120): + r = session.post( + f"{sandbox_url}/exec", + json={"command": cmd, "cwd": "/root", "timeout_sec": timeout}, + ) + return r.json() + + # 2. Upload lagent source + buf = io.BytesIO() + lagent_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + lagent_pkg = os.path.join(lagent_root, "lagent") + with tarfile.open(fileobj=buf, mode="w:gz") as tar: + for root, dirs, files in os.walk(lagent_pkg): + dirs[:] = [d for d in dirs if d != "__pycache__"] + for fname in files: + if fname.endswith(".pyc"): + continue + full_path = os.path.join(root, fname) + tar.add(full_path, arcname=os.path.relpath(full_path, lagent_root)) + content_b64 = base64.b64encode(buf.getvalue()).decode() + resp = session.post( + f"{sandbox_url}/upload", + json={"target_path": "/tmp/lagent_src.tar.gz", "content_b64": content_b64}, + ) + assert resp.json()["ok"], "Upload failed" + _exec("cd /tmp && tar xzf lagent_src.tar.gz") + + # 3. Write minimal __init__.py (avoid optional deps) + _exec(f"cat > /tmp/lagent/actions/__init__.py << 'EOF'\n{_MINIMAL_ACTIONS_INIT}EOF") + _exec(f"cat > /tmp/lagent/hooks/__init__.py << 'EOF'\n{_MINIMAL_HOOKS_INIT}EOF") + + # 4. Install required deps + _exec(f"pip install {_REQUIRED_PIPS} 2>&1 | tail -3", timeout=120) + + # 5. Start daemon + config_json = json.dumps(E2E_ACTIONS_CONFIG) + _exec(f"echo '{config_json}' > /tmp/lagent_actions_config.json") + _exec( + "PYTHONPATH=/tmp:$PYTHONPATH nohup python -m lagent.serving.sandbox.daemon start " + "--sock /tmp/lagent_action.sock " + "--actions-config /tmp/lagent_actions_config.json " + "> /tmp/lagent_daemon.log 2>&1 &" + ) + time.sleep(3) + r = _exec("test -S /tmp/lagent_action.sock && echo 'ready'") + assert "ready" in r.get("stdout", ""), ( + f"Daemon did not start: {_exec('cat /tmp/lagent_daemon.log')}" + ) + + client = RealSandboxClient(sandbox_url) + + yield sandbox_url, env_id, client + + # Cleanup + try: + gw.delete(f"{GATEWAY_URL}/envs/{env_id}", timeout=30) + except Exception: + pass + + +@pytest.fixture() +def e2e_executor(sandbox_env): + """Create a connected SandboxActionExecutor for E2E tests.""" + sandbox_url, env_id, client = sandbox_env + + async def _setup(): + executor = SandboxActionExecutor( + sandbox_client=client, + actions_config=E2E_ACTIONS_CONFIG, + sock_path="/tmp/lagent_action.sock", + ) + executor._connected = True + result = await executor._daemon_call({"cmd": "list_tools"}) + executor._tool_descriptions = {t["name"]: t for t in result["tools"]} + return executor + + loop = asyncio.new_event_loop() + executor = loop.run_until_complete(_setup()) + loop.close() + return executor + + +# --------------------------------------------------------------------------- +# E2E test class +# --------------------------------------------------------------------------- + + +@e2e +class TestSandboxE2E: + """Real sandbox E2E tests — shell, file ops, python, cross-action state.""" + + def test_tools_registered(self, e2e_executor): + tools = e2e_executor.keys() + assert "ShellAction" in tools + assert "ReadFileAction" in tools + assert "WriteFileAction" in tools + assert "EditFileAction" in tools + assert "PythonInterpreter" in tools + + @pytest.mark.asyncio + async def test_shell_echo(self, e2e_executor): + ar = await e2e_executor.forward("ShellAction", {"command": "echo hello"}) + assert ar.state == ActionStatusCode.SUCCESS + assert "hello" in ar.result[0]["content"] + + @pytest.mark.asyncio + async def test_shell_ls(self, e2e_executor): + ar = await e2e_executor.forward("ShellAction", {"command": "ls /"}) + assert ar.state == ActionStatusCode.SUCCESS + assert "root" in ar.result[0]["content"] or "tmp" in ar.result[0]["content"] + + @pytest.mark.asyncio + async def test_write_file(self, e2e_executor): + ar = await e2e_executor.forward("WriteFileAction", { + "path": "e2e_test.txt", + "content": "Line 1: Hello\nLine 2: World\nLine 3: End\n", + }) + assert ar.state == ActionStatusCode.SUCCESS + + @pytest.mark.asyncio + async def test_read_file(self, e2e_executor): + # Ensure file exists + await e2e_executor.forward("WriteFileAction", { + "path": "e2e_test.txt", + "content": "Line 1: Hello\nLine 2: World\nLine 3: End\n", + }) + ar = await e2e_executor.forward("ReadFileAction", {"path": "e2e_test.txt"}) + assert ar.state == ActionStatusCode.SUCCESS + assert "Hello" in ar.result[0]["content"] + assert "Line 2: World" in ar.result[0]["content"] + + @pytest.mark.asyncio + async def test_edit_file(self, e2e_executor): + # Write → Edit → Read-back + await e2e_executor.forward("WriteFileAction", { + "path": "e2e_edit.txt", + "content": "before edit\n", + }) + ar = await e2e_executor.forward("EditFileAction", { + "path": "e2e_edit.txt", + "search": "before edit", + "replace": "AFTER EDIT", + }) + assert ar.state == ActionStatusCode.SUCCESS + + ar = await e2e_executor.forward("ReadFileAction", {"path": "e2e_edit.txt"}) + assert "AFTER EDIT" in ar.result[0]["content"] + + @pytest.mark.asyncio + async def test_python_interpreter(self, e2e_executor): + ar = await e2e_executor.forward("PythonInterpreter", { + "command": "def solution():\n return 6 * 7", + }) + assert ar.state == ActionStatusCode.SUCCESS + assert ar.result[0]["content"] == "42" + + @pytest.mark.asyncio + async def test_cross_action_state(self, e2e_executor): + """File created by WriteFileAction should be visible to ShellAction.""" + await e2e_executor.forward("WriteFileAction", { + "path": "cross_test.txt", + "content": "CROSS_ACTION_OK\n", + }) + ar = await e2e_executor.forward("ShellAction", {"command": "cat /root/cross_test.txt"}) + assert ar.state == ActionStatusCode.SUCCESS + assert "CROSS_ACTION_OK" in ar.result[0]["content"] + + @pytest.mark.asyncio + async def test_hybrid_routing(self, e2e_executor): + """HybridActionExecutor routes sandbox tools correctly.""" + hybrid = HybridActionExecutor(local_actions=[], sandbox_executor=e2e_executor) + + ar = await hybrid.forward("ShellAction", {"command": "echo hybrid_ok"}) + assert ar.state == ActionStatusCode.SUCCESS + assert "hybrid_ok" in ar.result[0]["content"] + + ar = await hybrid.forward("DoesNotExist", {}) + assert ar.valid == ActionValidCode.INVALID + + @pytest.mark.asyncio + async def test_unknown_action(self, e2e_executor): + ar = await e2e_executor.forward("FakeAction", {}) + assert ar.valid == ActionValidCode.INVALID + + +# ===================================================================== +# Direct invocation entry point +# ===================================================================== + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="Sandbox E2E tests") + parser.add_argument("--gateway", default=GATEWAY_URL, help="Gateway URL") + parser.add_argument("--image", default=IMAGE_TAG, help="Sandbox image tag") + args = parser.parse_args() + + GATEWAY_URL = args.gateway + IMAGE_TAG = args.image + + async def main(): + # --- Setup --- + print("=" * 60) + print(" Sandbox E2E Test") + print("=" * 60) + + gw = requests.Session() + gw.headers.update({"Content-Type": "application/json"}) + + print(f"\n[1/5] Creating sandbox (image={IMAGE_TAG})...") + resp = gw.post( + f"{GATEWAY_URL}/envs", + json={"image_tag": IMAGE_TAG, "ttl_seconds": 600}, + timeout=120, + ) + ret = resp.json() + assert ret["ok"], f"Failed: {ret}" + sandbox_url = ret["env"]["url"] + env_id = ret["env"]["env_id"] + print(f" url={sandbox_url}") + + session = requests.Session() + session.headers.update({"Connection": "keep-alive", "Content-Type": "application/json"}) + + def _exec(cmd, timeout=120): + r = session.post( + f"{sandbox_url}/exec", + json={"command": cmd, "cwd": "/root", "timeout_sec": timeout}, + ) + return r.json() + + try: + # Upload lagent + print("[2/5] Uploading lagent source...") + buf = io.BytesIO() + lagent_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + lagent_pkg = os.path.join(lagent_root, "lagent") + with tarfile.open(fileobj=buf, mode="w:gz") as tar: + for root, dirs, files in os.walk(lagent_pkg): + dirs[:] = [d for d in dirs if d != "__pycache__"] + for fname in files: + if fname.endswith(".pyc"): + continue + full_path = os.path.join(root, fname) + tar.add(full_path, arcname=os.path.relpath(full_path, lagent_root)) + content_b64 = base64.b64encode(buf.getvalue()).decode() + resp = session.post( + f"{sandbox_url}/upload", + json={"target_path": "/tmp/lagent_src.tar.gz", "content_b64": content_b64}, + ) + assert resp.json()["ok"] + _exec("cd /tmp && tar xzf lagent_src.tar.gz") + _exec(f"cat > /tmp/lagent/actions/__init__.py << 'EOF'\n{_MINIMAL_ACTIONS_INIT}EOF") + _exec(f"cat > /tmp/lagent/hooks/__init__.py << 'EOF'\n{_MINIMAL_HOOKS_INIT}EOF") + + # Install deps + print("[3/5] Installing dependencies...") + _exec(f"pip install {_REQUIRED_PIPS} 2>&1 | tail -3", timeout=120) + + # Start daemon + print("[4/5] Starting daemon...") + config_json = json.dumps(E2E_ACTIONS_CONFIG) + _exec(f"echo '{config_json}' > /tmp/lagent_actions_config.json") + _exec( + "PYTHONPATH=/tmp:$PYTHONPATH nohup python -m lagent.serving.sandbox.daemon start " + "--sock /tmp/lagent_action.sock " + "--actions-config /tmp/lagent_actions_config.json " + "> /tmp/lagent_daemon.log 2>&1 &" + ) + time.sleep(3) + r = _exec("test -S /tmp/lagent_action.sock && echo 'ready'") + assert "ready" in r.get("stdout", ""), f"Daemon failed: {_exec('cat /tmp/lagent_daemon.log')}" + + # Run tests + print("[5/5] Running tests...\n") + client = RealSandboxClient(sandbox_url) + executor = SandboxActionExecutor( + sandbox_client=client, + actions_config=E2E_ACTIONS_CONFIG, + sock_path="/tmp/lagent_action.sock", + ) + executor._connected = True + result = await executor._daemon_call({"cmd": "list_tools"}) + executor._tool_descriptions = {t["name"]: t for t in result["tools"]} + hybrid = HybridActionExecutor(local_actions=[], sandbox_executor=executor) + + tools = executor.keys() + print(f" Tools: {tools}") + + passed = 0 + failed = 0 + + async def check(name, coro, assertion): + nonlocal passed, failed + try: + ar = await coro + assertion(ar) + print(f" PASS {name}") + passed += 1 + except Exception as e: + print(f" FAIL {name}: {e}") + failed += 1 + + await check( + "shell(echo)", + executor.forward("ShellAction", {"command": "echo hello"}), + lambda ar: ( + assert_eq(ar.state, ActionStatusCode.SUCCESS), + assert_in("hello", ar.result[0]["content"]), + ), + ) + await check( + "shell(ls /)", + executor.forward("ShellAction", {"command": "ls /"}), + lambda ar: assert_eq(ar.state, ActionStatusCode.SUCCESS), + ) + await check( + "write_file", + executor.forward("WriteFileAction", {"path": "e2e.txt", "content": "HELLO\nWORLD\n"}), + lambda ar: assert_eq(ar.state, ActionStatusCode.SUCCESS), + ) + await check( + "read_file", + executor.forward("ReadFileAction", {"path": "e2e.txt"}), + lambda ar: ( + assert_eq(ar.state, ActionStatusCode.SUCCESS), + assert_in("HELLO", ar.result[0]["content"]), + ), + ) + await check( + "edit_file", + executor.forward("EditFileAction", {"path": "e2e.txt", "search": "WORLD", "replace": "EDITED"}), + lambda ar: assert_eq(ar.state, ActionStatusCode.SUCCESS), + ) + + # Verify edit + ar = await executor.forward("ReadFileAction", {"path": "e2e.txt"}) + await check( + "verify_edit", + asyncio.coroutine(lambda: ar)(), + lambda ar: assert_in("EDITED", ar.result[0]["content"]), + ) if False else None # skip the coroutine trick + assert "EDITED" in ar.result[0]["content"], "Edit verification failed" + print(" PASS verify_edit") + passed += 1 + + await check( + "python(6*7)", + executor.forward("PythonInterpreter", {"command": "def solution():\n return 6*7"}), + lambda ar: ( + assert_eq(ar.state, ActionStatusCode.SUCCESS), + assert_eq(ar.result[0]["content"], "42"), + ), + ) + await check( + "cross_action_state", + hybrid.forward("ShellAction", {"command": "cat /root/e2e.txt"}), + lambda ar: assert_in("EDITED", ar.result[0]["content"]), + ) + await check( + "hybrid_routing", + hybrid.forward("ShellAction", {"command": "echo ok"}), + lambda ar: assert_in("ok", ar.result[0]["content"]), + ) + await check( + "unknown_action", + hybrid.forward("FakeAction", {}), + lambda ar: assert_eq(ar.valid, ActionValidCode.INVALID), + ) + + print(f"\n{'=' * 60}") + print(f" Results: {passed} passed, {failed} failed") + print(f"{'=' * 60}") + if failed: + exit(1) + + finally: + print("\nCleaning up sandbox...") + try: + gw.delete(f"{GATEWAY_URL}/envs/{env_id}", timeout=30) + print(" Done.") + except Exception as e: + print(f" Cleanup failed: {e}") + + def assert_eq(a, b): + assert a == b, f"{a} != {b}" + + def assert_in(needle, haystack): + assert needle in haystack, f"{needle!r} not in {haystack!r}" + + asyncio.run(main()) + + +# ===================================================================== +# Part 3: AgentDaemon E2E tests +# +# Tests the full InternClawAgent running inside a sandbox via AgentDaemon. +# Requires: RUN_AGENT_E2E=1 + LLM server accessible from sandbox. +# +# Run with: +# RUN_AGENT_E2E=1 python tests/test_actions/test_sandbox_executor.py --agent-e2e +# or +# RUN_AGENT_E2E=1 pytest tests/test_actions/test_sandbox_executor.py -v -k "AgentDaemon" +# ===================================================================== + +agent_e2e = pytest.mark.skipif( + not os.environ.get("RUN_AGENT_E2E"), + reason="Set RUN_AGENT_E2E=1 to run AgentDaemon E2E tests", +) + +# ClusterX config (shared storage, deps pre-installed) +CLUSTERX_PARTITION = "llmit_proxy" +CLUSTERX_CONDA_ENV = "xtuner_dev" +CLUSTERX_CONDA_ACTIVATE = "/mnt/shared-storage-user/liukuikun/miniconda3/bin/activate" +LAGENT_PATH = "/mnt/shared-storage-user/llmit/user/liukuikun/workspace/lagent" + + +@pytest.fixture(scope="module") +def agent_sandbox_env(): + """Create a ClusterX sandbox with SandboxServer for AgentDaemon tests. + + Uses ClusterX because the shared storage already has lagent + deps. + """ + import sys + sys.path.insert(0, LAGENT_PATH) + + try: + from lagent.serving.sandbox.providers.clusterx import ClusterXProvider + except ImportError: + pytest.skip("clusterx not available") + + provider = ClusterXProvider( + partition=CLUSTERX_PARTITION, + conda_env=CLUSTERX_CONDA_ENV, + conda_activate_path=CLUSTERX_CONDA_ACTIVATE, + python_path=LAGENT_PATH, + port=19876, + extra_run_kwargs={ + "cpus_per_task": 4, + "memory_per_task": 10, + "no_env": True, + }, + ) + + client, job_id = provider.create(timeout=300) + + yield client, job_id, provider + + try: + provider.delete(job_id) + except Exception: + pass + + +@pytest.fixture() +def agent_daemon_client(agent_sandbox_env): + """Start AgentDaemon in the sandbox and return a connected SandboxAgent.""" + import json as _json + from lagent.serving.sandbox.agent import SandboxAgent + + client, job_id, provider = agent_sandbox_env + + # Patch client for PYTHONPATH + conda + original_exec = client.execute + prefix = ( + f"source {CLUSTERX_CONDA_ACTIVATE} {CLUSTERX_CONDA_ENV} && " + f"PYTHONPATH={LAGENT_PATH}:$PYTHONPATH " + ) + + def patched_execute(command, **kw): + return original_exec(f"{prefix}{command}", **kw) + + client.execute = patched_execute + + # Write agent config + from workspace.agents.default_agent.config import agent_config + config_json = _json.dumps(agent_config, ensure_ascii=False) + escaped = config_json.replace("'", "'\\''") + client.execute(f"echo '{escaped}' > /tmp/agent_config.json") + client.execute("mkdir -p /root/workspace/memory /root/workspace/skills") + + # Start daemon + sock_path = "/tmp/lagent_agent_e2e.sock" + client.execute( + f"nohup python -m lagent.serving.sandbox.daemon start " + f"--mode agent --config /tmp/agent_config.json " + f"--sock {sock_path} " + f"> /tmp/lagent_agent_e2e.log 2>&1 &" + ) + + import time + time.sleep(8) + + r = client.execute(f"test -S {sock_path} && echo 'ready' || echo 'not ready'") + if "ready" not in r.get("stdout", ""): + r = client.execute(f"tail -30 /tmp/lagent_agent_e2e.log") + pytest.fail(f"AgentDaemon failed to start:\n{r.get('stdout', '')}") + + agent = SandboxAgent( + sandbox_client=client, + agent_config=agent_config, + sock_path=sock_path, + ) + agent._connected = True + return agent + + +@agent_e2e +class TestAgentDaemonE2E: + """Full InternClawAgent running inside a sandbox via AgentDaemon.""" + + @pytest.mark.asyncio + async def test_ping(self, agent_daemon_client): + r = await agent_daemon_client._daemon_call({"cmd": "ping"}) + assert r["status"] == "ok" + assert r["type"] == "agent" + + @pytest.mark.asyncio + async def test_list_tools(self, agent_daemon_client): + r = await agent_daemon_client._daemon_call({"cmd": "list_tools"}) + tools = r.get("tools", []) + tool_names = [t["name"] for t in tools] + assert len(tools) > 0 + # Should have at least shell and file actions + assert any("Shell" in n or "shell" in n for n in tool_names) + + @pytest.mark.asyncio + async def test_chat(self, agent_daemon_client): + response = await agent_daemon_client("请执行 echo hello world 并告诉我结果") + assert response.content is not None + assert len(str(response.content)) > 0 + + @pytest.mark.asyncio + async def test_state_dict(self, agent_daemon_client): + # Chat first to have some state + await agent_daemon_client("执行 echo test") + state = await agent_daemon_client.get_state_dict() + assert isinstance(state, dict) + + @pytest.mark.asyncio + async def test_reset(self, agent_daemon_client): + await agent_daemon_client.reset() + # After reset, should still be able to chat + r = await agent_daemon_client._daemon_call({"cmd": "ping"}) + assert r["status"] == "ok" diff --git a/tests/test_actions/test_send_message_action.py b/tests/test_actions/test_send_message_action.py new file mode 100644 index 00000000..274af1a8 --- /dev/null +++ b/tests/test_actions/test_send_message_action.py @@ -0,0 +1,153 @@ +"""Unit tests for SendMessageAction (lagent/actions/send_message.py).""" + +import os +import sys +import types + +import pytest + +# --- bypass circular import --- +_here = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +if _here not in sys.path: + sys.path.insert(0, _here) +if "lagent.services" not in sys.modules: + _pkg = types.ModuleType("lagent.services") + _pkg.__path__ = [os.path.join(_here, "lagent", "services")] + _pkg.__package__ = "lagent.services" + sys.modules["lagent.services"] = _pkg + +from lagent.actions.send_message import SendMessageAction, AsyncSendMessageAction +from lagent.schema import ActionStatusCode +from lagent.services.mailbox import Mailbox + + +def _make_action( + agent_name: str = "agent-A", + allowed_receivers=None, +) -> tuple[SendMessageAction, Mailbox]: + mb = Mailbox() + mb.register("agent-A") + mb.register("agent-B") + mb.register("lead") + action = SendMessageAction( + mailbox=mb, + agent_name=agent_name, + allowed_receivers=allowed_receivers, + ) + return action, mb + + +# ═══════════════════════════════════════════════════════════════════════ +# SEND +# ═══════════════════════════════════════════════════════════════════════ + +class TestSend: + def test_send_basic(self): + action, mb = _make_action() + result = action.send(to="agent-B", content="hello") + assert result.result is not None + assert "sent to @agent-B" in result.result[0]["content"] + assert mb.has_messages("agent-B") + + def test_send_content_preserved(self): + action, mb = _make_action() + action.send(to="agent-B", content="specific message") + msgs = mb.drain("agent-B") + assert len(msgs) == 1 + assert msgs[0].content == "specific message" + assert msgs[0].sender == "agent-A" + assert msgs[0].receiver == "agent-B" + + def test_send_broadcast(self): + action, mb = _make_action() + result = action.send(to="*", content="hello all") + assert "Broadcast" in result.result[0]["content"] + assert mb.has_messages("agent-B") + assert mb.has_messages("lead") + assert not mb.has_messages("agent-A") # sender excluded + + def test_send_allowed_receivers_permits(self): + action, mb = _make_action(allowed_receivers=["lead"]) + result = action.send(to="lead", content="report") + assert result.result is not None + assert mb.has_messages("lead") + + def test_send_allowed_receivers_blocks(self): + action, mb = _make_action(allowed_receivers=["lead"]) + result = action.send(to="agent-B", content="hello") + assert result.state == ActionStatusCode.API_ERROR + assert "Cannot send" in result.errmsg + assert not mb.has_messages("agent-B") + + def test_send_broadcast_ignores_allowed(self): + action, mb = _make_action(allowed_receivers=["lead"]) + result = action.send(to="*", content="hello all") + # Broadcast is always allowed + assert result.result is not None + + def test_send_no_restrictions(self): + action, mb = _make_action(allowed_receivers=None) + result = action.send(to="agent-B", content="hello") + assert result.result is not None + + +# ═══════════════════════════════════════════════════════════════════════ +# CHECK_INBOX +# ═══════════════════════════════════════════════════════════════════════ + +class TestCheckInbox: + def test_empty_inbox(self): + action, mb = _make_action() + result = action.check_inbox() + assert "No new messages" in result.result[0]["content"] + + def test_with_messages(self): + action, mb = _make_action() + # Another agent sends to agent-A + from lagent.schema import AgentMessage + mb.send(AgentMessage(sender="agent-B", receiver="agent-A", content="hey there")) + result = action.check_inbox() + content = result.result[0]["content"] + assert "1 pending" in content + assert "@agent-B" in content + assert "hey there" in content + + def test_check_inbox_does_not_consume(self): + action, mb = _make_action() + from lagent.schema import AgentMessage + mb.send(AgentMessage(sender="agent-B", receiver="agent-A", content="hey")) + action.check_inbox() + assert mb.has_messages("agent-A") # still there + + +# ═══════════════════════════════════════════════════════════════════════ +# TOOLKIT METADATA +# ═══════════════════════════════════════════════════════════════════════ + +class TestMeta: + def test_is_toolkit(self): + action, _ = _make_action() + assert action.is_toolkit + + def test_has_two_apis(self): + action, _ = _make_action() + desc = action.description + api_names = {api["name"] for api in desc["api_list"]} + assert api_names == {"send", "check_inbox"} + + def test_auto_registers_agent(self): + mb = Mailbox() + action = SendMessageAction(mailbox=mb, agent_name="new-agent") + assert "new-agent" in mb.agents + + +# ═══════════════════════════════════════════════════════════════════════ +# ASYNC VARIANT +# ═══════════════════════════════════════════════════════════════════════ + +class TestAsyncSendMessageAction: + def test_instantiation(self): + mb = Mailbox() + action = AsyncSendMessageAction(mailbox=mb, agent_name="async-agent") + assert action._mailbox is mb + assert action._agent_name == "async-agent" diff --git a/tests/test_actions/test_subagent.py b/tests/test_actions/test_subagent.py new file mode 100644 index 00000000..32ed8024 --- /dev/null +++ b/tests/test_actions/test_subagent.py @@ -0,0 +1,359 @@ +"""Unit tests for AsyncAgentAction (subagent toolkit). + +Run: + pytest tests/test_actions/test_subagent.py -v +""" + +import sys +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock + +import pytest + +ROOT = Path(__file__).resolve().parent.parent.parent +if str(ROOT) not in sys.path: + sys.path.insert(0, str(ROOT)) + +from lagent.actions.subagent import AsyncAgentAction +from lagent.schema import ActionReturn, ActionStatusCode +from lagent.services.agent import AgentEntry, AgentService, AgentStatus + + +# ── Helpers ────────────────────────────────────────────────────────── + + +def _make_entry(**overrides) -> AgentEntry: + defaults = dict( + id="abc12345", + agent_type="worker", + label="test task", + task="do something", + status=AgentStatus.STOPPED, + result="task done", + ) + defaults.update(overrides) + return AgentEntry(**defaults) + + +def _make_action(service=None) -> AsyncAgentAction: + if service is None: + service = MagicMock(spec=AgentService) + return AsyncAgentAction(agent_service=service) + + +def _result_text(ret: ActionReturn) -> str: + """Extract text content from ActionReturn.""" + if ret.result and len(ret.result) > 0: + return ret.result[0].get("content", "") + return "" + + +# ── spawn ──────────────────────────────────────────────────────────── + + +class TestSpawn: + + @pytest.mark.asyncio + async def test_spawn_sync_success(self): + entry = _make_entry() + service = MagicMock(spec=AgentService) + service.spawn = AsyncMock(return_value=entry) + + action = _make_action(service) + ret = await action.spawn( + agent_type="worker", task="do something", + label="test", mode="sync", + ) + + assert ret.state != ActionStatusCode.API_ERROR + assert "completed" in _result_text(ret) + assert entry.id in _result_text(ret) + service.spawn.assert_awaited_once_with( + agent_type="worker", task="do something", + label="test", mode="sync", + ) + + @pytest.mark.asyncio + async def test_spawn_sync_failed_agent(self): + entry = _make_entry(status=AgentStatus.FAILED, error="out of memory") + service = MagicMock(spec=AgentService) + service.spawn = AsyncMock(return_value=entry) + + action = _make_action(service) + ret = await action.spawn( + agent_type="worker", task="task", mode="sync", + ) + + assert ret.state == ActionStatusCode.API_ERROR + assert "out of memory" in ret.errmsg + + @pytest.mark.asyncio + async def test_spawn_sync_no_output(self): + entry = _make_entry(result=None) + service = MagicMock(spec=AgentService) + service.spawn = AsyncMock(return_value=entry) + + action = _make_action(service) + ret = await action.spawn( + agent_type="worker", task="task", mode="sync", + ) + + assert "(no output)" in _result_text(ret) + + @pytest.mark.asyncio + async def test_spawn_async_success(self): + entry = _make_entry(status=AgentStatus.RUNNING, result=None) + service = MagicMock(spec=AgentService) + service.spawn = AsyncMock(return_value=entry) + + action = _make_action(service) + ret = await action.spawn( + agent_type="worker", task="background job", mode="async", + ) + + assert ret.state != ActionStatusCode.API_ERROR + text = _result_text(ret) + assert "background" in text.lower() or "spawned" in text.lower() + assert entry.id in text + assert entry.agent_type in text + + @pytest.mark.asyncio + async def test_spawn_value_error(self): + service = MagicMock(spec=AgentService) + service.spawn = AsyncMock(side_effect=ValueError("Unknown agent type")) + + action = _make_action(service) + ret = await action.spawn(agent_type="bad", task="task") + + assert ret.state == ActionStatusCode.API_ERROR + assert "Unknown agent type" in ret.errmsg + + @pytest.mark.asyncio + async def test_spawn_generic_error(self): + service = MagicMock(spec=AgentService) + service.spawn = AsyncMock(side_effect=RuntimeError("connection lost")) + + action = _make_action(service) + ret = await action.spawn(agent_type="worker", task="task") + + assert ret.state == ActionStatusCode.API_ERROR + assert "connection lost" in ret.errmsg + + +# ── list_agents ────────────────────────────────────────────────────── + + +class TestListAgents: + + @pytest.mark.asyncio + async def test_list_with_entries(self): + entries = [ + _make_entry(id="aaa", label="task A", status=AgentStatus.STOPPED), + _make_entry(id="bbb", label="task B", status=AgentStatus.RUNNING), + ] + service = MagicMock(spec=AgentService) + service.list = MagicMock(return_value=entries) + + action = _make_action(service) + ret = await action.list_agents() + + text = _result_text(ret) + assert "task A" in text + assert "task B" in text + assert "aaa" in text + assert "bbb" in text + + @pytest.mark.asyncio + async def test_list_empty(self): + service = MagicMock(spec=AgentService) + service.list = MagicMock(return_value=[]) + + action = _make_action(service) + ret = await action.list_agents() + + assert "No sub-agents" in _result_text(ret) + + @pytest.mark.asyncio + async def test_list_empty_with_filter(self): + service = MagicMock(spec=AgentService) + service.list = MagicMock(return_value=[]) + + action = _make_action(service) + ret = await action.list_agents(status="running") + + text = _result_text(ret) + assert "No sub-agents" in text + assert "running" in text + + @pytest.mark.asyncio + async def test_list_status_icons(self): + entries = [ + _make_entry(id="1", label="p", status=AgentStatus.PENDING), + _make_entry(id="2", label="r", status=AgentStatus.RUNNING), + _make_entry(id="3", label="s", status=AgentStatus.STOPPED), + _make_entry(id="4", label="f", status=AgentStatus.FAILED), + ] + service = MagicMock(spec=AgentService) + service.list = MagicMock(return_value=entries) + + action = _make_action(service) + ret = await action.list_agents() + text = _result_text(ret) + + # Each status should have its icon + for entry in entries: + assert entry.id in text + + +# ── query_agent ────────────────────────────────────────────────────── + + +class TestQueryAgent: + + @pytest.mark.asyncio + async def test_query_found(self): + entry = _make_entry(result="42", error=None) + service = MagicMock(spec=AgentService) + service.get = MagicMock(return_value=entry) + + action = _make_action(service) + ret = await action.query_agent(agent_id=entry.id) + + text = _result_text(ret) + assert entry.id in text + assert entry.agent_type in text + assert "42" in text + + @pytest.mark.asyncio + async def test_query_with_error(self): + entry = _make_entry( + status=AgentStatus.FAILED, result=None, error="kaboom", + ) + service = MagicMock(spec=AgentService) + service.get = MagicMock(return_value=entry) + + action = _make_action(service) + ret = await action.query_agent(agent_id=entry.id) + + text = _result_text(ret) + assert "kaboom" in text + + @pytest.mark.asyncio + async def test_query_not_found(self): + service = MagicMock(spec=AgentService) + service.get = MagicMock(return_value=None) + + action = _make_action(service) + ret = await action.query_agent(agent_id="nonexistent") + + assert ret.state == ActionStatusCode.API_ERROR + assert "not found" in ret.errmsg + + @pytest.mark.asyncio + async def test_query_no_result_no_error(self): + entry = _make_entry(result=None, error=None) + service = MagicMock(spec=AgentService) + service.get = MagicMock(return_value=entry) + + action = _make_action(service) + ret = await action.query_agent(agent_id=entry.id) + + text = _result_text(ret) + assert entry.id in text + assert "Result" not in text + assert "Error" not in text + + +# ── resume_agent ───────────────────────────────────────────────────── + + +class TestResumeAgent: + + @pytest.mark.asyncio + async def test_resume_success(self): + entry = _make_entry(result="continued") + service = MagicMock(spec=AgentService) + service.resume = AsyncMock(return_value=entry) + + action = _make_action(service) + ret = await action.resume_agent(agent_id="abc12345", message="go on") + + text = _result_text(ret) + assert "resumed" in text + assert "continued" in text + service.resume.assert_awaited_once_with("abc12345", "go on") + + @pytest.mark.asyncio + async def test_resume_no_output(self): + entry = _make_entry(result=None) + service = MagicMock(spec=AgentService) + service.resume = AsyncMock(return_value=entry) + + action = _make_action(service) + ret = await action.resume_agent(agent_id="abc12345", message="go") + + assert "(no output)" in _result_text(ret) + + @pytest.mark.asyncio + async def test_resume_failed(self): + entry = _make_entry(status=AgentStatus.FAILED, error="crash") + service = MagicMock(spec=AgentService) + service.resume = AsyncMock(return_value=entry) + + action = _make_action(service) + ret = await action.resume_agent(agent_id="abc12345", message="retry") + + assert ret.state == ActionStatusCode.API_ERROR + assert "crash" in ret.errmsg + + @pytest.mark.asyncio + async def test_resume_value_error(self): + service = MagicMock(spec=AgentService) + service.resume = AsyncMock( + side_effect=ValueError("Agent 'x' not found"), + ) + + action = _make_action(service) + ret = await action.resume_agent(agent_id="x", message="hi") + + assert ret.state == ActionStatusCode.API_ERROR + assert "not found" in ret.errmsg + + @pytest.mark.asyncio + async def test_resume_generic_error(self): + service = MagicMock(spec=AgentService) + service.resume = AsyncMock(side_effect=RuntimeError("timeout")) + + action = _make_action(service) + ret = await action.resume_agent(agent_id="x", message="hi") + + assert ret.state == ActionStatusCode.API_ERROR + assert "timeout" in ret.errmsg + + +# ── stop_agent ─────────────────────────────────────────────────────── + + +class TestStopAgent: + + @pytest.mark.asyncio + async def test_stop_success(self): + service = MagicMock(spec=AgentService) + service.stop = AsyncMock(return_value=True) + + action = _make_action(service) + ret = await action.stop_agent(agent_id="abc12345") + + assert ret.state != ActionStatusCode.API_ERROR + assert "stopped" in _result_text(ret).lower() + + @pytest.mark.asyncio + async def test_stop_not_found(self): + service = MagicMock(spec=AgentService) + service.stop = AsyncMock(return_value=False) + + action = _make_action(service) + ret = await action.stop_agent(agent_id="nonexistent") + + assert ret.state == ActionStatusCode.API_ERROR + assert "not found" in ret.errmsg.lower() or "not running" in ret.errmsg.lower() diff --git a/tests/test_actions/test_task_action.py b/tests/test_actions/test_task_action.py new file mode 100644 index 00000000..fb84d05c --- /dev/null +++ b/tests/test_actions/test_task_action.py @@ -0,0 +1,170 @@ +"""Unit tests for TaskAction (lagent/actions/task.py). + +Bypasses the circular import in lagent packages by importing the +task modules directly before any lagent __init__.py gets triggered. +""" + +import os +import sys +import types + +# --- bypass circular import in lagent.services.__init__.py --- +_here = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +if _here not in sys.path: + sys.path.insert(0, _here) +if "lagent.services" not in sys.modules: + _pkg = types.ModuleType("lagent.services") + _pkg.__path__ = [os.path.join(_here, "lagent", "services")] + _pkg.__package__ = "lagent.services" + sys.modules["lagent.services"] = _pkg + +import pytest + +from lagent.actions.task import TaskAction, AsyncTaskAction +from lagent.schema import ActionStatusCode +from lagent.services.task import TaskBoard + + +def make_action() -> TaskAction: + board = TaskBoard() + return TaskAction(task_board=board) + + +class TestTaskActionCreate: + def test_create_via_method(self): + action = make_action() + result = action.create(subject="Fix bug", description="Fix it") + assert "Created task #1" in result.result[0]["content"] + + def test_create_with_blocked_by(self): + action = make_action() + action.create(subject="A", description="desc") + result = action.create(subject="B", description="desc", blocked_by="1") + assert "blocked by #1" in result.result[0]["content"] + + def test_create_with_metadata(self): + action = make_action() + result = action.create( + subject="Task", description="desc", + metadata='{"priority": "high"}', + ) + assert result.result is not None + task = action._board.get("1") + assert task.metadata == {"priority": "high"} + + def test_create_with_active_form(self): + action = make_action() + action.create(subject="Run tests", description="desc", active_form="Running tests") + task = action._board.get("1") + assert task.active_form == "Running tests" + + +class TestTaskActionUpdate: + def test_update_status(self): + action = make_action() + action.create(subject="Task", description="desc") + result = action.update(task_id="1", status="in_progress") + assert "in_progress" in result.result[0]["content"] + + def test_update_nonexistent(self): + action = make_action() + result = action.update(task_id="999", status="completed") + assert result.state == ActionStatusCode.API_ERROR + + def test_update_deleted(self): + action = make_action() + action.create(subject="Task", description="desc") + result = action.update(task_id="1", status="deleted") + assert "deleted" in result.result[0]["content"] + assert action._board.get("1") is None + + def test_update_no_fields(self): + action = make_action() + action.create(subject="Task", description="desc") + result = action.update(task_id="1") + assert result.state == ActionStatusCode.ARGS_ERROR + + def test_update_add_blocked_by(self): + action = make_action() + action.create(subject="A", description="desc") + action.create(subject="B", description="desc") + action.update(task_id="2", add_blocked_by="1") + t2 = action._board.get("2") + assert "1" in t2.blocked_by + + def test_update_metadata_merge(self): + action = make_action() + action.create(subject="Task", description="desc", metadata='{"a": 1}') + action.update(task_id="1", metadata='{"b": 2}') + task = action._board.get("1") + assert task.metadata == {"a": 1, "b": 2} + + +class TestTaskActionGet: + def test_get_existing(self): + action = make_action() + action.create(subject="Fix bug", description="Fix the login bug") + result = action.get(task_id="1") + assert "Fix bug" in result.result[0]["content"] + assert "Fix the login bug" in result.result[0]["content"] + + def test_get_nonexistent(self): + action = make_action() + result = action.get(task_id="999") + assert result.state == ActionStatusCode.API_ERROR + + +class TestTaskActionList: + def test_list_all(self): + action = make_action() + action.create(subject="A", description="desc") + action.create(subject="B", description="desc") + result = action.list() + content = result.result[0]["content"] + assert "#1" in content + assert "#2" in content + + def test_list_filtered(self): + action = make_action() + action.create(subject="A", description="desc") + action.create(subject="B", description="desc") + action.update(task_id="1", status="completed") + result = action.list(status="completed") + content = result.result[0]["content"] + assert "#1" in content + assert "#2" not in content + + def test_list_empty(self): + action = make_action() + result = action.list() + assert "No tasks" in result.result[0]["content"] + + def test_list_shows_blockers(self): + action = make_action() + action.create(subject="A", description="desc") + action.create(subject="B", description="desc", blocked_by="1") + result = action.list() + content = result.result[0]["content"] + assert "blocked by #1" in content + + +class TestAsyncTaskAction: + def test_instantiation(self): + board = TaskBoard() + action = AsyncTaskAction(task_board=board) + assert action._board is board + + +class TestToolDescription: + def test_is_toolkit(self): + action = make_action() + assert action.is_toolkit + + def test_has_four_apis(self): + action = make_action() + desc = action.description + api_names = {api["name"] for api in desc["api_list"]} + assert "create" in api_names + assert "update" in api_names + assert "get" in api_names + assert "list" in api_names diff --git a/tests/test_adapters/test_chain_rebuild.py b/tests/test_adapters/test_chain_rebuild.py new file mode 100644 index 00000000..bdb3668d --- /dev/null +++ b/tests/test_adapters/test_chain_rebuild.py @@ -0,0 +1,341 @@ +"""Test rebuild_chains and to_training_samples with complex scenarios. + +Simulates sub-agents, context compression, and mixed cases using +mock proxy records — no real LLM calls needed. + +Run: + python tests/test_adapters/test_chain_rebuild.py + python -m pytest tests/test_adapters/test_chain_rebuild.py -v -s +""" +import asyncio +import sys +from lagent.adapters.proxy import LLMProxyRecorder + + +def make_chain_records(turns, model="claude-opus-4-6", system="default_system", + prefix="main", input_tokens=100, output_tokens=50): + """Build a realistic chain of LLM call records. + + Each turn adds a user message and the previous assistant response + to the messages list, mimicking real multi-turn behavior where + each request carries the full conversation history. + + Args: + turns: List of (user_text, assistant_response) tuples. + model: Model name for all records. + system: System prompt for all records. + prefix: Prefix for timestamps. + + Returns: + List of record dicts. + """ + records = [] + messages = [] + for i, (user_text, response_text) in enumerate(turns): + messages.append({"role": "user", "content": user_text}) + record = { + "timestamp": f"2026-01-01T{prefix}:{i:02d}", + "request": { + "model": model, + "system": system, + "messages": list(messages), # copy + "tools": [{"name": "tool1"}], + }, + "response": { + "content": [{"type": "text", "text": response_text}], + "usage": { + "input_tokens": input_tokens, + "output_tokens": output_tokens, + "cache_read_input_tokens": 0, + "cache_creation_input_tokens": 0, + }, + "model": model, + }, + "path": "/v1/messages", + "method": "POST", + "stream": False, + } + records.append(record) + messages.append({"role": "assistant", "content": response_text}) + return records + + +class TestRebuildChains: + """Test chain rebuilding with various scenarios.""" + + def setup_method(self): + self.proxy = LLMProxyRecorder.__new__(LLMProxyRecorder) + self.proxy._records = {} + + def _inject(self, session_id, records): + self.proxy._records[session_id] = records + + def test_simple_multiturn(self): + """Simple case: 3-turn conversation → 1 chain.""" + print("\n--- Simple multi-turn (3 turns) ---") + records = make_chain_records([ + ("Hello", "Hi there!"), + ("What is 2+2?", "4"), + ("Thanks", "You're welcome!"), + ]) + self._inject("s1", records) + + chains = self.proxy.rebuild_chains("s1") + print(f"Chains: {len(chains)}") + for i, c in enumerate(chains): + print(f" Chain {i+1}: {len(c)} records, msgs={[len(r['messages']) for r in c]}") + assert len(chains) == 1 + assert len(chains[0]) == 3 + + samples = self.proxy.to_training_samples("s1") + print(f"Samples: {len(samples)}, messages in sample: {len(samples[0]['messages'])}") + assert len(samples) == 1 + assert len(samples[0]['messages']) == 7 # system + 3 user + 3 assistant + assert samples[0]['meta']['num_calls'] == 3 + + def test_context_compression(self): + """Context compression: history grows, then resets with summary.""" + print("\n--- Context compression ---") + pre = make_chain_records([ + ("Hello", "Hi!"), + ("Tell me about Python", "Python is great."), + ("More details", "It has many libraries."), + ], prefix="pre") + + # Post-compression: fresh start with summary (no previous response) + post = make_chain_records([ + ("Summary: we discussed Python. Now explain decorators.", "Decorators wrap functions."), + ("Example?", "@decorator syntax."), + ], prefix="post") + + self._inject("s1", pre + post) + + chains = self.proxy.rebuild_chains("s1") + print(f"Chains: {len(chains)}") + for i, c in enumerate(chains): + print(f" Chain {i+1}: {len(c)} records, msgs={[len(r['messages']) for r in c]}") + assert len(chains) == 2 + assert len(chains[0]) == 3 # pre-compression + assert len(chains[1]) == 2 # post-compression + + samples = self.proxy.to_training_samples("s1") + print(f"Samples: {len(samples)}") + assert len(samples) == 2 + + def test_subagent_different_model(self): + """Sub-agent uses a different model → separate chain.""" + print("\n--- Sub-agent (different model) ---") + main1 = make_chain_records([ + ("Fix the bug", "Let me look at the code."), + ("Found it?", "Yes, line 42."), + ], model="opus", prefix="m1") + + sub = make_chain_records([ + ("Search for similar bugs", "Found 3 results."), + ], model="haiku", prefix="sub") + + main2 = make_chain_records([ + ("Apply the fix from the search", "Done, fixed."), + ], model="opus", prefix="m2") + + self._inject("s1", main1 + sub + main2) + + chains = self.proxy.rebuild_chains("s1") + print(f"Chains: {len(chains)}") + for i, c in enumerate(chains): + models = set(r['meta']['model'] for r in c) + print(f" Chain {i+1}: {len(c)} records, model={models}") + assert len(chains) == 3 + + def test_subagent_same_model(self): + """Sub-agent uses same model — detected by response discontinuity.""" + print("\n--- Sub-agent (same model) ---") + main = make_chain_records([ + ("Main task", "Working on it."), + ("Continue", "Need to search first."), + ], model="opus", prefix="main") + + # Sub-agent: fresh messages, previous response NOT in history + sub = make_chain_records([ + ("Sub-task: search code", "Found relevant code."), + ("Details?", "Function foo() at line 10."), + ], model="opus", prefix="sub") + + self._inject("s1", main + sub) + + chains = self.proxy.rebuild_chains("s1") + print(f"Chains: {len(chains)}") + for i, c in enumerate(chains): + print(f" Chain {i+1}: {len(c)} records, msgs={[len(r['messages']) for r in c]}") + assert len(chains) == 2 + assert len(chains[0]) == 2 # main + assert len(chains[1]) == 2 # sub + + def test_compression_plus_subagent(self): + """Complex: main agent compresses, then spawns sub-agent.""" + print("\n--- Compression + Sub-agent (same model) ---") + + main_pre = make_chain_records([ + ("Fix bug", "Looking..."), + ("Status?", "Found issue in auth.py."), + ("Fix it", "Applying patch..."), + ], model="opus", prefix="pre") + + main_post = make_chain_records([ + ("Summary: fixed auth bug. Now add tests.", "Writing tests..."), + ], model="opus", prefix="post") + + sub = make_chain_records([ + ("Find test examples", "Found 5 examples."), + ("Show best one", "test_auth.py is the best."), + ], model="opus", prefix="sub") + + main_resume = make_chain_records([ + ("Write test based on example", "Test written."), + ("Run tests", "All 12 tests passed."), + ], model="opus", prefix="resume") + + self._inject("s1", main_pre + main_post + sub + main_resume) + + chains = self.proxy.rebuild_chains("s1") + print(f"Chains: {len(chains)}") + for i, c in enumerate(chains): + print(f" Chain {i+1}: {len(c)} records, msgs={[len(r['messages']) for r in c]}") + + # main_pre(3) → main_post(1): response "Applying patch..." not in post → break + # main_post(1) → sub(1): response "Writing tests..." not in sub → break + # sub(2) → main_resume(1): response "test_auth.py..." not in resume → break + assert len(chains) == 4 + assert len(chains[0]) == 3 + assert len(chains[1]) == 1 + assert len(chains[2]) == 2 + assert len(chains[3]) == 2 + + def test_single_call(self): + """Edge case: only one LLM call.""" + print("\n--- Single call ---") + records = make_chain_records([("Hello", "Hi")]) + self._inject("s1", records) + + chains = self.proxy.rebuild_chains("s1") + assert len(chains) == 1 + samples = self.proxy.to_training_samples("s1") + assert len(samples) == 1 + assert samples[0]['meta']['num_calls'] == 1 + print(f"Chains: {len(chains)}, Samples: {len(samples)} ✓") + + def test_empty_session(self): + """Edge case: no records.""" + print("\n--- Empty session ---") + chains = self.proxy.rebuild_chains("nonexistent") + assert len(chains) == 0 + samples = self.proxy.to_training_samples("nonexistent") + assert len(samples) == 0 + print(f"Chains: {len(chains)}, Samples: {len(samples)} ✓") + + def test_usage_aggregation(self): + """Verify training sample format: messages with meta + overall meta.""" + print("\n--- Training sample format ---") + records = make_chain_records( + [("Q1", "A1"), ("Q2", "A2"), ("Q3", "A3")], + input_tokens=100, output_tokens=20, + ) + self._inject("s1", records) + + samples = self.proxy.to_training_samples("s1") + assert len(samples) == 1 + s = samples[0] + + # Top-level structure + assert 'messages' in s + assert 'meta' in s + print(f"messages: {len(s['messages'])}") + print(f"meta keys: {sorted(s['meta'].keys())}") + + # Messages: system + 3 user + 3 assistant (including final response) + msgs = s['messages'] + roles = [m['role'] for m in msgs] + print(f"roles: {roles}") + # system, user, assistant(A1), user, assistant(A2), user, assistant(A3) + assert roles[0] == 'system' + assert roles[-1] == 'assistant' + + # Assistant messages should have extra_info + asst_msgs = [m for m in msgs if m['role'] == 'assistant'] + print(f"assistant messages: {len(asst_msgs)}") + for i, m in enumerate(asst_msgs): + has_extra = 'extra_info' in m + print(f" asst[{i}]: content='{m['content'][:30]}', has_extra_info={has_extra}") + + # The last assistant message (from final response) should have extra_info + assert 'extra_info' in asst_msgs[-1] + + # Overall meta + meta = s['meta'] + print(f"num_calls: {meta['num_calls']}") + print(f"total_usage: {meta['total_usage']}") + assert meta['num_calls'] == 3 + assert meta['total_usage']['total_input_tokens'] == 300 + assert meta['total_usage']['total_output_tokens'] == 60 + print("✓") + + def test_openai_response_format(self): + """Chain detection works with OpenAI response format (choices).""" + print("\n--- OpenAI response format ---") + records = [] + messages = [] + for i, (user, asst) in enumerate([("Hi", "Hello!"), ("How?", "Fine!")]): + messages.append({"role": "user", "content": user}) + records.append({ + "timestamp": f"T{i}", + "request": {"model": "gpt-4o", "messages": list(messages)}, + "response": { + "choices": [{"message": {"content": asst}}], + "usage": {"prompt_tokens": 10, "completion_tokens": 5}, + }, + "path": "/v1/chat/completions", + "method": "POST", + "stream": False, + }) + messages.append({"role": "assistant", "content": asst}) + + self._inject("s1", records) + chains = self.proxy.rebuild_chains("s1") + print(f"Chains: {len(chains)}") + assert len(chains) == 1 + assert len(chains[0]) == 2 + print("✓") + + +# ── F5 Runner ──────────────────────────────────────────────────── + +async def _run_test(test_cls, method_name): + obj = test_cls() + obj.setup_method() + method = getattr(obj, method_name) + print(f"\n{'='*60}") + print(f" {test_cls.__name__}.{method_name}") + print(f"{'='*60}") + try: + method() + print(f" ✅ PASSED") + except AssertionError as e: + print(f" ❌ FAILED: {e}") + import traceback + traceback.print_exc() + + +async def run_all(): + for name in sorted(dir(TestRebuildChains)): + if name.startswith('test_'): + await _run_test(TestRebuildChains, name) + print(f"\n{'='*60}") + print(" Done!") + print(f"{'='*60}") + + +if __name__ == '__main__': + if len(sys.argv) > 1: + asyncio.run(_run_test(TestRebuildChains, sys.argv[1])) + else: + asyncio.run(run_all()) diff --git a/tests/test_adapters/test_claude_code_integration.py b/tests/test_adapters/test_claude_code_integration.py new file mode 100644 index 00000000..e3f023c6 --- /dev/null +++ b/tests/test_adapters/test_claude_code_integration.py @@ -0,0 +1,496 @@ +"""Integration tests for Claude Code adapters (CLI + SDK + Proxy). + +These tests hit the real Claude Code CLI / SDK, so they require: + - claude CLI on PATH + - ANTHROPIC_AUTH_TOKEN and ANTHROPIC_BASE_URL set + - claude-agent-sdk installed + +Run: + python -m pytest tests/test_adapters/test_claude_code_integration.py -v -s + +Debug (F5 in IDE): + python tests/test_adapters/test_claude_code_integration.py + python tests/test_adapters/test_claude_code_integration.py TestCLIAdapter.test_single_turn +""" +import asyncio +import json +import os +import sys + +import pytest + +from lagent.adapters.claude_code import ClaudeCodeAdapter +from lagent.adapters.claude_code_sdk import ClaudeCodeSDKAdapter +from lagent.adapters.proxy import LLMProxyRecorder +from lagent.actions.external_agent import ExternalAgentAction +from lagent.schema import AgentMessage + +WORKDIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +# Skip all tests if env not configured (only for pytest) +pytestmark = pytest.mark.skipif( + not os.environ.get('ANTHROPIC_AUTH_TOKEN'), + reason="ANTHROPIC_AUTH_TOKEN not set", +) + + +# ── Helpers ────────────────────────────────────────────────────── + +def make_proxy(): + return LLMProxyRecorder( + real_api_key=os.environ.get('ANTHROPIC_AUTH_TOKEN', ''), + real_base_url=os.environ.get('ANTHROPIC_BASE_URL', ''), + ) + + +def log(msg): + print(f" → {msg}") + + +# ── CLI Adapter Tests ──────────────────────────────────────────── + +class TestCLIAdapter: + + @pytest.mark.asyncio + async def test_single_turn(self): + """CLI adapter returns correct result for a simple question.""" + agent = ClaudeCodeAdapter( + max_turns=3, timeout=60, working_dir=WORKDIR, + ) + result = await agent("What is 3+4? Answer with just the number.") + log(f"result: {result.content}") + assert isinstance(result, AgentMessage) + assert '7' in result.content + + @pytest.mark.asyncio + async def test_single_turn_with_proxy(self): + """CLI adapter + Proxy captures LLM trace in state_dict.""" + proxy = make_proxy() + await proxy.start() + try: + agent = ClaudeCodeAdapter( + max_turns=3, timeout=60, proxy=proxy, working_dir=WORKDIR, + ) + result = await agent("What is 5+5? Answer with just the number.") + log(f"result: {result.content}") + + state = agent.state_dict() + log(f"state keys: {list(state.keys())}") + assert 'memory' in state + assert 'llm_trace' in state + + trace = state['llm_trace'] + log(f"llm_trace records: {len(trace)}") + assert len(trace) >= 1 + + # Verify trace structure + rec = trace[0] + assert 'timestamp' in rec + assert 'request' in rec + assert 'response' in rec + assert rec['request'].get('model') is not None + assert rec['request'].get('messages') is not None + assert rec['response'].get('content') is not None + assert rec['response'].get('usage') is not None + log(f"usage: {json.dumps(rec['response']['usage'])}") + finally: + await proxy.stop() + + @pytest.mark.asyncio + async def test_multiturn_continue(self): + """CLI adapter uses --continue for real multi-turn.""" + agent = ClaudeCodeAdapter( + max_turns=3, timeout=60, working_dir=WORKDIR, + ) + r1 = await agent("Remember the number 99. Just say 'OK'.") + log(f"turn 1: {r1.content}") + assert agent._call_count == 1 + + r2 = await agent("What number did I say? Just the number.") + log(f"turn 2: {r2.content}") + assert agent._call_count == 2 + assert '99' in r2.content + + @pytest.mark.asyncio + async def test_multiturn_with_proxy_trace(self): + """CLI multi-turn + Proxy shows message history growing.""" + proxy = make_proxy() + await proxy.start() + try: + agent = ClaudeCodeAdapter( + max_turns=3, timeout=60, proxy=proxy, working_dir=WORKDIR, + ) + r1 = await agent("Remember: color=red. Just say 'OK'.") + log(f"turn 1: {r1.content}") + r2 = await agent("What color? Just the color.") + log(f"turn 2: {r2.content}") + + trace = agent.state_dict()['llm_trace'] + log(f"llm_trace: {len(trace)} calls") + assert len(trace) >= 2 + + # Second call should have more messages (history) + msgs_first = len(trace[0]['request']['messages']) + msgs_last = len(trace[-1]['request']['messages']) + log(f"messages: first={msgs_first} last={msgs_last}") + assert msgs_last > msgs_first + finally: + await proxy.stop() + + @pytest.mark.asyncio + async def test_memory_accumulates(self): + """Memory stores input+output for each turn.""" + agent = ClaudeCodeAdapter( + max_turns=3, timeout=60, working_dir=WORKDIR, + ) + r1 = await agent("Say 'hello'.") + log(f"turn 1: {r1.content}") + r2 = await agent("Say 'world'.") + log(f"turn 2: {r2.content}") + + memory = agent.memory.get_memory() + log(f"memory: {len(memory)} entries") + assert len(memory) == 4 # 2 turns × (input + output) + assert memory[0].sender == 'user' + assert memory[1].sender == 'claude-code' + assert memory[2].sender == 'user' + assert memory[3].sender == 'claude-code' + + +# ── SDK Adapter Tests ──────────────────────────────────────────── + +class TestSDKAdapter: + + @pytest.mark.asyncio + async def test_single_turn(self): + """SDK adapter returns correct result.""" + agent = ClaudeCodeSDKAdapter( + max_turns=3, timeout=60, cwd=WORKDIR, + ) + result = await agent("What is 6*7? Answer with just the number.") + log(f"result: {result.content}") + assert isinstance(result, AgentMessage) + assert '42' in result.content + + @pytest.mark.asyncio + async def test_sdk_trace_structure(self): + """SDK trace contains structured events.""" + agent = ClaudeCodeSDKAdapter( + max_turns=3, timeout=60, cwd=WORKDIR, + ) + result = await agent("Say 'test'. Just the word.") + log(f"result: {result.content}") + + state = agent.state_dict() + assert 'sdk_trace' in state + assert 'claude_session_id' in state + log(f"session_id: {state['claude_session_id']}") + + trace = state['sdk_trace'] + types = [e['type'] for e in trace] + log(f"event types: {types}") + assert 'AssistantMessage' in types + assert 'ResultMessage' in types + + # Check ResultMessage has cost info + result_events = [e for e in trace if e['type'] == 'ResultMessage'] + assert len(result_events) >= 1 + log(f"cost: ${result_events[0].get('total_cost_usd')}") + assert result_events[0].get('total_cost_usd') is not None + assert result_events[0].get('session_id') is not None + + @pytest.mark.asyncio + async def test_multiturn_resume(self): + """SDK adapter uses session_id for real multi-turn.""" + agent = ClaudeCodeSDKAdapter( + max_turns=3, timeout=60, cwd=WORKDIR, + ) + r1 = await agent("Remember the word 'banana'. Just say 'OK'.") + log(f"turn 1: {r1.content}") + + # session_id should be captured (from ResultMessage or AssistantMessage) + assert agent._session_id is not None, ( + f"session_id not captured. sdk_trace types: " + f"{[e.get('type') for e in agent._sdk_trace]}" + ) + session_id = agent._session_id + log(f"session_id: {session_id}") + + r2 = await agent("What word did I say? Just the word.") + log(f"turn 2: {r2.content}") + assert 'banana' in r2.content.lower() + # Same session + assert agent._session_id == session_id + + @pytest.mark.asyncio + async def test_sdk_with_proxy(self): + """SDK + Proxy gives both sdk_trace and llm_trace.""" + proxy = make_proxy() + await proxy.start() + try: + agent = ClaudeCodeSDKAdapter( + max_turns=3, timeout=60, proxy=proxy, cwd=WORKDIR, + ) + result = await agent("What is 2+2? Just the number.") + log(f"result: {result.content}") + + state = agent.state_dict() + log(f"state keys: {list(state.keys())}") + assert 'memory' in state + assert 'sdk_trace' in state + assert 'llm_trace' in state + assert 'claude_session_id' in state + + # Both traces populated + log(f"sdk_trace: {len(state['sdk_trace'])} events") + log(f"llm_trace: {len(state['llm_trace'])} calls") + assert len(state['sdk_trace']) >= 1 + assert len(state['llm_trace']) >= 1 + + # LLM trace has full request/response + rec = state['llm_trace'][0] + assert 'messages' in rec['request'] + assert 'content' in rec['response'] + assert 'usage' in rec['response'] + finally: + await proxy.stop() + + @pytest.mark.asyncio + async def test_sdk_proxy_multiturn(self): + """SDK + Proxy multi-turn: both traces grow, messages accumulate.""" + proxy = make_proxy() + await proxy.start() + try: + agent = ClaudeCodeSDKAdapter( + max_turns=3, timeout=60, proxy=proxy, cwd=WORKDIR, + ) + r1 = await agent("Remember: animal=cat. Just say 'OK'.") + log(f"turn 1: {r1.content}") + r2 = await agent("What animal? Just the animal.") + log(f"turn 2: {r2.content}") + + state = agent.state_dict() + sdk_trace = state['sdk_trace'] + llm_trace = state['llm_trace'] + + # SDK: 2 turns → at least 2 ResultMessages + result_msgs = [e for e in sdk_trace if e['type'] == 'ResultMessage'] + log(f"ResultMessages: {len(result_msgs)}") + assert len(result_msgs) >= 2 + assert result_msgs[0]['call_index'] == 0 + assert result_msgs[1]['call_index'] == 1 + + # Proxy: messages grow across turns + log(f"llm_trace calls: {len(llm_trace)}") + assert len(llm_trace) >= 2 + msgs_first = len(llm_trace[0]['request']['messages']) + msgs_last = len(llm_trace[-1]['request']['messages']) + log(f"messages: first={msgs_first} last={msgs_last}") + assert msgs_last > msgs_first + + # Training samples: multi-turn aggregated into one sample + samples = proxy.to_training_samples(agent.session_id) + log(f"training samples: {len(samples)}") + assert len(samples) >= 1 + + sample = samples[0] + log(f"sample messages: {len(sample['messages'])}") + log(f"sample response blocks: {len(sample['response'])}") + log(f"sample num_calls: {sample['meta']['num_calls']}") + log(f"sample total_usage: {sample['meta']['total_usage']}") + + # The sample's messages should be the full conversation + assert len(sample['messages']) >= msgs_last + # Should have aggregated usage + assert sample['meta']['total_usage']['total_input_tokens'] > 0 + assert sample['meta']['total_usage']['total_output_tokens'] > 0 + assert sample['meta']['num_calls'] >= 2 + finally: + await proxy.stop() + +class TestExternalAgentAction: + + @pytest.mark.asyncio + async def test_action_with_cli(self): + """ExternalAgentAction wrapping CLI adapter.""" + adapter = ClaudeCodeAdapter( + max_turns=3, timeout=60, working_dir=WORKDIR, + ) + action = ExternalAgentAction(adapters={"claude": adapter}) + + result = await action( + '{"agent_name": "claude", "task": "What is 8+9? Just the number."}', + 'run_agent', + ) + log(f"result: {result.result[0]['content'][:80] if result.result else result.errmsg}") + assert result.state == 0 + assert '17' in result.result[0]['content'] + + @pytest.mark.asyncio + async def test_action_with_sdk(self): + """ExternalAgentAction wrapping SDK adapter.""" + adapter = ClaudeCodeSDKAdapter( + max_turns=3, timeout=60, cwd=WORKDIR, + ) + action = ExternalAgentAction(adapters={"claude-sdk": adapter}) + + result = await action( + '{"agent_name": "claude-sdk", "task": "What is 3*5? Just the number."}', + 'run_agent', + ) + log(f"result: {result.result[0]['content'][:80] if result.result else result.errmsg}") + assert result.state == 0 + assert '15' in result.result[0]['content'] + + @pytest.mark.asyncio + async def test_action_list_agents(self): + """list_agents returns registered adapters.""" + cli = ClaudeCodeAdapter(max_turns=3, timeout=60, working_dir=WORKDIR) + sdk = ClaudeCodeSDKAdapter(max_turns=3, timeout=60, cwd=WORKDIR) + action = ExternalAgentAction(adapters={"cli": cli, "sdk": sdk}) + + result = await action('{}', 'list_agents') + log(f"agents: {result.result[0]['content']}") + assert result.state == 0 + assert 'cli' in result.result[0]['content'] + assert 'sdk' in result.result[0]['content'] + + +# ── State Dict Comparison ──────────────────────────────────────── + +class TestStateDictComparison: + + @pytest.mark.asyncio + async def test_cli_and_sdk_memory_structure_matches(self): + """Both adapters produce same memory structure.""" + cli_agent = ClaudeCodeAdapter( + max_turns=3, timeout=60, working_dir=WORKDIR, + ) + sdk_agent = ClaudeCodeSDKAdapter( + max_turns=3, timeout=60, cwd=WORKDIR, + ) + + r1 = await cli_agent("Say 'ping'.") + log(f"CLI: {r1.content}") + r2 = await sdk_agent("Say 'ping'.") + log(f"SDK: {r2.content}") + + cli_state = cli_agent.state_dict() + sdk_state = sdk_agent.state_dict() + + # Both have memory with same structure + assert 'memory' in cli_state + assert 'memory' in sdk_state + log(f"CLI memory: {len(cli_state['memory'])} entries") + log(f"SDK memory: {len(sdk_state['memory'])} entries") + assert len(cli_state['memory']) == 2 # input + output + assert len(sdk_state['memory']) == 2 + + # Both have sender fields + assert cli_state['memory'][0]['sender'] == 'user' + assert sdk_state['memory'][0]['sender'] == 'user' + + @pytest.mark.asyncio + async def test_proxy_trace_structure_same_for_both(self): + """When both use Proxy, llm_trace has same structure.""" + proxy_cli = make_proxy() + proxy_sdk = make_proxy() + await proxy_cli.start() + await proxy_sdk.start() + try: + cli_agent = ClaudeCodeAdapter( + max_turns=3, timeout=60, proxy=proxy_cli, working_dir=WORKDIR, + ) + sdk_agent = ClaudeCodeSDKAdapter( + max_turns=3, timeout=60, proxy=proxy_sdk, cwd=WORKDIR, + ) + + r1 = await cli_agent("What is 1+1? Just the number.") + log(f"CLI: {r1.content}") + r2 = await sdk_agent("What is 1+1? Just the number.") + log(f"SDK: {r2.content}") + + cli_trace = cli_agent.state_dict()['llm_trace'] + sdk_trace = sdk_agent.state_dict()['llm_trace'] + + # Both have at least 1 record + log(f"CLI trace: {len(cli_trace)} records") + log(f"SDK trace: {len(sdk_trace)} records") + assert len(cli_trace) >= 1 + assert len(sdk_trace) >= 1 + + # Same keys in each record + cli_keys = set(cli_trace[0].keys()) + sdk_keys = set(sdk_trace[0].keys()) + log(f"CLI keys: {sorted(cli_keys)}") + log(f"SDK keys: {sorted(sdk_keys)}") + assert cli_keys == sdk_keys + + # Both have request.messages and response.content + for label, trace in [("CLI", cli_trace), ("SDK", sdk_trace)]: + assert 'messages' in trace[0]['request'] + assert 'content' in trace[0]['response'] + assert 'usage' in trace[0]['response'] + log(f"{label} usage: {json.dumps(trace[0]['response']['usage'])}") + finally: + await proxy_cli.stop() + await proxy_sdk.stop() + + +# ── F5 Debug Runner ────────────────────────────────────────────── + +async def _run_test(test_cls, method_name): + """Run a single test method with output.""" + obj = test_cls() + method = getattr(obj, method_name) + print(f"\n{'='*60}") + print(f" {test_cls.__name__}.{method_name}") + print(f" {method.__doc__}") + print(f"{'='*60}") + try: + await method() + print(f" ✅ PASSED") + except Exception as e: + print(f" ❌ FAILED: {e}") + import traceback + traceback.print_exc() + + +async def run_all(): + """Run all tests sequentially for F5 debugging.""" + test_classes = [ + TestCLIAdapter, + TestSDKAdapter, + TestExternalAgentAction, + TestStateDictComparison, + ] + for cls in test_classes: + for name in sorted(dir(cls)): + if name.startswith('test_'): + await _run_test(cls, name) + + print(f"\n{'='*60}") + print(" All tests completed!") + print(f"{'='*60}") + + +async def run_one(spec: str): + """Run a single test: 'TestCLIAdapter.test_single_turn'.""" + cls_name, method_name = spec.split('.') + cls = { + 'TestCLIAdapter': TestCLIAdapter, + 'TestSDKAdapter': TestSDKAdapter, + 'TestExternalAgentAction': TestExternalAgentAction, + 'TestStateDictComparison': TestStateDictComparison, + }[cls_name] + await _run_test(cls, method_name) + + +if __name__ == '__main__': + # if len(sys.argv) > 1: + # # python test_claude_code_integration.py TestCLIAdapter.test_single_turn + # asyncio.run(run_one(sys.argv[1])) + # else: + # # python test_claude_code_integration.py → run all + # asyncio.run(run_all()) + asyncio.run(run_one('TestSDKAdapter')) \ No newline at end of file diff --git a/tests/test_adapters/test_external_agents.py b/tests/test_adapters/test_external_agents.py new file mode 100644 index 00000000..efd197b9 --- /dev/null +++ b/tests/test_adapters/test_external_agents.py @@ -0,0 +1,207 @@ +"""Tests for external agent adapters.""" +import asyncio +import json +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from lagent.adapters.base import AsyncExternalAgent, BaseExternalAgent +from lagent.adapters.cli_adapter import CLIAgentAdapter +from lagent.adapters.sdk_adapter import SDKAgentAdapter +from lagent.actions.external_agent import ExternalAgentAction +from lagent.schema import AgentMessage + + +# ── Base adapter tests ────────────────────────────────────────── + + +class ConcreteExternalAgent(AsyncExternalAgent): + """Minimal concrete implementation for testing.""" + + def setup(self): + pass + + async def run_external_async(self, task, **kwargs): + return f"result: {task}" + + +class TestBaseExternalAgent: + + def test_init_assigns_session_id(self): + agent = ConcreteExternalAgent(name="test") + assert agent.session_id + assert len(agent.session_id) == 8 + + def test_init_no_llm(self): + agent = ConcreteExternalAgent(name="test") + assert agent.llm is None + + @pytest.mark.asyncio + async def test_forward_returns_agent_message(self): + agent = ConcreteExternalAgent(name="test-agent") + result = await agent("hello world") + assert isinstance(result, AgentMessage) + assert result.content == "result: hello world" + assert result.sender == "test-agent" + + @pytest.mark.asyncio + async def test_forward_stores_in_memory(self): + agent = ConcreteExternalAgent(name="test") + await agent("task1") + msgs = agent.memory.get_memory() + assert len(msgs) == 2 # input + output + + def test_state_dict_without_proxy(self): + agent = ConcreteExternalAgent(name="test") + state = agent.state_dict() + assert 'memory' in state + assert 'llm_trace' not in state + + def test_state_dict_with_proxy(self): + mock_proxy = MagicMock() + mock_proxy.get_records.return_value = [{"test": "record"}] + agent = ConcreteExternalAgent(name="test", proxy=mock_proxy) + state = agent.state_dict() + assert 'llm_trace' in state + assert state['llm_trace'] == [{"test": "record"}] + + def test_build_env_without_proxy(self): + agent = ConcreteExternalAgent( + name="test", + env_vars={"MY_VAR": "123"}, + ) + env = agent._build_env() + assert env["MY_VAR"] == "123" + assert "OPENAI_BASE_URL" not in env or "sk-proxy" not in env.get("OPENAI_API_KEY", "") + + def test_build_env_with_proxy(self): + mock_proxy = MagicMock() + mock_proxy.url = "http://127.0.0.1:9999" + agent = ConcreteExternalAgent(name="test", proxy=mock_proxy) + env = agent._build_env() + assert env["OPENAI_BASE_URL"] == "http://127.0.0.1:9999" + assert env["OPENAI_API_KEY"].startswith("sk-proxy-") + assert env["ANTHROPIC_BASE_URL"] == "http://127.0.0.1:9999" + + +# ── CLI adapter tests ─────────────────────────────────────────── + + +class TestCLIAgentAdapter: + + def test_setup_finds_echo(self): + adapter = CLIAgentAdapter( + name="test-echo", + command_template="echo '{task}'", + ) + adapter.setup() # should not raise + + def test_setup_raises_for_missing_binary(self): + adapter = CLIAgentAdapter( + name="test-missing", + command_template="nonexistent_binary_xyz '{task}'", + ) + with pytest.raises(RuntimeError, match="not found on PATH"): + adapter.setup() + + @pytest.mark.asyncio + async def test_run_echo(self): + adapter = CLIAgentAdapter( + name="echo-agent", + command_template="echo '{task}'", + timeout=5, + ) + result = await adapter("hello") + assert isinstance(result, AgentMessage) + assert "hello" in result.content + + @pytest.mark.asyncio + async def test_timeout(self): + adapter = CLIAgentAdapter( + name="slow-agent", + command_template="sleep 10 && echo '{task}'", + timeout=0.5, + ) + result = await adapter("test") + assert "failed" in result.content.lower() or "timed out" in result.content.lower() + + @pytest.mark.asyncio + async def test_custom_parse_output(self): + def my_parser(stdout, stderr): + return f"PARSED: {stdout.strip()}" + + adapter = CLIAgentAdapter( + name="parse-test", + command_template="echo '{task}'", + parse_output=my_parser, + ) + result = await adapter("data") + assert "PARSED:" in result.content + + +# ── SDK adapter tests ─────────────────────────────────────────── + + +class MockSDKAdapter(SDKAgentAdapter): + """Mock SDK adapter for testing.""" + + def create_sdk_agent(self, config): + return {"model": config.get("model", "test")} + + def invoke_sdk_agent(self, agent, task, **kwargs): + return f"SDK({agent['model']}): {task}" + + +class TestSDKAgentAdapter: + + @pytest.mark.asyncio + async def test_basic_invoke(self): + adapter = MockSDKAdapter( + name="mock-sdk", + sdk_config={"model": "gpt-4"}, + ) + result = await adapter("test task") + assert isinstance(result, AgentMessage) + assert "SDK(gpt-4): test task" in result.content + + def test_setup_with_invalid_module(self): + adapter = MockSDKAdapter( + name="bad-sdk", + sdk_module="nonexistent_module_xyz", + ) + with pytest.raises(RuntimeError, match="not importable"): + adapter.setup() + + +# ── ExternalAgentAction tests ─────────────────────────────────── + + +class TestExternalAgentAction: + + @pytest.mark.asyncio + async def test_run_agent_success(self): + adapter = ConcreteExternalAgent(name="test-ext") + action = ExternalAgentAction(adapters={"test-ext": adapter}) + result = await action('{"agent_name": "test-ext", "task": "do stuff"}', 'run_agent') + assert result.state == 0 # SUCCESS + + @pytest.mark.asyncio + async def test_run_agent_unknown(self): + action = ExternalAgentAction(adapters={}) + result = await action('{"agent_name": "nope", "task": "x"}', 'run_agent') + assert result.errmsg + assert "Unknown agent" in result.errmsg + + @pytest.mark.asyncio + async def test_list_agents(self): + adapter = ConcreteExternalAgent(name="a1", description="Agent One") + action = ExternalAgentAction(adapters={"a1": adapter}) + result = await action('{}', 'list_agents') + assert result.result + assert "a1" in result.result[0]["content"] + + @pytest.mark.asyncio + async def test_list_agents_empty(self): + action = ExternalAgentAction(adapters={}) + result = await action('{}', 'list_agents') + assert "No external agents" in result.result[0]["content"] diff --git a/tests/test_adapters/test_openai_agents.py b/tests/test_adapters/test_openai_agents.py new file mode 100644 index 00000000..2924acbe --- /dev/null +++ b/tests/test_adapters/test_openai_agents.py @@ -0,0 +1,185 @@ +"""Test OpenAI Agents adapter — validates the SDKAgentAdapter pattern +works with a real framework beyond Claude Code. + +Run: + python tests/test_adapters/test_openai_agents.py + python -m pytest tests/test_adapters/test_openai_agents.py -v -s + +Note: requires OPENAI_API_KEY or a compatible API endpoint. +""" +import asyncio +import os +import sys + +import pytest + +from lagent.adapters.openai_agents import OpenAIAgentsAdapter +from lagent.adapters.proxy import LLMProxyRecorder +from lagent.schema import AgentMessage + +WORKDIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +# Use Anthropic endpoint as OpenAI-compatible if no OpenAI key +ANTHROPIC_BASE = os.environ.get('ANTHROPIC_BASE_URL', '') +ANTHROPIC_KEY = os.environ.get('ANTHROPIC_AUTH_TOKEN', '') + +pytestmark = pytest.mark.skipif( + not os.environ.get('OPENAI_API_KEY') and not ANTHROPIC_KEY, + reason="No API key available", +) + + +def log(msg): + print(f" → {msg}") + + +class TestOpenAIAgentsAdapter: + + @pytest.mark.asyncio + async def test_single_turn(self): + """Single turn with OpenAI Agents SDK.""" + agent = OpenAIAgentsAdapter( + model='gpt-4o-mini', + instructions='You are a math tutor. Answer with just the number.', + max_turns=3, + timeout=60, + ) + result = await agent("What is 7+8?") + log(f"result: {result.content}") + assert isinstance(result, AgentMessage) + assert '15' in result.content + + @pytest.mark.asyncio + async def test_multiturn(self): + """Real multi-turn via RunResult chaining.""" + agent = OpenAIAgentsAdapter( + model='gpt-4o-mini', + instructions='You are a helpful assistant. Be very brief.', + max_turns=3, + timeout=60, + ) + r1 = await agent("Remember: my lucky number is 77. Just say OK.") + log(f"turn 1: {r1.content}") + + r2 = await agent("What is my lucky number? Just the number.") + log(f"turn 2: {r2.content}") + assert '77' in r2.content + + @pytest.mark.asyncio + async def test_with_proxy(self): + """Proxy captures LLM trace from OpenAI Agents.""" + proxy = LLMProxyRecorder( + real_api_key=os.environ.get('OPENAI_API_KEY', ''), + real_base_url='https://api.openai.com', + ) + await proxy.start() + try: + agent = OpenAIAgentsAdapter( + model='gpt-4o-mini', + instructions='Answer briefly.', + max_turns=3, + timeout=60, + proxy=proxy, + ) + result = await agent("What is 3*3? Just the number.") + log(f"result: {result.content}") + + state = agent.state_dict() + log(f"state keys: {list(state.keys())}") + trace = state.get('llm_trace', []) + log(f"llm_trace: {len(trace)} records") + + assert len(trace) >= 1 + rec = trace[0] + assert 'messages' in rec['request'] + assert rec['response'] is not None + finally: + await proxy.stop() + + @pytest.mark.asyncio + async def test_multiturn_with_proxy(self): + """Multi-turn + Proxy: messages grow, training sample correct.""" + proxy = LLMProxyRecorder( + real_api_key=os.environ.get('OPENAI_API_KEY', ''), + real_base_url='https://api.openai.com', + ) + await proxy.start() + try: + agent = OpenAIAgentsAdapter( + model='gpt-4o-mini', + instructions='Be very brief.', + max_turns=3, + timeout=60, + proxy=proxy, + ) + r1 = await agent("Remember: color=green. Just say OK.") + log(f"turn 1: {r1.content}") + r2 = await agent("What color? Just the word.") + log(f"turn 2: {r2.content}") + + trace = proxy.get_records(agent.session_id) + log(f"total LLM calls: {len(trace)}") + for i, rec in enumerate(trace): + msgs = rec['request'].get('messages', []) + log(f" call {i+1}: {len(msgs)} messages") + + # Messages should grow + if len(trace) >= 2: + msgs_first = len(trace[0]['request']['messages']) + msgs_last = len(trace[-1]['request']['messages']) + log(f"messages grew: {msgs_first} → {msgs_last}") + assert msgs_last > msgs_first + + # Training samples + samples = proxy.to_training_samples(agent.session_id) + log(f"training samples: {len(samples)}") + if samples: + log(f"sample[0]: {len(samples[0]['messages'])} msgs, {samples[0]['num_calls']} calls") + finally: + await proxy.stop() + + @pytest.mark.asyncio + async def test_state_dict_structure(self): + """state_dict has memory + trace.""" + agent = OpenAIAgentsAdapter( + model='gpt-4o-mini', + max_turns=3, timeout=60, + ) + await agent("Say hello.") + + state = agent.state_dict() + assert 'memory' in state + assert len(state['memory']) == 2 # input + output + log(f"memory: {len(state['memory'])} entries") + log(f"sender[0]: {state['memory'][0]['sender']}") + log(f"sender[1]: {state['memory'][1]['sender']}") + + +# ── F5 Runner ──────────────────────────────────────────────────── + +async def _run_test(cls, name): + obj = cls() + print(f"\n{'='*60}") + print(f" {cls.__name__}.{name}") + print(f"{'='*60}") + try: + await getattr(obj, name)() + print(f" ✅ PASSED") + except Exception as e: + print(f" ❌ FAILED: {e}") + import traceback + traceback.print_exc() + + +async def run_all(): + for name in sorted(dir(TestOpenAIAgentsAdapter)): + if name.startswith('test_'): + await _run_test(TestOpenAIAgentsAdapter, name) + print(f"\n{'='*60}\n Done!\n{'='*60}") + + +if __name__ == '__main__': + if len(sys.argv) > 1: + asyncio.run(_run_test(TestOpenAIAgentsAdapter, sys.argv[1])) + else: + asyncio.run(run_all()) diff --git a/tests/test_adapters/test_openai_agents_live.py b/tests/test_adapters/test_openai_agents_live.py new file mode 100644 index 00000000..129c0f92 --- /dev/null +++ b/tests/test_adapters/test_openai_agents_live.py @@ -0,0 +1,171 @@ +"""Test OpenAI Agents adapter with InternClaw's API endpoint.""" +import asyncio +import json +import os +import sys + +from lagent.adapters.openai_agents import OpenAIAgentsAdapter +from lagent.adapters.proxy import LLMProxyRecorder + +API_KEY = ' ' +BASE_URL = 'http://35.220.164.252:3888/v1' +HTTP_PROXY = 'http://100.100.72.89:8899' + + +def log(msg): + print(f" → {msg}") + + +async def test_single_turn(): + print("\n" + "=" * 60) + print(" OpenAI Agents: Single Turn") + print("=" * 60) + + agent = OpenAIAgentsAdapter( + model='gpt-4o-mini', + instructions='Answer with just the number, nothing else.', + max_turns=3, + timeout=30, + api_key=API_KEY, + base_url=BASE_URL, + http_proxy=HTTP_PROXY, + ) + + result = await agent("What is 7+8?") + log(f"result: {result.content}") + log(f"sender: {result.sender}") + assert '15' in result.content + print(" ✅ PASSED") + + +async def test_multiturn(): + print("\n" + "=" * 60) + print(" OpenAI Agents: Multi-Turn") + print("=" * 60) + + agent = OpenAIAgentsAdapter( + model='gpt-4o-mini', + instructions='You are helpful. Be very brief.', + max_turns=3, + timeout=30, + api_key=API_KEY, + base_url=BASE_URL, + http_proxy=HTTP_PROXY, + ) + + r1 = await agent("Remember: my pet's name is Muffin. Just say OK.") + log(f"turn 1: {r1.content}") + + r2 = await agent("What is my pet's name? Just the name.") + log(f"turn 2: {r2.content}") + assert 'muffin' in r2.content.lower() + + log(f"memory: {len(agent.memory.get_memory())} entries") + print(" ✅ PASSED") + + +async def test_with_proxy(): + print("\n" + "=" * 60) + print(" OpenAI Agents: With LLM Proxy") + print("=" * 60) + + proxy = LLMProxyRecorder( + real_api_key=API_KEY, + real_base_url=BASE_URL, + ) + await proxy.start() + log(f"proxy on: {proxy.url}") + + try: + agent = OpenAIAgentsAdapter( + model='gpt-4o-mini', + instructions='Be brief.', + max_turns=3, + timeout=30, + api_key=API_KEY, + base_url=BASE_URL, + http_proxy=HTTP_PROXY, + proxy=proxy, + ) + + r1 = await agent("Remember: city=Tokyo. Just say OK.") + log(f"turn 1: {r1.content}") + r2 = await agent("What city? Just the name.") + log(f"turn 2: {r2.content}") + + state = agent.state_dict() + log(f"state keys: {list(state.keys())}") + + trace = state.get('llm_trace', []) + log(f"llm_trace: {len(trace)} LLM calls") + for i, rec in enumerate(trace): + msgs = rec['request'].get('messages', []) + resp = rec.get('response') or {} + usage = resp.get('usage', {}) + log(f" call {i+1}: {len(msgs)} msgs, " + f"in={usage.get('prompt_tokens', '?')} " + f"out={usage.get('completion_tokens', '?')}") + + # Training samples + samples = proxy.to_training_samples(agent.session_id) + log(f"training samples: {len(samples)}") + if samples: + s = samples[0] + log(f" sample: {len(s['messages'])} msgs, {s['num_calls']} calls, model={s['model']}") + + log(f"memory: {len(agent.memory.get_memory())} entries") + print(" ✅ PASSED") + finally: + await proxy.stop() + + +async def test_state_dict_comparison_with_claude(): + print("\n" + "=" * 60) + print(" Comparison: OpenAI Agents vs Claude Code state_dict") + print("=" * 60) + + from lagent.adapters.claude_code_sdk import ClaudeCodeSDKAdapter + + oai = OpenAIAgentsAdapter( + model='gpt-4o-mini', + instructions='Be brief.', + max_turns=3, timeout=30, + api_key=API_KEY, base_url=BASE_URL, http_proxy=HTTP_PROXY, + ) + claude = ClaudeCodeSDKAdapter( + max_turns=3, timeout=60, + cwd=os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), + ) + + await oai("Say 'hello'.") + await claude("Say 'hello'.") + + oai_state = oai.state_dict() + claude_state = claude.state_dict() + + log(f"OpenAI keys: {sorted(oai_state.keys())}") + log(f"Claude keys: {sorted(claude_state.keys())}") + + # Both have memory with same structure + log(f"OpenAI memory: {len(oai_state['memory'])} entries") + log(f"Claude memory: {len(claude_state['memory'])} entries") + assert len(oai_state['memory']) == 2 + assert len(claude_state['memory']) == 2 + + log(f"OpenAI sender: {oai_state['memory'][1]['sender']}") + log(f"Claude sender: {claude_state['memory'][1]['sender']}") + print(" ✅ PASSED") + + +async def main(): + await test_single_turn() + await test_multiturn() + await test_with_proxy() + await test_state_dict_comparison_with_claude() + print(f"\n{'='*60}") + print(" All OpenAI Agents tests passed!") + print(f"{'='*60}") + + +if __name__ == '__main__': + asyncio.run(main()) diff --git a/tests/test_adapters/test_openai_chat_live.py b/tests/test_adapters/test_openai_chat_live.py new file mode 100644 index 00000000..41c8e667 --- /dev/null +++ b/tests/test_adapters/test_openai_chat_live.py @@ -0,0 +1,179 @@ +"""Test OpenAI Chat adapter with InternClaw's API endpoint. + +Run: + python tests/test_adapters/test_openai_chat_live.py +""" +import asyncio +import json +import os +import sys + +from lagent.adapters.openai_chat import OpenAIChatAdapter +from lagent.adapters.proxy import LLMProxyRecorder + +API_KEY = ' ' +BASE_URL = 'http://35.220.164.252:3888/v1' +HTTP_PROXY = 'http://100.100.72.89:8899' + + +def log(msg): + print(f" → {msg}") + + +async def test_single_turn(): + print("\n" + "=" * 60) + print(" OpenAI Chat: Single Turn") + print("=" * 60) + + agent = OpenAIChatAdapter( + model='gpt-4o-mini', + api_key=API_KEY, base_url=BASE_URL, http_proxy=HTTP_PROXY, + system_prompt='Answer with just the number, nothing else.', + timeout=30, + ) + result = await agent("What is 7+8?") + log(f"result: {result.content}") + assert '15' in result.content + print(" ✅ PASSED") + + +async def test_multiturn(): + print("\n" + "=" * 60) + print(" OpenAI Chat: Multi-Turn (real)") + print("=" * 60) + + agent = OpenAIChatAdapter( + model='gpt-4o-mini', + api_key=API_KEY, base_url=BASE_URL, http_proxy=HTTP_PROXY, + system_prompt='You are helpful. Be very brief.', + timeout=30, + ) + + r1 = await agent("Remember: my pet is called Muffin. Just say OK.") + log(f"turn 1: {r1.content}") + + r2 = await agent("What is my pet's name? Just the name.") + log(f"turn 2: {r2.content}") + assert 'muffin' in r2.content.lower() + + r3 = await agent("Say it in uppercase.") + log(f"turn 3: {r3.content}") + assert 'MUFFIN' in r3.content.upper() + + log(f"memory: {len(agent.memory.get_memory())} entries") + log(f"internal messages: {len(agent._messages)}") + print(" ✅ PASSED") + + +async def test_with_proxy(): + print("\n" + "=" * 60) + print(" OpenAI Chat: With LLM Proxy") + print("=" * 60) + + proxy = LLMProxyRecorder( + real_api_key=API_KEY, + real_base_url=BASE_URL, + http_proxy=HTTP_PROXY, + ) + await proxy.start() + log(f"proxy on: {proxy.url}") + + try: + agent = OpenAIChatAdapter( + model='gpt-4o-mini', + api_key=API_KEY, base_url=BASE_URL, http_proxy=HTTP_PROXY, + system_prompt='Be brief.', + timeout=30, + proxy=proxy, + ) + + r1 = await agent("Remember: city=Paris. Just say OK.") + log(f"turn 1: {r1.content}") + r2 = await agent("What city? Just the name.") + log(f"turn 2: {r2.content}") + + state = agent.state_dict() + log(f"state keys: {list(state.keys())}") + + trace = state.get('llm_trace', []) + log(f"llm_trace: {len(trace)} LLM calls") + for i, rec in enumerate(trace): + msgs = rec['request'].get('messages', []) + resp = rec.get('response') or {} + usage = resp.get('usage', {}) + log(f" call {i+1}: {len(msgs)} msgs, " + f"tokens: {json.dumps(usage)[:100]}") + + # Messages should grow + if len(trace) >= 2: + m1 = len(trace[0]['request']['messages']) + m2 = len(trace[-1]['request']['messages']) + log(f"messages grew: {m1} → {m2}") + assert m2 > m1 + + # Training samples + samples = proxy.to_training_samples(agent.session_id) + log(f"training samples: {len(samples)}") + if samples: + s = samples[0] + log(f" sample: {len(s['messages'])} msgs, {s['meta']['num_calls']} calls") + log(f" all_usage: {s['all_usage']}") + + print(" ✅ PASSED") + finally: + await proxy.stop() + + +async def test_comparison_with_claude(): + print("\n" + "=" * 60) + print(" Comparison: OpenAI Chat vs Claude Code SDK state_dict") + print("=" * 60) + + from lagent.adapters.claude_code_sdk import ClaudeCodeSDKAdapter + + oai = OpenAIChatAdapter( + model='gpt-4o-mini', + api_key=API_KEY, base_url=BASE_URL, http_proxy=HTTP_PROXY, + system_prompt='Be brief.', + timeout=30, + ) + claude = ClaudeCodeSDKAdapter( + max_turns=3, timeout=60, + cwd=os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), + ) + + await oai("Say 'hello'.") + await claude("Say 'hello'.") + + oai_state = oai.state_dict() + claude_state = claude.state_dict() + + log(f"OpenAI state keys: {sorted(oai_state.keys())}") + log(f"Claude state keys: {sorted(claude_state.keys())}") + log(f"OpenAI memory: {len(oai_state['memory'])} entries") + log(f"Claude memory: {len(claude_state['memory'])} entries") + + # Both have 2 memory entries (input + output) + assert len(oai_state['memory']) == 2 + assert len(claude_state['memory']) == 2 + assert oai_state['memory'][0]['sender'] == 'user' + assert claude_state['memory'][0]['sender'] == 'user' + + # Show content + log(f"OpenAI output: {oai_state['memory'][1]['content']}") + log(f"Claude output: {claude_state['memory'][1]['content']}") + print(" ✅ PASSED") + + +async def main(): + await test_single_turn() + await test_multiturn() + await test_with_proxy() + await test_comparison_with_claude() + print(f"\n{'='*60}") + print(" All tests passed!") + print(f"{'='*60}") + + +if __name__ == '__main__': + asyncio.run(main()) diff --git a/tests/test_adapters/test_openclaw_integration.py b/tests/test_adapters/test_openclaw_integration.py new file mode 100644 index 00000000..8b0cd0ff --- /dev/null +++ b/tests/test_adapters/test_openclaw_integration.py @@ -0,0 +1,246 @@ +"""Integration tests for OpenClaw adapter (CLI + Proxy). + +Requires: + - Node 22+ via nvm + - openclaw installed globally (npm install -g openclaw) + - ~/.openclaw/openclaw.json configured with a working provider + - http proxy at http://100.100.72.89:8899 for outbound access + +Run: + python tests/test_adapters/test_openclaw_integration.py + python -m pytest tests/test_adapters/test_openclaw_integration.py -v -s +""" +import asyncio +import json +import os +import pathlib +import shutil +import sys + +import pytest + +from lagent.adapters.openclaw import OpenClawAdapter +from lagent.adapters.proxy import LLMProxyRecorder +from lagent.schema import AgentMessage + +NVM_DIR = '/mnt/shared-storage-user/liukuikun/.nvm' +HTTP_PROXY = 'http://100.100.72.89:8899' +API_KEY = ' ' +REAL_BASE_URL = 'http://35.220.164.252:3888/v1' +OPENCLAW_HOME = pathlib.Path.home() / '.openclaw' + +# Skip if openclaw not set up +pytestmark = pytest.mark.skipif( + not (OPENCLAW_HOME / 'openclaw.json').exists(), + reason="OpenClaw not configured (~/.openclaw/openclaw.json missing)", +) + + +def log(msg): + print(f" → {msg}") + + +def _setup_openclaw_config(proxy_url=None): + """Write openclaw.json and auth-profiles for testing.""" + base_url = proxy_url or REAL_BASE_URL + config = { + 'models': {'mode': 'merge', 'providers': {'custom-openai': { + 'baseUrl': base_url, + 'api': 'openai-completions', + 'models': [{'id': 'gpt-4o-mini', 'name': 'GPT-4o Mini', + 'reasoning': False, 'input': ['text'], + 'contextWindow': 128000, 'maxTokens': 16384}], + }}}, + 'agents': {'defaults': {'model': {'primary': 'custom-openai/gpt-4o-mini'}}}, + } + (OPENCLAW_HOME / 'openclaw.json').write_text(json.dumps(config, indent=2)) + + key = f'sk-proxy-test' if proxy_url else API_KEY + auth = {'version': 1, 'profiles': {'default': { + 'type': 'api_key', 'provider': 'custom-openai', 'key': key, + }}} + auth_path = OPENCLAW_HOME / 'agents/main/agent/auth-profiles.json' + auth_path.parent.mkdir(parents=True, exist_ok=True) + auth_path.write_text(json.dumps(auth)) + + # Clear sessions and state + sd = OPENCLAW_HOME / 'agents/main/sessions' + if sd.exists(): + shutil.rmtree(sd) + sd.mkdir(parents=True, exist_ok=True) + (OPENCLAW_HOME / 'agents/main/agent/auth-state.json').unlink(missing_ok=True) + + +def _make_agent(**kwargs): + """Create an OpenClawAdapter with standard test config.""" + defaults = dict( + thinking='off', + timeout=120, + json_output=False, + nvm_dir=NVM_DIR, + env_vars={'http_proxy': HTTP_PROXY, 'https_proxy': HTTP_PROXY}, + ) + defaults.update(kwargs) + return OpenClawAdapter(**defaults) + + +class TestOpenClawAdapter: + + @pytest.mark.asyncio + async def test_single_turn(self): + """Single turn returns correct result.""" + _setup_openclaw_config() + agent = _make_agent() + result = await agent("What is 9+3? Just the number.") + log(f"result: {result.content}") + assert isinstance(result, AgentMessage) + assert '12' in result.content + + @pytest.mark.asyncio + async def test_multiturn(self): + """Real multi-turn via session-id.""" + _setup_openclaw_config() + agent = _make_agent() + + r1 = await agent("Remember: fruit=mango. Just say OK.") + log(f"turn 1: {r1.content}") + + r2 = await agent("What fruit? Just the name.") + log(f"turn 2: {r2.content}") + assert 'mango' in r2.content.lower() + + @pytest.mark.asyncio + async def test_memory_accumulates(self): + """Memory stores input+output for each turn.""" + _setup_openclaw_config() + agent = _make_agent() + await agent("Say 'alpha'.") + await agent("Say 'beta'.") + + memory = agent.memory.get_memory() + log(f"memory: {len(memory)} entries") + assert len(memory) == 4 + assert memory[0].sender == 'user' + assert memory[1].sender == 'openclaw' + + @pytest.mark.asyncio + async def test_with_proxy(self): + """Proxy captures LLM calls from OpenClaw.""" + proxy = LLMProxyRecorder( + real_api_key=API_KEY, + real_base_url=REAL_BASE_URL, + http_proxy=HTTP_PROXY, + ) + await proxy.start() + try: + _setup_openclaw_config(proxy_url=f'{proxy.url}/v1') + + # No http_proxy in env — OpenClaw talks to localhost proxy directly + agent = _make_agent(env_vars={}) + r1 = await agent("What is 5*6? Just the number.") + log(f"result: {r1.content}") + assert '30' in r1.content + + all_records = [] + for recs in proxy._records.values(): + all_records.extend(recs) + log(f"LLM calls: {len(all_records)}") + assert len(all_records) >= 1 + + # Verify response was parsed + proxy._records['test'] = all_records + norm = proxy.get_normalized_records('test') + for n in norm: + resp_text = n['response'].get('content', '') + log(f" response: '{resp_text[:50]}'") + assert any(n['response'].get('content', '') for n in norm) + finally: + await proxy.stop() + + @pytest.mark.asyncio + async def test_proxy_multiturn_training_sample(self): + """Multi-turn + Proxy produces correct training sample.""" + proxy = LLMProxyRecorder( + real_api_key=API_KEY, + real_base_url=REAL_BASE_URL, + http_proxy=HTTP_PROXY, + ) + await proxy.start() + try: + _setup_openclaw_config(proxy_url=f'{proxy.url}/v1') + + # No http_proxy — OpenClaw talks to localhost proxy directly + agent = _make_agent(env_vars={}) + r1 = await agent("Remember: planet=Mars. Just say OK.") + log(f"turn 1: {r1.content}") + r2 = await agent("What planet? Just the name.") + log(f"turn 2: {r2.content}") + assert 'mars' in r2.content.lower() + + all_records = [] + for recs in proxy._records.values(): + all_records.extend(recs) + log(f"LLM calls: {len(all_records)}") + + proxy._records['test'] = all_records + chains = proxy.rebuild_chains('test') + log(f"chains: {len(chains)}") + + samples = proxy.to_training_samples('test') + log(f"training samples: {len(samples)}") + assert len(samples) >= 1 + + s = samples[0] + log(f"sample: {len(s['messages'])} msgs, {s['meta']['num_calls']} calls") + assert s['meta']['num_calls'] >= 1 + assert len(s['messages']) >= 2 + + # Messages should have system + user + assistant pattern + roles = [m['role'] for m in s['messages']] + log(f"roles: {roles}") + assert 'assistant' in roles + assert 'user' in roles + finally: + await proxy.stop() + + @pytest.mark.asyncio + async def test_state_dict(self): + """state_dict has memory.""" + _setup_openclaw_config() + agent = _make_agent() + await agent("Say hello.") + + state = agent.state_dict() + log(f"state keys: {list(state.keys())}") + assert 'memory' in state + assert len(state['memory']) == 2 + + +# ── F5 Runner ──────────────────────────────────────────────────── + +async def _run_test(cls, name): + obj = cls() + print(f"\n{'='*60}") + print(f" {cls.__name__}.{name}") + print(f"{'='*60}") + try: + await getattr(obj, name)() + print(f" ✅ PASSED") + except Exception as e: + print(f" ❌ FAILED: {e}") + import traceback + traceback.print_exc() + + +async def run_all(): + for name in sorted(dir(TestOpenClawAdapter)): + if name.startswith('test_'): + await _run_test(TestOpenClawAdapter, name) + print(f"\n{'='*60}\n Done!\n{'='*60}") + + +if __name__ == '__main__': + if len(sys.argv) > 1: + asyncio.run(_run_test(TestOpenClawAdapter, sys.argv[1])) + else: + asyncio.run(run_all()) diff --git a/tests/test_agent_project.py b/tests/test_agent_project.py new file mode 100644 index 00000000..4fc49b45 --- /dev/null +++ b/tests/test_agent_project.py @@ -0,0 +1,409 @@ +"""Tests for agent project discovery (config.py / AGENT.md). + +Two discovery modes: + 1. config.py → "pyconfig" (agent_config dict) + 2. AGENT.md → "markdown" (YAML frontmatter) + +Run: + python tests/test_agent_project.py +""" + +import asyncio +import shutil +import sys +import tempfile +from pathlib import Path + +# ── Ensure lagent is importable ────────────────────────────────────── +ROOT = Path(__file__).resolve().parent.parent +if str(ROOT) not in sys.path: + sys.path.insert(0, str(ROOT)) + +from lagent.services.agent_loader import ( + AgentLoader, + AgentSpec, + _detect_kind, +) +from lagent.services.agent import ( + AgentService, +) + +# ── Helpers ────────────────────────────────────────────────────────── + +PASS = 0 +FAIL = 0 + + +def check(name: str, condition: bool, detail: str = ""): + global PASS, FAIL + if condition: + PASS += 1 + print(f" ✅ {name}") + else: + FAIL += 1 + print(f" ❌ {name}: {detail}") + + +# ── Fixtures: create temp agent projects ───────────────────────────── + +def create_test_workspace(tmp: Path) -> Path: + """Create a workspace with agent projects.""" + agents = tmp / "agents" + + # ── 1. PyConfig agent (config.py) ───────────────────────────── + reviewer = agents / "code-reviewer" + reviewer.mkdir(parents=True) + + (reviewer / "config.py").write_text( + """\ +# Agent project config — discovered by AgentLoader +from lagent.agents.agent import Agent + +agent_config = dict( + type=Agent, + name="code-reviewer", +) + +name = "code-reviewer" +description = "Reviews code for quality and best practices" +system_prompt = "You are a code reviewer." +max_turns = 50 +background = False +extra = {"tools": ["Read", "Grep"], "model": "gpt-4"} +""", + encoding="utf-8", + ) + + # ── 2. PyConfig agent with build() ──────────────────────────── + translator = agents / "translator" + translator.mkdir(parents=True) + + (translator / "config.py").write_text( + """\ +from lagent.agents.agent import Agent + +agent_config = dict( + type=Agent, + name="translator", +) + +name = "translator" +description = "Translates text between languages" +max_turns = 10 + +def build(config): + \"\"\"Custom build function.\"\"\" + from lagent.utils import create_object + agent = create_object(config) + return agent +""", + encoding="utf-8", + ) + + # ── 3. Markdown agent (AGENT.md) ────────────────────────────── + summarizer = agents / "summarizer" + summarizer.mkdir(parents=True) + + (summarizer / "AGENT.md").write_text( + """\ +--- +name: summarizer +description: Summarizes documents concisely +tools: + - Read + - Glob +model: gpt-4o-mini +max_turns: 30 +background: false +--- + +You are a document summarizer. When given a document, provide a concise +summary highlighting the key points. +""", + encoding="utf-8", + ) + + # ── 4. Empty dir (should be skipped) ────────────────────────── + empty = agents / "empty-dir" + empty.mkdir(parents=True) + + return tmp + + +# ── Test: _detect_kind ─────────────────────────────────────────────── + +def test_detect_kind(workspace: Path): + print("\n── Test: _detect_kind ──") + agents = workspace / "agents" + + check( + "config.py → pyconfig", + _detect_kind(agents / "code-reviewer") == "pyconfig", + ) + check( + "AGENT.md → markdown", + _detect_kind(agents / "summarizer") == "markdown", + ) + check( + "empty dir → None", + _detect_kind(agents / "empty-dir") is None, + ) + + +# ── Test: list ─────────────────────────────────────────────────────── + +async def test_list(workspace: Path): + print("\n── Test: AgentLoader.list() ──") + loader = AgentLoader(workspace) + entries = await loader.list() + + names = {e["name"] for e in entries} + check("discovers 3 agents", len(entries) == 3, f"got {len(entries)}: {names}") + check("code-reviewer found", "code-reviewer" in names) + check("translator found", "translator" in names) + check("summarizer found", "summarizer" in names) + check("empty-dir skipped", "empty-dir" not in names) + + # Check kinds + kinds = {e["name"]: e["kind"] for e in entries} + check("code-reviewer is pyconfig", kinds.get("code-reviewer") == "pyconfig") + check("translator is pyconfig", kinds.get("translator") == "pyconfig") + check("summarizer is markdown", kinds.get("summarizer") == "markdown") + + +# ── Test: load (pyconfig) ──────────────────────────────────────────── + +async def test_load_pyconfig(workspace: Path): + print("\n── Test: load (pyconfig) ──") + loader = AgentLoader(workspace) + spec = await loader.load("code-reviewer") + + check("spec loaded", spec is not None) + if spec is None: + return + + check("name", spec.name == "code-reviewer", spec.name) + check("description", "Reviews code" in spec.description, spec.description) + check("system_prompt", "code reviewer" in spec.system_prompt, spec.system_prompt) + check("max_turns", spec.max_turns == 50, str(spec.max_turns)) + check("agent_config is dict", isinstance(spec.agent_config, dict)) + check("agent_config has type", "type" in (spec.agent_config or {})) + check("extra has tools", spec.extra.get("tools") == ["Read", "Grep"]) + check("extra has model", spec.extra.get("model") == "gpt-4") + check("build is None", spec.build is None) + check("project_dir set", spec.project_dir is not None) + + +# ── Test: load (pyconfig with build) ───────────────────────────────── + +async def test_load_pyconfig_build(workspace: Path): + print("\n── Test: load (pyconfig with build) ──") + loader = AgentLoader(workspace) + spec = await loader.load("translator") + + check("spec loaded", spec is not None) + if spec is None: + return + + check("name", spec.name == "translator", spec.name) + check("description", "Translates" in spec.description, spec.description) + check("max_turns", spec.max_turns == 10, str(spec.max_turns)) + check("build is callable", callable(spec.build)) + check("agent_config is dict", isinstance(spec.agent_config, dict)) + + +# ── Test: load (markdown) ──────────────────────────────────────────── + +async def test_load_markdown(workspace: Path): + print("\n── Test: load (markdown) ──") + loader = AgentLoader(workspace) + spec = await loader.load("summarizer") + + check("spec loaded", spec is not None) + if spec is None: + return + + check("name", spec.name == "summarizer", spec.name) + check("description", "Summarizes" in spec.description, spec.description) + check("system_prompt", "document summarizer" in spec.system_prompt, spec.system_prompt) + check("max_turns", spec.max_turns == 30, str(spec.max_turns)) + check("tools in extra", spec.extra.get("tools") == ["Read", "Glob"]) + check("model in extra", spec.extra.get("model") == "gpt-4o-mini") + check("no agent_config", spec.agent_config is None) + + +# ── Test: load_all ─────────────────────────────────────────────────── + +async def test_load_all(workspace: Path): + print("\n── Test: load_all ──") + loader = AgentLoader(workspace) + specs = await loader.load_all() + + check("3 specs loaded", len(specs) == 3, f"got {len(specs)}") + check("code-reviewer in specs", "code-reviewer" in specs) + check("translator in specs", "translator" in specs) + check("summarizer in specs", "summarizer" in specs) + + +# ── Test: build_agents_summary ─────────────────────────────────────── + +async def test_summary(workspace: Path): + print("\n── Test: build_agents_summary ──") + loader = AgentLoader(workspace) + summary = await loader.build_agents_summary() + + check("not empty", len(summary) > 0) + check("contains ", "" in summary) + check("contains code-reviewer", "code-reviewer" in summary) + check("contains translator", "translator" in summary) + check("contains summarizer", "summarizer" in summary) + + +# ── Test: AgentSpec.create() (pyconfig) ────────────────────────────── + +async def test_spec_create_pyconfig(workspace: Path): + print("\n── Test: AgentSpec.create() (pyconfig) ──") + loader = AgentLoader(workspace) + spec = await loader.load("code-reviewer") + assert spec is not None + + agent = spec.create() + check("agent created", agent is not None) + check("agent has memory", hasattr(agent, "memory")) + check("agent name", agent.name == "code-reviewer", agent.name) + + +# ── Test: AgentSpec.create() (pyconfig with build) ─────────────────── + +async def test_spec_create_build(workspace: Path): + print("\n── Test: AgentSpec.create() (pyconfig with build) ──") + loader = AgentLoader(workspace) + spec = await loader.load("translator") + assert spec is not None + + agent = spec.create() + check("agent created", agent is not None) + check("agent has memory", hasattr(agent, "memory")) + check("agent name", agent.name == "translator", agent.name) + + +# ── Test: AgentService with mixed project types ───────────────────── + +async def test_agent_service_mixed(workspace: Path): + print("\n── Test: AgentService with mixed project types ──") + loader = AgentLoader(workspace) + service = AgentService(agent_loader=loader) + await service.load_specs() + + check("3 types available", len(service.available_types) == 3, + str(service.available_types)) + check("code-reviewer registered", "code-reviewer" in service.available_types) + check("translator registered", "translator" in service.available_types) + check("summarizer registered", "summarizer" in service.available_types) + + # Check that pyconfig spec has agent_config + cr_spec = service.get_spec("code-reviewer") + check("pyconfig spec has agent_config", + cr_spec is not None and cr_spec.agent_config is not None) + + # Check that translator spec has build + tr_spec = service.get_spec("translator") + check("translator spec has build", + tr_spec is not None and callable(tr_spec.build)) + + +# ── Test: AgentSpec serialization roundtrip ────────────────────────── + +def test_spec_serialization(): + print("\n── Test: AgentSpec serialization roundtrip ──") + spec = AgentSpec( + name="test", + description="A test agent", + project_dir="/tmp/agents/test", + agent_config=dict(type="lagent.agents.Agent"), + extra={"tools": ["Read"]}, + ) + d = spec.to_dict() + check("to_dict has project_dir", d["project_dir"] == "/tmp/agents/test") + + restored = AgentSpec.from_dict(d) + check("roundtrip name", restored.name == spec.name) + check("roundtrip project_dir", restored.project_dir == spec.project_dir) + check("roundtrip extra", restored.extra == spec.extra) + check("build is None after roundtrip", restored.build is None) + + +# ── Test: priority (config.py wins over AGENT.md) ─────────────────── + +async def test_priority(workspace: Path): + """If a project has both config.py AND AGENT.md, config.py wins.""" + print("\n── Test: config.py priority over AGENT.md ──") + agents = workspace / "agents" + hybrid = agents / "hybrid" + hybrid.mkdir(parents=True, exist_ok=True) + + # Create both config.py and AGENT.md + (hybrid / "config.py").write_text( + """\ +from lagent.agents.agent import Agent +agent_config = dict(type=Agent, name="hybrid-from-config") +name = "hybrid" +description = "From config.py" +""", + encoding="utf-8", + ) + (hybrid / "AGENT.md").write_text( + """\ +--- +name: hybrid +description: From AGENT.md +--- +System prompt from markdown. +""", + encoding="utf-8", + ) + + kind = _detect_kind(hybrid) + check("config.py wins", kind == "pyconfig") + + loader = AgentLoader(workspace) + spec = await loader.load("hybrid") + check("spec loaded", spec is not None) + if spec: + check("description from config.py", spec.description == "From config.py", + spec.description) + + +# ── Main ───────────────────────────────────────────────────────────── + +async def main(): + tmp = Path(tempfile.mkdtemp(prefix="test_agent_project_")) + try: + workspace = create_test_workspace(tmp) + print(f"Workspace: {workspace}") + + test_detect_kind(workspace) + await test_list(workspace) + await test_load_pyconfig(workspace) + await test_load_pyconfig_build(workspace) + await test_load_markdown(workspace) + await test_load_all(workspace) + await test_summary(workspace) + await test_spec_create_pyconfig(workspace) + await test_spec_create_build(workspace) + await test_agent_service_mixed(workspace) + test_spec_serialization() + await test_priority(workspace) + + finally: + shutil.rmtree(tmp, ignore_errors=True) + + print(f"\n{'='*60}") + print(f"Results: {PASS} passed, {FAIL} failed") + print(f"{'='*60}") + return FAIL == 0 + + +if __name__ == "__main__": + ok = asyncio.run(main()) + sys.exit(0 if ok else 1) diff --git a/tests/test_agent_service_e2e.py b/tests/test_agent_service_e2e.py new file mode 100644 index 00000000..76c16163 --- /dev/null +++ b/tests/test_agent_service_e2e.py @@ -0,0 +1,902 @@ +"""End-to-end integration test for AgentService + AgentAction + CompactAction. + +Three test levels (run selectively via command-line flags): + + 1. **Unit** (no LLM, no network): + Mock everything, verify wiring and data flow. + + 2. **Integration** (real LLM, no sandbox): + Use a live OpenAI-compatible endpoint to test the full + Policy → fork → compact → summary pipeline. + + 3. **Full E2E** (real LLM + sandbox): + Same as the __main__ block in internclaw_agent.py — tests the + complete InternClawAgent with AgentService, CompactAction, etc. + +Usage:: + + # Unit tests only (fast, no network) + python tests/test_agent_service_e2e.py --unit + + # Integration tests (needs LLM endpoint) + python tests/test_agent_service_e2e.py --integration + + # Full E2E (needs LLM + sandbox) + python tests/test_agent_service_e2e.py --e2e + + # All tests + python tests/test_agent_service_e2e.py --all +""" + +from __future__ import annotations + +import argparse +import asyncio +import json +import logging +import sys +import os +from pathlib import Path +from dataclasses import dataclass +from typing import Any + +# Ensure lagent is importable +sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) + +from lagent.agents.agent import AsyncAgent +from lagent.services.agent_loader import AgentSpec +from lagent.services.agent import ( + AgentEntry, + AgentService, + AgentStatus, +) +from lagent.actions.subagent import AgentAction, AsyncAgentAction +from lagent.actions.compact import ( + AsyncCompactAction, + CompactAction, + estimate_token_count, + COMPACT_PROMPT, +) +from lagent.memory.memory import BaseMemoryStore, ClaudeCodeMemory +from lagent.schema import ActionReturn, ActionStatusCode, AgentMessage + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(name)s] %(levelname)s: %(message)s", +) +logger = logging.getLogger("test_e2e") + + +# ===================================================================== +# Helpers / Mocks +# ===================================================================== + +class MockLLM: + """A mock LLM that returns canned responses.""" + + def __init__(self, responses: list[dict] | None = None): + self._responses = responses or [] + self._call_count = 0 + + async def chat(self, messages: list[dict], tools=None, **kwargs) -> dict: + self._call_count += 1 + if self._responses: + idx = min(self._call_count - 1, len(self._responses) - 1) + return self._responses[idx] + # Default: return a simple text response + return { + "role": "assistant", + "content": f"Mock response #{self._call_count}", + "tool_calls": [], + } + + @property + def call_count(self): + return self._call_count + + +class MockAggregator: + """Minimal aggregator that just passes messages through.""" + + def aggregate(self, memory, name, output_format, template): + msgs = [] + for m in memory.get_memory(): + if isinstance(m, AgentMessage): + msgs.append({"role": "user", "content": m.content or ""}) + if not msgs: + msgs = [{"role": "user", "content": "hello"}] + return msgs, None + + +class SimpleTestAgent(AsyncAgent): + """A minimal async agent for testing purposes. + + Uses a MockLLM and MockAggregator. Returns the LLM response + directly as an AgentMessage. + """ + + def __init__(self, llm=None, name="test-agent", **kwargs): + super().__init__( + llm=llm or MockLLM(), + aggregator=MockAggregator(), + name=name, + **kwargs, + ) + + async def forward(self, *messages, **kwargs): + formatted, tools = self.aggregator.aggregate( + self.memory, self.name, self.output_format, self.template + ) + resp = await self.llm.chat(formatted, tools=tools, **kwargs) + return AgentMessage( + sender=self.name, + content=resp.get("content", ""), + tool_calls=resp.get("tool_calls") or [], + ) + + +# ===================================================================== +# Level 1: Unit Tests (no network) +# ===================================================================== + +async def test_agent_service_basic(): + """AgentService: register, spawn (with mock agent), list, query.""" + print("\n" + "=" * 60) + print("TEST: AgentService basic lifecycle (unit)") + print("=" * 60) + + # 1. Create service without loader + svc = AgentService() + assert svc.available_types == [], "Should start empty" + print(" ✅ AgentService() created without loader") + + # 2. Register spec with agent_config + spec = AgentSpec( + name="echo-agent", + description="Echoes back the task", + agent_config=dict( + type=f"{SimpleTestAgent.__module__}.{SimpleTestAgent.__qualname__}", + llm=dict( + type=f"{MockLLM.__module__}.{MockLLM.__qualname__}", + responses=[{ + "role": "assistant", + "content": "Echo: placeholder", + "tool_calls": [], + }], + ), + name="echo-agent", + ), + ) + svc.register_spec(spec) + assert "echo-agent" in svc.available_types + print(f" ✅ register_spec() → available_types={svc.available_types}") + + # 3. Spawn sync + entry = await svc.spawn("echo-agent", "Hello world!", mode="sync") + assert entry.status == AgentStatus.STOPPED, f"Expected STOPPED, got {entry.status}" + assert entry.result is not None + print(f" ✅ spawn(sync) → status={entry.status}, result={entry.result[:50]}") + + # 4. List + entries = svc.list() + assert len(entries) == 1 + print(f" ✅ list() → {len(entries)} entries") + + # 5. Query + queried = svc.get(entry.id) + assert queried is not None and queried.id == entry.id + print(f" ✅ get({entry.id}) → found") + + # 6. Spawn async + entry2 = await svc.spawn("echo-agent", "Background task", mode="async") + assert entry2.status in (AgentStatus.PENDING, AgentStatus.RUNNING, AgentStatus.STOPPED) + await asyncio.sleep(0.2) # let async task complete + updated = svc.get(entry2.id) + print(f" ✅ spawn(async) → status={updated.status}") + + print(" 🎉 AgentService basic lifecycle: ALL PASSED\n") + + +async def test_agent_action_wiring(): + """AgentAction: tool_api methods route to AgentService correctly.""" + print("\n" + "=" * 60) + print("TEST: AgentAction wiring (unit)") + print("=" * 60) + + svc = AgentService() + spec = AgentSpec( + name="helper", + description="A helper agent", + agent_config=dict( + type=f"{SimpleTestAgent.__module__}.{SimpleTestAgent.__qualname__}", + llm=dict( + type=f"{MockLLM.__module__}.{MockLLM.__qualname__}", + responses=[{ + "role": "assistant", + "content": "Done: something", + "tool_calls": [], + }], + ), + name="helper", + ), + ) + svc.register_spec(spec) + + action = AgentAction(agent_service=svc) + + # spawn + result = await action.spawn(agent_type="helper", task="Do something") + assert result.state == ActionStatusCode.SUCCESS, f"spawn failed: {result.errmsg}" + content = result.result[0]["content"] + assert "Done" in content + print(f" ✅ spawn() → {content[:60]}...") + + # list_agents + result = await action.list_agents() + assert result.state == ActionStatusCode.SUCCESS + content = result.result[0]["content"] + assert "helper" in content or "Do something" in content + print(f" ✅ list_agents() → {content[:60]}...") + + # spawn unknown type + result = await action.spawn(agent_type="nonexistent", task="fail") + assert result.state == ActionStatusCode.API_ERROR + print(f" ✅ spawn(unknown) → error: {result.errmsg[:60]}...") + + print(" 🎉 AgentAction wiring: ALL PASSED\n") + + +async def test_compact_action_unit(): + """CompactAction: should_compact + compact with mock service.""" + print("\n" + "=" * 60) + print("TEST: CompactAction unit (mock LLM)") + print("=" * 60) + + svc = AgentService() + + # Create a policy agent with mock LLM that returns a summary + policy = SimpleTestAgent( + llm=MockLLM([{ + "role": "assistant", + "content": "## Summary\nUser asked to test compact.\n## Pending\nNothing.", + "tool_calls": [], + }]), + name="policy", + ) + # Feed some history into policy's memory + policy.memory.add(AgentMessage(sender="user", content="Hello, please help me test")) + policy.memory.add(AgentMessage(sender="policy", content="Sure, I'll help")) + + compact = CompactAction( + agent_service=svc, + policy_agent=policy, + max_context_tokens=1000, + threshold_ratio=0.5, + ) + + # should_compact + assert not compact.should_compact(400), "400 < 500 threshold" + assert compact.should_compact(600), "600 > 500 threshold" + print(" ✅ should_compact() threshold logic works") + + # compact (fork-based) + result = await compact.compact() + assert result.state == ActionStatusCode.SUCCESS, f"compact failed: {result.errmsg}" + summary = result.result[0]["content"] + assert "Summary" in summary + print(f" ✅ compact() → summary: {summary[:60]}...") + + # circuit breaker + compact._consecutive_failures = 3 + assert not compact.should_compact(9999), "Circuit breaker should block" + compact._consecutive_failures = 0 + print(" ✅ Circuit breaker works") + + print(" 🎉 CompactAction unit: ALL PASSED\n") + + +async def test_claude_code_memory_unit(): + """ClaudeCodeMemory: unified interface with CompactAction.""" + print("\n" + "=" * 60) + print("TEST: ClaudeCodeMemory unit") + print("=" * 60) + + svc = AgentService() + policy = SimpleTestAgent( + llm=MockLLM([{ + "role": "assistant", + "content": "Compacted summary here.", + "tool_calls": [], + }]), + name="policy", + ) + policy.memory.add(AgentMessage(sender="user", content="test")) + + compact = CompactAction( + agent_service=svc, + policy_agent=policy, + max_context_tokens=1000, + threshold_ratio=0.5, + ) + + mem = ClaudeCodeMemory(compact_action=compact) + + # get_info returns empty (no persistent storage) + info = await mem.get_info() + assert info == {} + print(" ✅ get_info() → {}") + + # should_compact delegates to CompactAction + assert not mem.should_compact(400) + assert mem.should_compact(600) + print(" ✅ should_compact() delegates correctly") + + # compact returns summary string + summary = await mem.compact() + assert summary is not None and "Compacted summary" in summary + print(f" ✅ compact() → {summary[:50]}...") + + # actions is empty + assert mem.actions == [] + print(" ✅ actions → []") + + print(" 🎉 ClaudeCodeMemory unit: ALL PASSED\n") + + +async def test_spawn_with_state(): + """Test spawn with state transfer (replaces fork).""" + print("\n" + "=" * 60) + print("TEST: Spawn with state transfer (unit)") + print("=" * 60) + + svc = AgentService() + + spec = AgentSpec( + name="worker", + description="A worker agent", + agent_config=dict( + type=f"{SimpleTestAgent.__module__}.{SimpleTestAgent.__qualname__}", + llm=dict( + type=f"{MockLLM.__module__}.{MockLLM.__qualname__}", + responses=[{ + "role": "assistant", + "content": "Worker done", + "tool_calls": [], + }], + ), + name="worker", + ), + ) + svc.register_spec(spec) + + # Spawn mode (normal) + entry1 = await svc.spawn("worker", "Task A", mode="sync") + assert entry1.agent_type == "worker" + assert entry1.result is not None + print(f" ✅ Spawn mode: type={entry1.agent_type}, result={entry1.result[:40]}") + + # Spawn with state (replaces fork) + fake_state = {"memory": [{"sender": "user", "content": "Context from prev session"}]} + entry2 = await svc.spawn("worker", "Continue work", mode="sync", state=fake_state) + assert entry2.agent_type == "worker" + assert entry2.result is not None + print(f" ✅ Spawn with state: type={entry2.agent_type}, result={entry2.result[:40]}") + + print(" 🎉 Spawn with state transfer: ALL PASSED\n") + + +async def test_agent_spec_create(): + """AgentSpec.create() and acreate() with agent_config (PyConfig).""" + print("\n" + "=" * 60) + print("TEST: AgentSpec.create() + acreate() (unit)") + print("=" * 60) + + # agent_config pointing to SimpleTestAgent + spec = AgentSpec( + name="pyconfig-agent", + agent_config=dict( + type=f"{SimpleTestAgent.__module__}.{SimpleTestAgent.__qualname__}", + llm=dict( + type=f"{MockLLM.__module__}.{MockLLM.__qualname__}", + responses=[{ + "role": "assistant", + "content": "PyConfig agent works!", + "tool_calls": [], + }], + ), + name="pyconfig-test", + ), + ) + + try: + agent = spec.create() + assert isinstance(agent, AsyncAgent), f"Got {type(agent)}" + print(f" ✅ spec.create() created {type(agent).__name__}") + + # Run the agent + response = await agent("test task") + print(f" ✅ Agent response: {response.content[:50]}") + except Exception as exc: + # create_object may not resolve the mock classes; that's OK for + # a unit test — the important thing is the error is clear + print(f" ⚠️ spec.create() raised (expected in unit test): {exc}") + + # Test async create + try: + agent2 = await spec.acreate() + assert isinstance(agent2, AsyncAgent), f"Got {type(agent2)}" + print(f" ✅ spec.acreate() created {type(agent2).__name__}") + except Exception as exc: + print(f" ⚠️ spec.acreate() raised (expected in unit test): {exc}") + + # Test error case: no agent_config + spec_empty = AgentSpec(name="empty") + try: + spec_empty.create() + assert False, "Should have raised ValueError" + except ValueError as exc: + assert "no agent_config" in str(exc) + print(f" ✅ No agent_config → ValueError: {str(exc)[:60]}...") + + print(" 🎉 AgentSpec.create(): ALL PASSED\n") + + +async def test_agent_service_persistence(): + """Save and load agent entries.""" + print("\n" + "=" * 60) + print("TEST: AgentService persistence (unit)") + print("=" * 60) + + import tempfile + + svc = AgentService() + spec = AgentSpec( + name="test-type", + description="For persistence test", + agent_config=dict( + type=f"{SimpleTestAgent.__module__}.{SimpleTestAgent.__qualname__}", + name="test-type", + ), + ) + svc.register_spec(spec) + + entry = await svc.spawn("test-type", "Persist me", mode="sync") + + with tempfile.TemporaryDirectory() as tmpdir: + save_path = Path(tmpdir) / "agents" + await svc.save_all(save_path) + assert (save_path / "agents.json").exists() + print(f" ✅ save_all() → {save_path / 'agents.json'}") + + # Load into a fresh service + svc2 = AgentService() + await svc2.load_all(save_path) + entries = svc2.list() + assert len(entries) == 1 + assert entries[0].id == entry.id + assert entries[0].agent_type == "test-type" + print(f" ✅ load_all() → restored {len(entries)} entries") + + print(" 🎉 AgentService persistence: ALL PASSED\n") + + +# ===================================================================== +# Level 2: Integration Tests (real LLM) +# ===================================================================== + +async def test_real_llm_compact(): + """Integration: CompactAction with a real LLM endpoint. + + Requires environment variables: + - LLM_BASE_URL (e.g. http://35.220.164.252:3888/v1) + - LLM_API_KEY + - LLM_MODEL (e.g. gpt-5.4) + Or falls back to the hardcoded defaults from internclaw_agent.py. + """ + print("\n" + "=" * 60) + print("TEST: Real LLM compact (integration)") + print("=" * 60) + + from lagent.llms.model import AsyncAPIClient, ModelConfig, SampleParameters + from lagent.agents.aggregator.context import InternClawContextBuilder + from lagent.agents.internclaw_agent import AsyncPolicyAgent + + base_url = os.environ.get("LLM_BASE_URL", "http://35.220.164.252:3888/v1") + api_key = os.environ.get("LLM_API_KEY", " ") + model_name = os.environ.get("LLM_MODEL", "gpt-5.4") + proxy = os.environ.get("LLM_PROXY", "http://100.100.72.89:8899") + + print(f" Using model={model_name}, base_url={base_url}") + + model = AsyncAPIClient( + model=ModelConfig(model=model_name, base_url=base_url, api_key=api_key, proxy=proxy), + sample_params=SampleParameters(temperature=0.7, top_p=1.0, top_k=50), + timeout=120, + max_retry=3, + sleep_interval=2, + ) + + workspace = Path(__file__).parent.parent / "workspace" + aggregator = InternClawContextBuilder(workspace, tools=None) + + # Build policy agent + policy = AsyncPolicyAgent( + llm=model, + aggregator=aggregator, + name="policy", + ) + + # Simulate a conversation + policy.memory.add(AgentMessage( + sender="user", + content="Please help me write a Python function that calculates fibonacci numbers.", + )) + policy.memory.add(AgentMessage( + sender="policy", + content="Sure! Here's a recursive fibonacci function:\n\n```python\ndef fib(n):\n if n <= 1: return n\n return fib(n-1) + fib(n-2)\n```", + tool_calls=[], + )) + policy.memory.add(AgentMessage( + sender="user", + content="Can you make it iterative and add memoization?", + )) + + # Create AgentService + CompactAction + svc = AgentService() + + compact = CompactAction( + agent_service=svc, + policy_agent=policy, + max_context_tokens=1000, + threshold_ratio=0.1, # Low threshold to force trigger + ) + + # Check should_compact + token_est = estimate_token_count( + [{"content": "x" * 200}] # simulate some tokens + ) + print(f" Token estimate: {token_est}, threshold: {compact.threshold_tokens}") + assert compact.should_compact(token_est) or token_est < compact.threshold_tokens + + # Force compact + print(" Calling compact() with real LLM (this may take 10-30s)...") + result = await compact.compact() + print(f" compact() state: {result.state}") + if result.state == ActionStatusCode.SUCCESS: + summary = result.result[0]["content"] + print(f" ✅ Got summary ({len(summary)} chars):") + print(f" {summary[:200]}...") + assert len(summary) > 50, "Summary too short" + else: + print(f" ⚠️ compact() returned error: {result.errmsg}") + print(" (This may be expected if the LLM endpoint is down)") + + print(" 🎉 Real LLM compact: DONE\n") + + +async def test_real_llm_agent_service_spawn(): + """Integration: AgentService.spawn with real LLM using PyConfig.""" + print("\n" + "=" * 60) + print("TEST: Real LLM AgentService spawn (integration)") + print("=" * 60) + + from lagent.llms.model import AsyncAPIClient, ModelConfig, SampleParameters + from lagent.agents.aggregator.context import InternClawContextBuilder + from lagent.agents.internclaw_agent import AsyncPolicyAgent + + base_url = os.environ.get("LLM_BASE_URL", "http://35.220.164.252:3888/v1") + api_key = os.environ.get("LLM_API_KEY", " ") + model_name = os.environ.get("LLM_MODEL", "gpt-5.4") + proxy = os.environ.get("LLM_PROXY", "http://100.100.72.89:8899") + + workspace = Path(__file__).parent.parent / "workspace" + + # AgentService without loader + svc = AgentService() + + # Register an agent type with PyConfig + model_cfg = ModelConfig(model=model_name, base_url=base_url, api_key=api_key, proxy=proxy) + sample_params = SampleParameters(temperature=0.7, top_p=1.0, top_k=50) + + # We can't use full PyConfig here because AsyncPolicyAgent needs + # a ContextBuilder which needs a workspace path. So we use a + # custom factory instead. + async def real_llm_factory(spec: AgentSpec, task: str): + model = AsyncAPIClient( + model=model_cfg, + sample_params=sample_params, + timeout=120, + max_retry=3, + sleep_interval=2, + ) + aggregator = InternClawContextBuilder(workspace, tools=None) + agent = AsyncPolicyAgent( + llm=model, + aggregator=aggregator, + name=spec.name, + ) + return agent + + svc._factory = real_llm_factory + + svc.register("qa-agent", description="Answers questions") + + print(" Spawning qa-agent with real LLM (sync mode)...") + entry = await svc.spawn( + "qa-agent", + "What is 2+2? Answer in one word.", + mode="sync", + ) + print(f" spawn() → status={entry.status}") + if entry.status == AgentStatus.STOPPED: + print(f" ✅ Result: {entry.result[:100]}") + elif entry.status == AgentStatus.FAILED: + print(f" ⚠️ Failed: {entry.error}") + else: + print(f" ⚠️ Unexpected status: {entry.status}") + + # List + entries = svc.list() + print(f" ✅ list() → {len(entries)} entries") + + print(" 🎉 Real LLM AgentService spawn: DONE\n") + + +async def test_real_llm_full_pipeline(): + """Integration: Full pipeline — InternClawAgent with AgentService + ClaudeCodeMemory.""" + print("\n" + "=" * 60) + print("TEST: Full pipeline with real LLM (integration)") + print("=" * 60) + + from lagent.llms.model import AsyncAPIClient, ModelConfig, SampleParameters + from lagent.agents.aggregator.context import InternClawContextBuilder + from lagent.agents.internclaw_agent import ( + AsyncPolicyAgent, + AsyncEnvAgent, + InternClawAgent, + ) + from lagent.hooks.logger import MessageLogger + + base_url = os.environ.get("LLM_BASE_URL", "http://35.220.164.252:3888/v1") + api_key = os.environ.get("LLM_API_KEY", " ") + model_name = os.environ.get("LLM_MODEL", "gpt-5.4") + proxy = os.environ.get("LLM_PROXY", "http://100.100.72.89:8899") + + workspace = Path(__file__).parent.parent / "workspace" + + model = AsyncAPIClient( + model=ModelConfig(model=model_name, base_url=base_url, api_key=api_key, proxy=proxy), + sample_params=SampleParameters(temperature=0.7, top_p=1.0, top_k=50), + timeout=120, + max_retry=3, + sleep_interval=2, + ) + + aggregator = InternClawContextBuilder(workspace, tools=None) + + # Step 1: Create AgentService + svc = AgentService() + + # Step 2: Create PolicyAgent + policy = AsyncPolicyAgent( + llm=model, + aggregator=aggregator, + name="policy", + hooks=[MessageLogger()], + ) + + # Step 3: Create CompactAction → ClaudeCodeMemory + compact = CompactAction( + agent_service=svc, + policy_agent=policy, + max_context_tokens=128_000, + threshold_ratio=0.85, + ) + memory_store = ClaudeCodeMemory(compact_action=compact) + + # Step 4: Create AgentAction (for sub-agent spawning) + agent_action = AgentAction(agent_service=svc) + + # Step 5: Create EnvAgent with memory_store + # Note: For this test we don't use real shell actions — just + # the AgentAction so policy can spawn sub-agents + env = AsyncEnvAgent( + actions=[agent_action], + skills=None, + memory_store=memory_store, + ) + + # Step 6: Create InternClawAgent + # Use max_turn=3 to limit the test + agent = InternClawAgent( + policy_agent=policy, + env_agent=env, + max_turn=3, + ) + + print(" Running InternClawAgent with max_turn=3...") + print(" (Policy → Env → Policy → Env → ...)") + try: + response = await agent("What is the capital of France? Answer briefly.") + print(f" ✅ Agent finished: {response.content[:100]}") + except Exception as exc: + print(f" ⚠️ Agent failed: {exc}") + import traceback + traceback.print_exc() + + print(" 🎉 Full pipeline: DONE\n") + + +# ===================================================================== +# Level 3: Full E2E (real LLM + sandbox) +# ===================================================================== + +async def test_e2e_with_sandbox(): + """Full E2E: InternClawAgent with sandbox MCP actions. + + This is essentially the same as the __main__ block in + internclaw_agent.py, but structured as a proper test. + """ + print("\n" + "=" * 60) + print("TEST: Full E2E with sandbox (e2e)") + print("=" * 60) + + from lagent.llms.model import AsyncAPIClient, ModelConfig, SampleParameters + from lagent.agents.aggregator.context import InternClawContextBuilder + from lagent.agents.internclaw_agent import ( + AsyncPolicyAgent, + AsyncEnvAgent, + InternClawAgent, + ) + from lagent.skills.skills import SkillsLoader, SandboxSkillsBackend + from lagent.memory.memory import SandboxMemoryBackend + from lagent.actions.mcp_client import AsyncMCPClientSandbox + from lagent.hooks.logger import MessageLogger + + base_url = os.environ.get("LLM_BASE_URL", "http://35.220.164.252:3888/v1") + api_key = os.environ.get("LLM_API_KEY", " ") + model_name = os.environ.get("LLM_MODEL", "gpt-5.4") + proxy = os.environ.get("LLM_PROXY", "http://100.100.72.89:8899") + sandbox_url = os.environ.get("SANDBOX_URL", "http://simple-shell.ailab.ailab.ai/mcp") + init_dir = os.environ.get("INIT_DIR", "/mnt/shared-storage-user/llmit/user/liukuikun/workspace/lagent/workspace") + + model = AsyncAPIClient( + model=ModelConfig(model=model_name, base_url=base_url, api_key=api_key, proxy=proxy), + sample_params=SampleParameters(temperature=0.7, top_p=1.0, top_k=50), + timeout=600, + max_retry=10, + sleep_interval=5, + ) + + shell_action = AsyncMCPClientSandbox('http', url=sandbox_url, init_dir=init_dir) + + try: + # Discover workspace + home_path = await shell_action.run(command='pwd') + import json as _json + cwd = _json.loads(home_path.result[0]['content'])['cwd'] + workspace_path = os.path.join(cwd, 'workspace') + print(f" Workspace: {workspace_path}") + + actions = [shell_action] + aggregator = InternClawContextBuilder(Path(workspace_path), tools=None) + + # AgentService + CompactAction + svc = AgentService() + + policy = AsyncPolicyAgent( + llm=model, + aggregator=aggregator, + name="policy", + hooks=[MessageLogger()], + ) + + compact = CompactAction( + agent_service=svc, + policy_agent=policy, + max_context_tokens=128_000, + threshold_ratio=0.85, + ) + memory_store = ClaudeCodeMemory(compact_action=compact) + + # AgentAction for sub-agent management + agent_action = AgentAction(agent_service=svc) + + env = AsyncEnvAgent( + actions=actions + [agent_action], + skills=None, + memory_store=memory_store, + ) + + agent = InternClawAgent( + policy_agent=policy, + env_agent=env, + max_turn=5, + ) + + print(" Running InternClawAgent with sandbox (max_turn=5)...") + response = await agent("List the files in the current directory using ls -la") + print(f" ✅ Agent finished: {response.content[:200]}") + + except Exception as exc: + print(f" ⚠️ E2E test failed: {exc}") + import traceback + traceback.print_exc() + finally: + try: + await shell_action.close() + except Exception: + pass + + print(" 🎉 Full E2E with sandbox: DONE\n") + + +# ===================================================================== +# Runner +# ===================================================================== + +async def run_unit_tests(): + """Run all unit tests (no network required).""" + print("\n" + "#" * 60) + print("# UNIT TESTS") + print("#" * 60) + await test_agent_service_basic() + await test_agent_action_wiring() + await test_compact_action_unit() + await test_claude_code_memory_unit() + await test_fork_spawn_modes() + await test_default_agent_factory() + await test_agent_service_persistence() + print("=" * 60) + print("ALL UNIT TESTS PASSED ✅") + print("=" * 60) + + +async def run_integration_tests(): + """Run integration tests (requires LLM endpoint).""" + print("\n" + "#" * 60) + print("# INTEGRATION TESTS (requires LLM endpoint)") + print("#" * 60) + await test_real_llm_compact() + await test_real_llm_agent_service_spawn() + await test_real_llm_full_pipeline() + print("=" * 60) + print("ALL INTEGRATION TESTS DONE ✅") + print("=" * 60) + + +async def run_e2e_tests(): + """Run full E2E tests (requires LLM + sandbox).""" + print("\n" + "#" * 60) + print("# E2E TESTS (requires LLM + sandbox)") + print("#" * 60) + await test_e2e_with_sandbox() + print("=" * 60) + print("ALL E2E TESTS DONE ✅") + print("=" * 60) + + +def main(): + parser = argparse.ArgumentParser(description="AgentService E2E Tests") + parser.add_argument("--unit", action="store_true", help="Run unit tests (no network)") + parser.add_argument("--integration", action="store_true", help="Run integration tests (needs LLM)") + parser.add_argument("--e2e", action="store_true", help="Run full E2E tests (needs LLM + sandbox)") + parser.add_argument("--all", action="store_true", help="Run all tests") + args = parser.parse_args() + + # Default to unit tests if nothing specified + if not any([args.unit, args.integration, args.e2e, args.all]): + args.unit = True + async def run_all(): + if args.unit or args.all: + await run_unit_tests() + if args.integration or args.all: + await run_integration_tests() + if args.e2e or args.all: + await run_e2e_tests() + + asyncio.run(run_integration_tests()) + + +if __name__ == "__main__": + main() diff --git a/tests/test_agents/test_compact_agent.py b/tests/test_agents/test_compact_agent.py new file mode 100644 index 00000000..a45bfea2 --- /dev/null +++ b/tests/test_agents/test_compact_agent.py @@ -0,0 +1,312 @@ +"""Tests for lagent.agents.compact_agent + +Includes: + - Unit tests with mock LLM (always run) + - Integration test with real model (run with --real flag) + +Usage: + python tests/test_agents/test_compact_agent.py # mock only + python tests/test_agents/test_compact_agent.py --real # include real model +""" + +import asyncio +import sys + +from lagent.agents.compact_agent import AsyncCompactAgent, estimate_token_count, COMPACT_PROMPT +from lagent.schema import AgentMessage + + +# ── Mock tests ──────────────────────────────────────────────────── + +async def test_threshold_tokens(): + compact = AsyncCompactAgent(max_context_tokens=1000, threshold_ratio=0.8) + assert compact.threshold_tokens == 800 + + +async def test_below_threshold(): + compact = AsyncCompactAgent(max_context_tokens=1000, threshold_ratio=0.8) + msg = AgentMessage(sender='policy', content='test', extra_info={'context_tokens': 500}) + assert compact.should_compact(msg) is False + + +async def test_above_threshold(): + compact = AsyncCompactAgent(max_context_tokens=1000, threshold_ratio=0.8) + msg = AgentMessage(sender='policy', content='test', extra_info={'context_tokens': 900}) + assert compact.should_compact(msg) is True + + +async def test_no_extra_info(): + compact = AsyncCompactAgent(max_context_tokens=1000, threshold_ratio=0.8) + msg = AgentMessage(sender='policy', content='test') + assert compact.should_compact(msg) is False + + +async def test_circuit_breaker(): + compact = AsyncCompactAgent(max_context_tokens=1000, threshold_ratio=0.8) + compact._consecutive_failures = 3 + msg = AgentMessage(sender='policy', content='test', extra_info={'context_tokens': 900}) + assert compact.should_compact(msg) is False + + +async def test_circuit_breaker_reset(): + compact = AsyncCompactAgent(max_context_tokens=1000, threshold_ratio=0.8) + compact._consecutive_failures = 3 + msg = AgentMessage(sender='policy', content='test', extra_info={'context_tokens': 900}) + assert compact.should_compact(msg) is False + compact._consecutive_failures = 0 + assert compact.should_compact(msg) is True + + +async def test_default_template(): + compact = AsyncCompactAgent() + assert compact.template == COMPACT_PROMPT + + +async def test_custom_template(): + compact = AsyncCompactAgent(template="Custom prompt") + assert compact.template == "Custom prompt" + + +async def test_estimate_token_count_basic(): + messages = [ + {"role": "user", "content": "Hello world"}, + {"role": "assistant", "content": "Hi there!"}, + ] + assert estimate_token_count(messages) == 5 # 20 chars / 4 + + +async def test_estimate_token_count_with_tools(): + messages = [{"role": "user", "content": "Hi"}] + tools = [{"type": "function", "function": {"name": "test"}}] + assert estimate_token_count(messages, tools) > estimate_token_count(messages) + + +async def test_estimate_token_count_empty(): + assert estimate_token_count([]) == 0 + + +async def test_forward_with_mock_llm(): + """Test that CompactAgent forward works with string content.""" + + class MockLLM: + async def chat(self, messages, **kwargs): + assert len(messages) >= 1 + return {"content": "## Summary\nThis is a test summary."} + + compact = AsyncCompactAgent(llm=MockLLM()) + + input_msg = AgentMessage( + sender='orchestrator', + content='USER: Hello\nASSISTANT: Hi there\nUSER: Help me refactor', + ) + + result = await compact(input_msg) + content = result.content + if isinstance(content, dict): + content = content.get('content', '') + assert 'Summary' in content + assert result.sender == 'AsyncCompactAgent' + + +async def test_forward_with_list_dict_content(): + """Test CompactAgent with list[dict] input (formatted_messages from policy aggregator).""" + + received_messages = [] + + class MockLLM: + async def chat(self, messages, **kwargs): + received_messages.extend(messages) + return {"content": "## Summary\nCompact summary of conversation."} + + compact = AsyncCompactAgent(llm=MockLLM()) + + # This is what orchestrator actually passes: formatted_messages from policy's aggregator + formatted_messages = [ + {'role': 'system', 'content': 'You are a helpful assistant.'}, + {'role': 'user', 'content': 'Hello, help me refactor memory'}, + {'role': 'assistant', 'content': 'Sure, let me analyze the code.'}, + {'role': 'user', 'content': 'Focus on base_memory.py'}, + {'role': 'assistant', 'content': 'I see the Memory class needs reset().', 'tool_calls': [ + {'function': {'name': 'read_file', 'arguments': {'path': 'base_memory.py'}}} + ]}, + ] + + input_msg = AgentMessage(sender='orchestrator', content=formatted_messages) + result = await compact(input_msg) + + # Verify CompactAggregator formatted correctly + assert len(received_messages) == 2 # system (COMPACT_PROMPT) + user (formatted text) + assert received_messages[0]['role'] == 'system' + assert 'CRITICAL' in received_messages[0]['content'] # COMPACT_PROMPT + + user_content = received_messages[1]['content'] + assert 'USER: Hello, help me refactor memory' in user_content + assert 'ASSISTANT: Sure, let me analyze the code.' in user_content + assert 'read_file' in user_content # tool_calls included + assert 'SYSTEM: You are a helpful assistant.' in user_content + + content = result.content + if isinstance(content, dict): + content = content.get('content', '') + assert 'Summary' in content + + +async def test_compact_aggregator_empty_content(): + """Test CompactAggregator handles empty/None content gracefully.""" + from lagent.agents.aggregator.compact_aggregator import CompactAggregator + + agg = CompactAggregator() + + from lagent.memory import Memory + mem = Memory() + mem.add(AgentMessage(sender='user', content=[ + {'role': 'user', 'content': 'test'}, + {'role': 'assistant', 'content': None, 'tool_calls': [ + {'function': {'name': 'shell', 'arguments': {}}} + ]}, + ])) + + messages, tools = agg.aggregate(mem, 'compact', system_instruction='Summarize') + assert messages[0]['role'] == 'system' + user_text = messages[1]['content'] + assert 'USER: test' in user_text + assert 'shell' in user_text # tool call captured even with None content + + +async def test_forward_error_handling(): + """Test circuit breaker increments on failure.""" + + class FailingLLM: + async def chat(self, messages, **kwargs): + raise RuntimeError("LLM unavailable") + + compact = AsyncCompactAgent(llm=FailingLLM()) + assert compact._consecutive_failures == 0 + + input_msg = AgentMessage(sender='orchestrator', content='test history') + result = await compact(input_msg) + + # forward raises inside Agent.__call__, but the agent should handle it + # Check that the compact agent's state reflects the attempt + # Note: error handling depends on whether forward() catches or propagates + + +# ── Real model test ─────────────────────────────────────────────── + +async def test_real_model_compact(): + """Integration test: CompactAgent with a real LLM. + + Run with: python tests/test_agents/test_compact_agent.py --real + """ + from lagent.llms.model import AsyncAPIClient, ModelConfig, SampleParameters + + model_name = "gpt-5.4-mini" + api_base = "http://35.220.164.252:3888/v1" + api_key = "" + proxy = "http://100.100.72.89:8899" + extra_body = {} + # model_name = "/mnt/shared-storage-user/llmit1/user/liujiangning/exp/s2_preview/agent_rl/s2-preview-thinker_sft_0228b_rl0312rc1/20260316082019/hf-15" + # api_base = "http://10.102.245.34:23333/v1" + + # api_key='YOUR KEY' + # extra_body = {'enable_thinking': True, 'spaces_between_special_tokens': False} + # proxy = None + + model = AsyncAPIClient( + model=ModelConfig(model=model_name, base_url=api_base, api_key=api_key, proxy=proxy), + sample_params=SampleParameters(temperature=0.7, top_p=1.0, top_k=50), + timeout=600, + max_retry=500, + sleep_interval=5, + extra_body=extra_body, + ) + + compact = AsyncCompactAgent( + llm=model, + max_context_tokens=128_000, + threshold_ratio=0.85, + ) + + # Simulate real formatted_messages from policy's aggregator (list[dict]) + formatted_messages = [ + {'role': 'system', 'content': 'You are InternClaw, a helpful AI assistant.'}, + {'role': 'user', 'content': "I'm working on refactoring the memory system in lagent."}, + {'role': 'assistant', 'content': 'I can help with that. What aspects of the memory system need refactoring?'}, + {'role': 'user', 'content': 'The main issue is that BaseMemoryStore couples compact and long-term memory together.'}, + {'role': 'assistant', 'content': 'I see. We should split them into independent modules. Let me analyze the current code.', 'tool_calls': [ + {'function': {'name': 'read_file', 'arguments': {'path': 'lagent/memory/memory.py'}}} + ]}, + {'role': 'tool', 'content': 'class BaseMemoryStore(ABC):\n def get_info(self)...\n def should_compact(self)...\n def compact(self)...', 'name': 'read_file'}, + {'role': 'assistant', 'content': 'I can see the coupling. BaseMemoryStore has both compact and LTM methods in one interface.'}, + {'role': 'user', 'content': 'Yes, also the CompactAction forks the policy agent which does not match lagent architecture.'}, + {'role': 'assistant', 'content': 'Right. Instead of forking, we can make CompactAgent a standard AsyncAgent.'}, + {'role': 'user', 'content': 'Good. The providers for different frameworks (OpenClaw vs Claude Code) should be independent.'}, + {'role': 'assistant', 'content': 'Agreed. Each provider just needs get_info() for reading, writes go through separate Actions.'}, + {'role': 'user', 'content': 'Make sure Memory is a pure list with no state tracking.'}, + {'role': 'assistant', 'content': 'Done. Memory now only has add/get/delete/save/load/reset.'}, + ] + + input_msg = AgentMessage( + sender='orchestrator', + content=formatted_messages, + ) + + print("\n Calling real model for compact...") + result = await compact(input_msg) + + content = result.content + if isinstance(content, dict): + content = content.get('content', '') + + print(f" Summary length: {len(content)} chars") + print(f" First 200 chars: {content[:200]}...") + + # Basic assertions + assert content, "Summary should not be empty" + assert len(content) > 100, "Summary should be substantial" + content_lower = content.lower() + assert any(word in content_lower for word in ['memory', 'refactor', 'compact']), \ + "Summary should mention key topics" + + print(" Real model compact: OK") + + +# ── Runner ──────────────────────────────────────────────────────── + +async def main(): + run_real = True + + mock_tests = [ + test_threshold_tokens, + test_below_threshold, + test_above_threshold, + test_no_extra_info, + test_circuit_breaker, + test_circuit_breaker_reset, + test_default_template, + test_custom_template, + test_estimate_token_count_basic, + test_estimate_token_count_with_tools, + test_estimate_token_count_empty, + test_forward_with_mock_llm, + test_forward_with_list_dict_content, + test_compact_aggregator_empty_content, + ] + + for test in mock_tests: + await test() + print(f" {test.__name__}: OK") + print(" Mock tests: ALL PASSED") + + if run_real: + print("\n --- Real model tests ---") + await test_real_model_compact() + print(" Real model tests: ALL PASSED") + else: + print("\n (skip real model tests, use --real to enable)") + + print("\n ALL PASSED") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/tests/test_agents/test_compact_integration.py b/tests/test_agents/test_compact_integration.py new file mode 100644 index 00000000..2b9c01b3 --- /dev/null +++ b/tests/test_agents/test_compact_integration.py @@ -0,0 +1,363 @@ +"""Integration tests for CompactAgent + Provider + InternClawAgent. + +Tests the compact data flow end-to-end: + 1. Provider + CompactAgent interaction + 2. Full InternClawAgent loop with compact triggering + 3. ContextBuilder handling compact_boundary + summary + +Usage: + python tests/test_agents/test_compact_integration.py # mock only + python tests/test_agents/test_compact_integration.py --real # real model +""" + +import asyncio +import sys +import tempfile +from pathlib import Path + +from lagent.agents.compact_agent import AsyncCompactAgent, estimate_token_count +from lagent.agents.internclaw_agent import ( + AsyncEnvAgent, AsyncPolicyAgent, InternClawAgent, +) +from lagent.agents.aggregator.context import InternClawContextBuilder +from lagent.memory import Memory, OpenClawMemoryProvider +from lagent.actions.save_memory import SaveMemoryAction +from lagent.schema import AgentMessage + + +# ── Helpers ─────────────────────────────────────────────────────── + +class MockLLM: + """Mock LLM that returns predictable responses.""" + + def __init__(self, responses=None): + self._responses = responses or [] + self._call_count = 0 + + async def chat(self, messages, **kwargs): + self._call_count += 1 + if self._responses: + resp = self._responses.pop(0) + if callable(resp): + return resp(messages, kwargs) + return resp + # Default: echo back with tool call to keep loop going + return {"content": f"Response #{self._call_count}"} + + +# ── Test 1: Provider + EnvAgent get_info injection ──────────────── + +async def test_provider_injects_into_env_info(): + """Verify provider.get_info() content appears in env_info['memory'].""" + with tempfile.TemporaryDirectory() as tmpdir: + workspace = Path(tmpdir) + mem_dir = workspace / "memory" + mem_dir.mkdir() + (mem_dir / "MEMORY.md").write_text("# Facts\n- User prefers Python\n- Project uses lagent") + + provider = OpenClawMemoryProvider(workspace) + env = AsyncEnvAgent(actions=[], long_term_memory=provider) + + info = await env.get_env_info() + assert info['memory']['available'] is True + assert 'Python' in info['memory']['long_term'] + assert 'lagent' in info['memory']['long_term'] + + +# ── Test 2: CompactAgent with formatted_messages ────────────────── + +async def test_compact_formats_and_summarizes(): + """Verify CompactAgent correctly formats list[dict] input and produces summary.""" + received = {} + + class CaptureLLM: + async def chat(self, messages, **kwargs): + received['messages'] = messages + return {"content": "## Summary\nUser is refactoring lagent memory system."} + + compact = AsyncCompactAgent(llm=CaptureLLM(), max_context_tokens=100, threshold_ratio=0.5) + + formatted_messages = [ + {'role': 'system', 'content': 'You are helpful.'}, + {'role': 'user', 'content': 'Refactor memory'}, + {'role': 'assistant', 'content': 'Sure, analyzing...'}, + ] + + input_msg = AgentMessage( + sender='orchestrator', + content=formatted_messages, + extra_info={'context_tokens': 80}, + ) + + # Should trigger compact (80 > 100 * 0.5 = 50) + assert compact.should_compact(input_msg) is True + + result = await compact(input_msg) + + # Verify LLM received properly formatted messages + assert received['messages'][0]['role'] == 'system' + assert 'CRITICAL' in received['messages'][0]['content'] # COMPACT_PROMPT + assert 'USER: Refactor memory' in received['messages'][1]['content'] + + content = result.content + if isinstance(content, dict): + content = content.get('content', '') + assert 'Summary' in content + + +# ── Test 3: ContextBuilder handles compact_boundary ─────────────── + +async def test_context_builder_with_compact_boundary(): + """Verify ContextBuilder skips messages before boundary and prepends summary.""" + with tempfile.TemporaryDirectory() as tmpdir: + workspace = Path(tmpdir) + builder = InternClawContextBuilder(workspace) + + mem = Memory() + # Simulate 6 messages, compact happened at index 4 + summary = "## Summary\nUser discussed memory refactoring." + env_info_with_compact = { + 'conversation_summary': summary, + 'compact_boundary': 4, + } + + mem.add(AgentMessage(sender='user', content='msg0', role='user')) + mem.add(AgentMessage(sender='agent', content='msg1', role='assistant')) + mem.add(AgentMessage(sender='user', content='msg2', role='user')) + mem.add(AgentMessage(sender='agent', content='msg3', role='assistant')) + # This message carries the compact info + mem.add(AgentMessage( + sender='user', content='msg4 (after compact)', role='user', + env_info=env_info_with_compact, + )) + mem.add(AgentMessage(sender='agent', content='msg5', role='assistant')) + + messages, tools = builder.aggregate(mem, name='agent') + + # Should have: system + summary + msg4 + msg5 + # msg0-msg3 should be skipped (before boundary index 4) + contents = [m.get('content', '') for m in messages] + + # First is system prompt + assert messages[0]['role'] == 'system' + + # Second should be the injected summary + assert 'Conversation Summary' in contents[1] + assert 'memory refactoring' in contents[1] + + # Messages after boundary should be present + assert any('msg4' in c for c in contents), f"msg4 not found in {contents}" + assert any('msg5' in c for c in contents), f"msg5 not found in {contents}" + + # Messages before boundary should NOT be present + assert not any('msg0' in c for c in contents), "msg0 should be skipped" + assert not any('msg1' in c for c in contents), "msg1 should be skipped" + + +async def test_context_builder_without_compact(): + """Verify ContextBuilder works normally without compact info.""" + with tempfile.TemporaryDirectory() as tmpdir: + workspace = Path(tmpdir) + builder = InternClawContextBuilder(workspace) + + mem = Memory() + mem.add(AgentMessage(sender='user', content='hello', role='user')) + mem.add(AgentMessage(sender='agent', content='hi', role='assistant')) + + messages, tools = builder.aggregate(mem, name='agent') + + contents = [m.get('content', '') for m in messages] + assert any('hello' in c for c in contents) + assert any('hi' in c for c in contents) + + +# ── Test 4: Provider + SaveMemoryAction + ContextBuilder ────────── + +async def test_provider_action_contextbuilder_flow(): + """Full flow: write via action → read via provider → inject into ContextBuilder.""" + with tempfile.TemporaryDirectory() as tmpdir: + workspace = Path(tmpdir) + + # Write memory via action + action = SaveMemoryAction(workspace) + await action.save_memory( + memory_update="# Facts\n- User likes Rust", + history_entry="[2026-04-09 10:00] User stated preference for Rust", + ) + + # Read via provider + provider = OpenClawMemoryProvider(workspace) + info = await provider.get_info() + assert 'Rust' in info['long_term'] + + # Build context with provider info + builder = InternClawContextBuilder(workspace) + mem = Memory() + mem.add(AgentMessage( + sender='user', content='What do I like?', role='user', + env_info={'memory': info}, + )) + + messages, tools = builder.aggregate(mem, name='agent') + system_prompt = messages[0]['content'] + assert 'Rust' in system_prompt # Memory injected into system prompt + + +# ── Test 5: Full InternClawAgent loop with compact ──────────────── + +async def test_internclaw_compact_triggers(): + """Verify compact triggers during InternClawAgent loop when tokens exceed threshold.""" + compact_called = {'count': 0} + + class PolicyLLM: + """Simulates policy: returns tool_calls for first N turns, then stops.""" + def __init__(self): + self._turn = 0 + + async def chat(self, messages, **kwargs): + self._turn += 1 + if self._turn <= 3: + return { + 'content': f'Let me check turn {self._turn}', + 'tool_calls': [{'id': f'call_{self._turn}', 'function': {'name': 'test_tool', 'arguments': '{}'}}], + } + return {'content': 'All done, no more tools needed.'} + + class CompactLLM: + async def chat(self, messages, **kwargs): + compact_called['count'] += 1 + return {'content': '## Summary\nCompacted conversation.'} + + with tempfile.TemporaryDirectory() as tmpdir: + workspace = Path(tmpdir) + + # Very low threshold so compact triggers easily + compact = AsyncCompactAgent( + llm=CompactLLM(), + max_context_tokens=50, + threshold_ratio=0.1, # triggers at 5 tokens + ) + + from lagent.agents.aggregator import DefaultAggregator + policy = AsyncPolicyAgent(llm=PolicyLLM(), aggregator=DefaultAggregator()) + + # Minimal env that just passes through + env = AsyncEnvAgent(actions=[]) + + agent = InternClawAgent( + policy_agent=policy, + env_agent=env, + compact_agent=compact, + max_turn=5, + ) + + result = await agent("Start a conversation about memory refactoring") + + # Compact should have been called at least once + assert compact_called['count'] > 0, \ + f"Compact should have triggered, but was called {compact_called['count']} times" + + +# ── Real model test ─────────────────────────────────────────────── + +async def test_real_compact_with_provider(): + """Integration: real LLM + provider + compact.""" + from lagent.llms.model import AsyncAPIClient, ModelConfig, SampleParameters + + model_name = "gpt-5.4-mini" + api_base = "http://35.220.164.252:3888/v1" + api_key = "" + proxy = "http://100.100.72.89:8899" + + model = AsyncAPIClient( + model=ModelConfig(model=model_name, base_url=api_base, api_key=api_key, proxy=proxy), + sample_params=SampleParameters(temperature=0.7, top_p=1.0, top_k=50), + timeout=600, + max_retry=500, + sleep_interval=5, + ) + + with tempfile.TemporaryDirectory() as tmpdir: + workspace = Path(tmpdir) + + # Setup provider with some existing memory + mem_dir = workspace / "memory" + mem_dir.mkdir() + (mem_dir / "MEMORY.md").write_text( + "# Project Context\n- Working on lagent memory refactoring\n- User prefers minimal abstractions" + ) + + provider = OpenClawMemoryProvider(workspace) + compact = AsyncCompactAgent( + llm=model, + max_context_tokens=128_000, + ) + + # Build formatted_messages that include provider content + env_info = await provider.get_info() + + formatted_messages = [ + {'role': 'system', 'content': f"You are helpful.\n\nMemory:\n{env_info.get('long_term', '')}"}, + {'role': 'user', 'content': 'Help me design the memory system'}, + {'role': 'assistant', 'content': 'Based on the project context, I see you prefer minimal abstractions. Let me propose a design.'}, + {'role': 'user', 'content': 'Yes, Memory should be a pure list, no LTM base class'}, + {'role': 'assistant', 'content': 'Agreed. Provider=read, Action=write, both independent.'}, + {'role': 'user', 'content': 'What about compact?'}, + {'role': 'assistant', 'content': 'CompactAgent is a standard AsyncAgent with its own aggregator.'}, + ] + + input_msg = AgentMessage(sender='orchestrator', content=formatted_messages) + + print("\n Calling real model for compact + provider integration...") + result = await compact(input_msg) + + content = result.content + if isinstance(content, dict): + content = content.get('content', '') + + print(f" Summary length: {len(content)} chars") + print(f" First 300 chars:\n {content[:300]}") + + assert content and len(content) > 50, "Summary should be substantial" + content_lower = content.lower() + assert any(w in content_lower for w in ['memory', 'compact', 'provider']), \ + "Summary should mention key topics from the conversation" + + # Verify the provider's memory content influenced the summary + assert any(w in content_lower for w in ['minimal', 'abstraction', 'lagent']), \ + "Summary should reflect project context from provider" + + print(" Real compact + provider: OK") + + +# ── Runner ──────────────────────────────────────────────────────── + +async def main(): + run_real = True + + mock_tests = [ + test_provider_injects_into_env_info, + test_compact_formats_and_summarizes, + test_context_builder_with_compact_boundary, + test_context_builder_without_compact, + test_provider_action_contextbuilder_flow, + test_internclaw_compact_triggers, + ] + + for test in mock_tests: + await test() + print(f" {test.__name__}: OK") + print(" Mock tests: ALL PASSED") + + if run_real: + print("\n --- Real model integration tests ---") + await test_real_compact_with_provider() + print(" Real integration tests: ALL PASSED") + else: + print("\n (skip real model tests, use --real to enable)") + + print("\n ALL PASSED") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/tests/test_agents/test_framework_agents.py b/tests/test_agents/test_framework_agents.py new file mode 100644 index 00000000..cb680e72 --- /dev/null +++ b/tests/test_agents/test_framework_agents.py @@ -0,0 +1,90 @@ +"""Tests for InternClawAgent initialization and provider integration.""" + +import asyncio +import tempfile +from pathlib import Path + +from lagent.agents.internclaw_agent import AsyncEnvAgent, AsyncPolicyAgent, InternClawAgent +from lagent.agents.compact_agent import AsyncCompactAgent +from lagent.memory.openclaw_provider import OpenClawMemoryProvider +from lagent.memory.claude_code_provider import ClaudeCodeMemoryProvider + + +class MockLLM: + async def chat(self, messages, **kwargs): + return {"content": "mock response"} + + +async def test_internclaw_with_compact(): + llm = MockLLM() + agent = InternClawAgent( + policy_agent=AsyncPolicyAgent(llm=llm), + env_agent=AsyncEnvAgent(actions=[]), + compact_agent=AsyncCompactAgent(llm=llm), + consolidate_agent=None, + max_turn=10, + ) + assert agent.compact_agent is not None + assert agent.consolidate_agent is None + + +async def test_internclaw_without_compact(): + """Claude Code style — no compact, no consolidate.""" + llm = MockLLM() + agent = InternClawAgent( + policy_agent=AsyncPolicyAgent(llm=llm), + env_agent=AsyncEnvAgent(actions=[]), + max_turn=10, + ) + assert agent.compact_agent is None + assert agent.consolidate_agent is None + + +async def test_env_with_openclaw_provider(): + with tempfile.TemporaryDirectory() as tmpdir: + workspace = Path(tmpdir) + mem_dir = workspace / "memory" + mem_dir.mkdir() + (mem_dir / "MEMORY.md").write_text("# User likes Python") + + provider = OpenClawMemoryProvider(workspace) + env = AsyncEnvAgent(actions=[], long_term_memory=provider) + info = await env.get_env_info() + assert 'Python' in info['memory']['long_term'] + + +async def test_env_with_claude_code_provider(): + with tempfile.TemporaryDirectory() as tmpdir: + d = Path(tmpdir) + (d / "MEMORY.md").write_text("- [Pref](pref.md) — likes Rust\n") + (d / "pref.md").write_text("User prefers Rust.\n") + + provider = ClaudeCodeMemoryProvider(d) + env = AsyncEnvAgent(actions=[], long_term_memory=provider) + info = await env.get_env_info() + assert info['memory']['available'] is True + assert 'Rust' in info['memory']['memories'][0] + + +async def test_env_without_provider(): + env = AsyncEnvAgent(actions=[]) + info = await env.get_env_info() + assert info['memory'] == '' + + +async def main(): + tests = [ + test_internclaw_with_compact, + test_internclaw_without_compact, + test_env_with_openclaw_provider, + test_env_with_claude_code_provider, + test_env_without_provider, + ] + for test in tests: + await test() + print(f" {test.__name__}: OK") + print(" ALL PASSED") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/tests/test_agents/test_internclaw_e2e.py b/tests/test_agents/test_internclaw_e2e.py new file mode 100644 index 00000000..6938dc33 --- /dev/null +++ b/tests/test_agents/test_internclaw_e2e.py @@ -0,0 +1,441 @@ +"""End-to-end tests for InternClawAgent with all memory modules. + +Tests: + 1. Consolidate agent (SaveMemoryAction writes to MEMORY.md + HISTORY.md) + 2. Full pipeline: policy + env + compact + consolidate + provider + contextbuilder + 3. Real model full pipeline + +Usage: + python tests/test_agents/test_internclaw_e2e.py # mock only + python tests/test_agents/test_internclaw_e2e.py --real # real model +""" + +import asyncio +import json +import sys +import tempfile +from pathlib import Path + +from lagent.agents.compact_agent import AsyncCompactAgent, estimate_token_count +from lagent.agents.internclaw_agent import ( + AsyncEnvAgent, AsyncPolicyAgent, InternClawAgent, +) +from lagent.agents.aggregator.context import InternClawContextBuilder +from lagent.agents.aggregator.default_aggregator import DefaultAggregator +from lagent.memory import Memory, OpenClawMemoryProvider +from lagent.actions.save_memory import SaveMemoryAction, AsyncSaveMemoryAction +from lagent.schema import AgentMessage + + +# ── Test 1: Consolidate agent writes to LTM ────────────────────── + +async def test_consolidate_agent_writes_memory(): + """Consolidate agent = InternClawAgent with SaveMemoryAction. + + Verify it calls save_memory tool and writes to MEMORY.md + HISTORY.md. + """ + + class ConsolidateLLM: + """Mock LLM that calls save_memory tool.""" + async def chat(self, messages, **kwargs): + return { + 'content': '', + 'tool_calls': [{ + 'id': 'call_1', + 'function': { + 'name': 'AsyncSaveMemoryAction', + 'arguments': json.dumps({ + 'history_entry': '[2026-04-09 12:00] Discussed memory refactoring. Decided to split compact and LTM.', + 'memory_update': '# Facts\n- User prefers minimal abstractions\n- Memory is a pure list\n- CompactAgent is a standard AsyncAgent', + }), + }, + }], + } + + with tempfile.TemporaryDirectory() as tmpdir: + workspace = Path(tmpdir) + save_action = AsyncSaveMemoryAction(workspace) + + consolidate_policy = AsyncPolicyAgent(llm=ConsolidateLLM()) + consolidate_env = AsyncEnvAgent(actions=[save_action]) + + consolidate_agent = InternClawAgent( + policy_agent=consolidate_policy, + env_agent=consolidate_env, + max_turn=1, + finish_condition=None, + ) + + # Run consolidation + input_msg = AgentMessage( + sender='orchestrator', + content='Consolidate the following conversation...', + ) + await consolidate_agent(input_msg) + + # Verify MEMORY.md was written + memory_file = workspace / "memory" / "MEMORY.md" + assert memory_file.exists(), "MEMORY.md should exist" + content = memory_file.read_text() + assert 'minimal abstractions' in content + assert 'pure list' in content + + # Verify HISTORY.md was appended + history_file = workspace / "memory" / "HISTORY.md" + assert history_file.exists(), "HISTORY.md should exist" + history = history_file.read_text() + assert '[2026-04-09 12:00]' in history + assert 'memory refactoring' in history + + +# ── Test 2: Consolidate + Provider round-trip ───────────────────── + +async def test_consolidate_then_provider_reads(): + """Consolidate writes → provider reads back the same content.""" + + class ConsolidateLLM: + async def chat(self, messages, **kwargs): + return { + 'content': '', + 'tool_calls': [{ + 'id': 'call_1', + 'function': { + 'name': 'AsyncSaveMemoryAction', + 'arguments': json.dumps({ + 'history_entry': '[2026-04-09] Round-trip test', + 'memory_update': '# Memory\n- Round-trip test passed', + }), + }, + }], + } + + with tempfile.TemporaryDirectory() as tmpdir: + workspace = Path(tmpdir) + + # Consolidate writes + consolidate = InternClawAgent( + policy_agent=AsyncPolicyAgent(llm=ConsolidateLLM()), + env_agent=AsyncEnvAgent(actions=[AsyncSaveMemoryAction(workspace)]), + max_turn=1, + finish_condition=None, + ) + await consolidate(AgentMessage(sender='test', content='consolidate')) + + # Provider reads + provider = OpenClawMemoryProvider(workspace) + info = await provider.get_info() + assert info['available'] is True + assert 'Round-trip test passed' in info['long_term'] + + +# ── Test 3: Full pipeline mock — compact + consolidate ──────────── + +async def test_full_pipeline_compact_and_consolidate(): + """Full InternClawAgent pipeline: policy loops → compact triggers → + consolidate writes LTM → compact compresses context. + """ + consolidate_called = {'count': 0} + compact_called = {'count': 0} + + class PolicyLLM: + def __init__(self): + self._turn = 0 + + async def chat(self, messages, **kwargs): + self._turn += 1 + if self._turn <= 2: + return { + 'content': f'Working on turn {self._turn}...', + 'tool_calls': [{ + 'id': f'call_{self._turn}', + 'function': {'name': 'test_tool', 'arguments': '{}'}, + }], + } + return {'content': 'All done.'} + + class ConsolidateLLM: + async def chat(self, messages, **kwargs): + consolidate_called['count'] += 1 + return { + 'content': '', + 'tool_calls': [{ + 'id': 'cons_1', + 'function': { + 'name': 'AsyncSaveMemoryAction', + 'arguments': json.dumps({ + 'history_entry': '[2026-04-09] Consolidated', + 'memory_update': '# Consolidated memory', + }), + }, + }], + } + + class CompactLLM: + async def chat(self, messages, **kwargs): + compact_called['count'] += 1 + return {'content': '## Summary\nConversation was about testing.'} + + with tempfile.TemporaryDirectory() as tmpdir: + workspace = Path(tmpdir) + provider = OpenClawMemoryProvider(workspace) + + # Consolidate agent + consolidate = InternClawAgent( + policy_agent=AsyncPolicyAgent(llm=ConsolidateLLM()), + env_agent=AsyncEnvAgent(actions=[AsyncSaveMemoryAction(workspace)]), + max_turn=1, + finish_condition=None, + ) + + # Compact agent — very low threshold to trigger + compact = AsyncCompactAgent( + llm=CompactLLM(), + max_context_tokens=50, + threshold_ratio=0.1, + ) + + # Main agent + agent = InternClawAgent( + policy_agent=AsyncPolicyAgent( + llm=PolicyLLM(), + aggregator=DefaultAggregator(), + ), + env_agent=AsyncEnvAgent( + actions=[], + long_term_memory=provider, + ), + compact_agent=compact, + consolidate_agent=consolidate, + max_turn=5, + ) + + result = await agent("Test the full pipeline") + + # Both should have been called + assert compact_called['count'] > 0, "Compact should have triggered" + assert consolidate_called['count'] > 0, "Consolidate should have triggered" + + # MEMORY.md should have been written by consolidation + memory_file = workspace / "memory" / "MEMORY.md" + assert memory_file.exists(), "Consolidation should have written MEMORY.md" + + +# ── Test 4: ContextBuilder with compact + provider together ────── + +async def test_context_builder_full_assembly(): + """ContextBuilder assembles: system prompt (with LTM) + compact summary + recent messages.""" + with tempfile.TemporaryDirectory() as tmpdir: + workspace = Path(tmpdir) + mem_dir = workspace / "memory" + mem_dir.mkdir() + (mem_dir / "MEMORY.md").write_text("# Facts\n- User likes Python") + + builder = InternClawContextBuilder(workspace) + provider = OpenClawMemoryProvider(workspace) + env_info = await provider.get_info() + + # Build a session with compact boundary + mem = Memory() + # Old messages (before compact) + mem.add(AgentMessage(sender='user', content='old msg 0', role='user')) + mem.add(AgentMessage(sender='agent', content='old msg 1', role='assistant')) + mem.add(AgentMessage(sender='user', content='old msg 2', role='user')) + mem.add(AgentMessage(sender='agent', content='old msg 3', role='assistant')) + # Message carrying compact info + provider memory + env_info_with_compact = { + **env_info, + 'conversation_summary': '## Summary\nDiscussed Python preferences.', + 'compact_boundary': 4, + } + mem.add(AgentMessage( + sender='user', content='new msg after compact', role='user', + env_info=env_info_with_compact, + )) + mem.add(AgentMessage(sender='agent', content='response after compact', role='assistant')) + + messages, tools = builder.aggregate(mem, name='agent') + + # System prompt should contain LTM content + system_prompt = messages[0]['content'] + assert 'Python' in system_prompt, "System prompt should include LTM facts" + + # Summary should be injected + all_content = ' '.join(m.get('content', '') or '' for m in messages) + assert 'Summary' in all_content, "Compact summary should be present" + assert 'Python preferences' in all_content + + # New messages should be present + assert 'new msg after compact' in all_content + assert 'response after compact' in all_content + + # Old messages should NOT be present + assert 'old msg 0' not in all_content, "Old messages should be skipped" + assert 'old msg 3' not in all_content + + +# ── Test 5: Multiple compact rounds ────────────────────────────── + +async def test_multiple_compact_rounds(): + """Verify compact can trigger multiple times in a long session.""" + compact_count = {'n': 0} + + class PolicyLLM: + def __init__(self): + self._turn = 0 + + async def chat(self, messages, **kwargs): + self._turn += 1 + if self._turn <= 6: + return { + 'content': f'Turn {self._turn} ' + 'x' * 100, # padding for tokens + 'tool_calls': [{ + 'id': f'c_{self._turn}', + 'function': {'name': 'noop', 'arguments': '{}'}, + }], + } + return {'content': 'Done.'} + + class CompactLLM: + async def chat(self, messages, **kwargs): + compact_count['n'] += 1 + return {'content': f'## Summary round {compact_count["n"]}'} + + compact = AsyncCompactAgent( + llm=CompactLLM(), + max_context_tokens=30, + threshold_ratio=0.1, + ) + + agent = InternClawAgent( + policy_agent=AsyncPolicyAgent( + llm=PolicyLLM(), + aggregator=DefaultAggregator(), + ), + env_agent=AsyncEnvAgent(actions=[]), + compact_agent=compact, + max_turn=8, + ) + + await agent("Start a long conversation") + assert compact_count['n'] >= 2, f"Expected multiple compacts, got {compact_count['n']}" + + +# ── Real model test ─────────────────────────────────────────────── + +async def test_real_full_pipeline(): + """Real model: consolidate + compact + provider full round-trip.""" + from lagent.llms.model import AsyncAPIClient, ModelConfig, SampleParameters + + model_name = "gpt-5.4-mini" + api_base = "http://35.220.164.252:3888/v1" + api_key = "" + proxy = "http://100.100.72.89:8899" + + model = AsyncAPIClient( + model=ModelConfig(model=model_name, base_url=api_base, api_key=api_key, proxy=proxy), + sample_params=SampleParameters(temperature=0.7, top_p=1.0, top_k=50), + timeout=600, + max_retry=500, + sleep_interval=5, + ) + + with tempfile.TemporaryDirectory() as tmpdir: + workspace = Path(tmpdir) + + # Seed initial memory + mem_dir = workspace / "memory" + mem_dir.mkdir() + (mem_dir / "MEMORY.md").write_text("# Initial\n- Project: lagent memory refactoring") + + provider = OpenClawMemoryProvider(workspace) + + # Consolidate agent with real model + consolidate = InternClawAgent( + policy_agent=AsyncPolicyAgent(llm=model), + env_agent=AsyncEnvAgent(actions=[AsyncSaveMemoryAction(workspace)]), + max_turn=1, + finish_condition=None, + ) + + # Compact agent with real model + compact = AsyncCompactAgent(llm=model, max_context_tokens=128_000) + + # Main agent — just test consolidate independently first + print("\n Testing consolidation with real model...") + consolidate_input = AgentMessage( + sender='test', + content=( + 'Consolidate this conversation:\n' + 'USER: I want to refactor the memory system\n' + 'ASSISTANT: Sure, we should split compact from LTM\n' + 'USER: Memory should be a pure list\n' + 'ASSISTANT: Agreed, no state tracking in Memory class\n' + '\nCall the save_memory tool with your consolidation.' + ), + ) + await consolidate(consolidate_input) + + # Check if consolidation wrote to files + memory_content = (workspace / "memory" / "MEMORY.md").read_text() + print(f" MEMORY.md after consolidation ({len(memory_content)} chars):") + print(f" {memory_content[:200]}...") + + # Provider should reflect the update + info = await provider.get_info() + print(f" Provider get_info: available={info.get('available')}") + + # Now test compact with the same model + print("\n Testing compact with real model...") + formatted_messages = [ + {'role': 'system', 'content': f"You are helpful.\n\nMemory:\n{info.get('long_term', '')}"}, + {'role': 'user', 'content': 'Refactor memory system'}, + {'role': 'assistant', 'content': 'Split compact from LTM'}, + {'role': 'user', 'content': 'Memory = pure list'}, + {'role': 'assistant', 'content': 'Done, no state tracking'}, + ] + + compact_result = await compact(AgentMessage( + sender='test', content=formatted_messages, + )) + + content = compact_result.content + if isinstance(content, dict): + content = content.get('content', '') + print(f" Compact summary ({len(content)} chars):") + print(f" {content[:200]}...") + + assert content and len(content) > 20, "Compact summary should be substantial" + + print("\n Real full pipeline: OK") + + +# ── Runner ──────────────────────────────────────────────────────── + +async def main(): + run_real = True + + mock_tests = [ + test_consolidate_agent_writes_memory, + test_consolidate_then_provider_reads, + test_full_pipeline_compact_and_consolidate, + test_context_builder_full_assembly, + test_multiple_compact_rounds, + ] + + for test in mock_tests: + await test() + print(f" {test.__name__}: OK") + print(" Mock tests: ALL PASSED") + + if run_real: + print("\n --- Real model E2E tests ---") + await test_real_full_pipeline() + print(" Real E2E tests: ALL PASSED") + else: + print("\n (skip real model tests, use --real to enable)") + + print("\n ALL PASSED") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/tests/test_agents/test_rewoo.py b/tests/test_agents/test_rewoo.py deleted file mode 100644 index 52fa1ed7..00000000 --- a/tests/test_agents/test_rewoo.py +++ /dev/null @@ -1,87 +0,0 @@ -from unittest import TestCase, mock - -from lagent.actions import ActionExecutor -from lagent.actions.llm_qa import LLMQA -from lagent.actions.serper_search import SerperSearch -from lagent.agents.rewoo import ReWOO, ReWOOProtocol -from lagent.schema import ActionReturn, ActionStatusCode - - -class TestReWOO(TestCase): - - @mock.patch.object(SerperSearch, 'run') - @mock.patch.object(LLMQA, 'run') - @mock.patch.object(ReWOOProtocol, 'parse_worker') - def test_normal_chat(self, mock_parse_worker_func, mock_qa_func, - mock_search_func): - mock_model = mock.Mock() - mock_model.generate_from_template.return_value = 'LLM response' - - mock_parse_worker_func.return_value = (['Thought1', 'Thought2' - ], ['LLMQA', 'SerperSearch'], - ['abc', 'abc']) - - search_return = ActionReturn(args=None) - search_return.state = ActionStatusCode.SUCCESS - search_return.result = dict(text='search_return') - mock_search_func.return_value = search_return - - qa_return = ActionReturn(args=None) - qa_return.state = ActionStatusCode.SUCCESS - qa_return.result = dict(text='qa_return') - mock_qa_func.return_value = qa_return - - chatbot = ReWOO( - llm=mock_model, - action_executor=ActionExecutor(actions=[ - LLMQA(mock_model), - SerperSearch(api_key=''), - ])) - agent_return = chatbot.chat('abc') - self.assertEqual(agent_return.response, 'LLM response') - - def test_parse_worker(self): - prompt = ReWOOProtocol() - message = """ - Plan: a. - #E1 = tool1["a"] - #E2 = tool2["b"] - """ - try: - thoughts, actions, actions_input = prompt.parse_worker(message) - except Exception as e: - self.assertEqual( - 'Each Plan should only correspond to only ONE action', str(e)) - else: - self.assertFalse( - True, 'it should raise exception when the format is incorrect') - - message = """ - Plan: a. - #E1 = tool1("a") - Plan: b. - #E2 = tool2["b"] - """ - try: - thoughts, actions, actions_input = prompt.parse_worker(message) - except Exception as e: - self.assertIsInstance(e, BaseException) - else: - self.assertFalse( - True, 'it should raise exception when the format is incorrect') - - message = """ - Plan: a. - #E1 = tool1["a"] - Plan: b. - #E2 = tool2["b"] - """ - try: - thoughts, actions, actions_input = prompt.parse_worker(message) - except Exception: - self.assertFalse( - True, - 'it should not raise exception when the format is correct') - self.assertEqual(thoughts, ['a.', 'b.']) - self.assertEqual(actions, ['tool1', 'tool2']) - self.assertEqual(actions_input, ['"a"', '"b"']) diff --git a/tests/test_agents/test_skills.py b/tests/test_agents/test_skills.py new file mode 100644 index 00000000..ec062f06 --- /dev/null +++ b/tests/test_agents/test_skills.py @@ -0,0 +1,76 @@ +import json +from pathlib import Path +from unittest import TestCase, mock + +from lagent.skills.skills import FilesystemSkillsBackend, SkillsLoader, SandboxSkillsBackend +from lagent.schema import ActionReturn, ActionStatusCode + + +class TestFilesystemSkillsBackend(TestCase): + + def test_list_and_read_skill(self): + with mock.patch('tempfile.TemporaryDirectory'): + pass + + def test_workspace_skill_overrides_builtin(self): + with mock.patch('tempfile.TemporaryDirectory'): + pass + + +class TestSkillsLoader(TestCase): + + def test_list_skills_filters_unavailable(self): + with mock.patch('tempfile.TemporaryDirectory'): + pass + + def test_build_skills_summary_marks_missing_requirements(self): + with mock.patch('tempfile.TemporaryDirectory'): + pass + + def test_get_always_skills(self): + with mock.patch('tempfile.TemporaryDirectory'): + pass + + +class DummySandboxAction: + + def __init__(self, responses): + self.responses = responses + self.calls = [] + + async def run(self, session_id, command): + self.calls.append((session_id, command)) + payload = self.responses.get(command) + action_return = ActionReturn(args={'session_id': session_id, 'command': command}) + if payload is None: + action_return.state = ActionStatusCode.API_ERROR + action_return.result = [] + action_return.errmsg = 'not found' + return action_return + + action_return.state = ActionStatusCode.SUCCESS + action_return.result = [dict(type='text', content=payload)] + return action_return + + +class TestSandboxSkillsBackend(TestCase): + + def test_list_skill_entries_with_builtin_fallback(self): + with mock.patch('tempfile.TemporaryDirectory'): + pass + + def test_read_skill_prefers_backend_then_builtin(self): + with mock.patch('tempfile.TemporaryDirectory'): + pass + + def test_default_command_builder(self): + backend = SandboxSkillsBackend(action=mock.Mock(), workspace_root='/workspace', session_id='s1') + self.assertIn("Path('/workspace/skills')", backend._default_command_builder('list', '')) + self.assertEqual(backend._default_command_builder('read', 'demo'), "cat '/workspace/skills/demo/SKILL.md'") + + def test_run_raises_inside_active_loop(self): + backend = SandboxSkillsBackend(action=mock.Mock()) + + with mock.patch('asyncio.get_running_loop', return_value=object()): + with self.assertRaisesRegex(RuntimeError, 'active event loop'): + backend._run('echo 1') diff --git a/tests/test_e2e_subagent.py b/tests/test_e2e_subagent.py new file mode 100644 index 00000000..9b9883ed --- /dev/null +++ b/tests/test_e2e_subagent.py @@ -0,0 +1,227 @@ +"""End-to-end integration test: AgentLoader → AgentService → AsyncAgentAction. + +No mocks on the three core modules. Agent projects live under +``tests/data/agents/`` and are loaded via real AgentLoader. + +Two agent types: + - ``simple-agent``: EchoAgent, no LLM, no network — always runs. + - ``e2e-agent``: InternClawAgent with real LLM — needs network. + +Run: + pytest tests/test_e2e_subagent.py -v -s +""" + +import asyncio +import sys +from pathlib import Path + +import pytest +import pytest_asyncio + +ROOT = Path(__file__).resolve().parent.parent +if str(ROOT) not in sys.path: + sys.path.insert(0, str(ROOT)) + +from lagent.actions.subagent import AsyncAgentAction +from lagent.schema import ActionStatusCode +from lagent.services.agent import AgentService, AgentStatus +from lagent.services.agent_loader import AgentLoader + +DATA_DIR = Path(__file__).parent / "data" + + +# ── Fixtures ───────────────────────────────────────────────────────── + + +@pytest_asyncio.fixture +async def service(): + """Real AgentService backed by tests/data/agents/.""" + loader = AgentLoader(DATA_DIR) + svc = AgentService(agent_loader=loader) + await svc.load_specs() + yield svc + await svc.shutdown() + + +@pytest.fixture +def action(service): + """Real AsyncAgentAction wrapping the service.""" + return AsyncAgentAction(agent_service=service) + + +# ── Helpers ────────────────────────────────────────────────────────── + + +def _text(ret): + if ret.result and len(ret.result) > 0: + return ret.result[0].get("content", "") + return "" + + +# ── Discovery ──────────────────────────────────────────────────────── + + +@pytest.mark.asyncio +async def test_loader_discovers_test_agents(): + loader = AgentLoader(DATA_DIR) + entries = await loader.list() + names = {e["name"] for e in entries} + assert "simple-agent" in names + assert "e2e-agent" in names + + +@pytest.mark.asyncio +async def test_service_loads_specs(service): + assert "simple-agent" in service.available_types + assert "e2e-agent" in service.available_types + + +# ── simple-agent (no LLM, no network) ─────────────────────────────── + + +@pytest.mark.asyncio +async def test_simple_spawn_sync(service, action): + ret = await action.spawn( + agent_type="simple-agent", + task="hello world", + label="simple-sync", + mode="sync", + ) + assert ret.state != ActionStatusCode.API_ERROR, ret.errmsg + assert "echo: hello world" in _text(ret) + + +@pytest.mark.asyncio +async def test_simple_spawn_async_and_query(service, action): + ret = await action.spawn( + agent_type="simple-agent", + task="background task", + label="simple-async", + mode="async", + ) + assert ret.state != ActionStatusCode.API_ERROR, ret.errmsg + + # Wait for completion + entries = service.list(agent_type="simple-agent") + for e in entries: + task = service._tasks.get(e.id) + if task: + await task + + # Query result + entry = [e for e in service.list() if e.label == "simple-async"][0] + query_ret = await action.query_agent(agent_id=entry.id) + assert "echo: background task" in _text(query_ret) + + +@pytest.mark.asyncio +async def test_simple_list_agents(service, action): + await action.spawn( + agent_type="simple-agent", task="t1", label="label-a", mode="sync", + ) + await action.spawn( + agent_type="simple-agent", task="t2", label="label-b", mode="sync", + ) + + ret = await action.list_agents() + text = _text(ret) + assert "label-a" in text + assert "label-b" in text + + +@pytest.mark.asyncio +async def test_simple_resume(service, action): + # First run + ret = await action.spawn( + agent_type="simple-agent", task="ping", label="resumable", mode="sync", + ) + assert "echo: ping" in _text(ret) + + # Find the entry + entry = [e for e in service.list() if e.label == "resumable"][0] + + # Resume + resume_ret = await action.resume_agent(agent_id=entry.id, message="pong") + assert resume_ret.state != ActionStatusCode.API_ERROR, resume_ret.errmsg + assert "echo: pong" in _text(resume_ret) + + +@pytest.mark.asyncio +async def test_simple_stop(service, action): + # Need a slow agent to test stop — simple-agent finishes instantly, + # so spawn it async and try to stop (may already be done) + ret = await action.spawn( + agent_type="simple-agent", task="fast", label="stoppable", mode="async", + ) + entry = [e for e in service.list() if e.label == "stoppable"][0] + + # Try stop — might already be done + stop_ret = await action.stop_agent(agent_id=entry.id) + # Either stopped or already finished — both are valid + assert stop_ret.state != ActionStatusCode.API_ERROR or "not running" in (stop_ret.errmsg or "") + + +# ── Error paths ────────────────────────────────────────────────────── + + +@pytest.mark.asyncio +async def test_spawn_unknown_type(service, action): + ret = await action.spawn(agent_type="no-such-agent", task="fail") + assert ret.state == ActionStatusCode.API_ERROR + assert "Unknown agent type" in ret.errmsg + + +@pytest.mark.asyncio +async def test_query_nonexistent(service, action): + ret = await action.query_agent(agent_id="nonexistent") + assert ret.state == ActionStatusCode.API_ERROR + assert "not found" in ret.errmsg + + +@pytest.mark.asyncio +async def test_resume_nonexistent(service, action): + ret = await action.resume_agent(agent_id="nonexistent", message="hi") + assert ret.state == ActionStatusCode.API_ERROR + assert "not found" in ret.errmsg + + +@pytest.mark.asyncio +async def test_stop_nonexistent(service, action): + ret = await action.stop_agent(agent_id="nonexistent") + assert ret.state == ActionStatusCode.API_ERROR + + +# ── e2e-agent (real LLM) ──────────────────────────────────────────── + + +@pytest.mark.asyncio +async def test_e2e_spawn_sync_real_llm(service, action): + """Full e2e with real LLM call.""" + ret = await action.spawn( + agent_type="e2e-agent", + task="Reply with exactly: HELLO_E2E", + label="e2e-real-llm", + mode="sync", + ) + print(f"\n[e2e spawn] state={ret.state}, errmsg={ret.errmsg}") + if ret.result: + print(f"[e2e spawn] {_text(ret)[:200]}") + + assert ret.state != ActionStatusCode.API_ERROR, f"LLM call failed: {ret.errmsg}" + assert "HELLO_E2E" in _text(ret) + + +@pytest.mark.asyncio +async def test_e2e_resume_real_llm(service, action): + """Spawn then resume with real LLM.""" + ret = await action.spawn( + agent_type="e2e-agent", task="Say PING", label="e2e-resume", mode="sync", + ) + if ret.state == ActionStatusCode.API_ERROR: + pytest.skip(f"LLM not available: {ret.errmsg}") + + entry = [e for e in service.list() if e.label == "e2e-resume"][0] + + resume_ret = await action.resume_agent(agent_id=entry.id, message="Now say PONG") + print(f"\n[e2e resume] {_text(resume_ret)[:200]}") + assert resume_ret.state != ActionStatusCode.API_ERROR, resume_ret.errmsg diff --git a/tests/test_llms/test_anthropic.py b/tests/test_llms/test_anthropic.py new file mode 100644 index 00000000..025a7bde --- /dev/null +++ b/tests/test_llms/test_anthropic.py @@ -0,0 +1,47 @@ +import anthropic + +client = anthropic.Anthropic(api_key=" ", base_url='http://100.100.72.89:3888') +model_name = "claude-sonnet-4-20250514-thinking" + +# 第一轮请求:开启 Extended Thinking +print("=== 第一轮请求 (Claude) ===") +response1 = client.messages.create( + model=model_name, + max_tokens=4096, + thinking={ + "type": "enabled", + "budget_tokens": 1024 # 给 1024 个 token 的思考预算 + }, + messages=[{"role": "user", "content": "请用极其复杂的逻辑证明 1+1=2,必须深入思考。"}] +) + +# 打印解析结果 +for block in response1.content: + if block.type == "thinking": + print(f"🧠 [思考过程]: {block.thinking[:50]}...") + print(f"🔒[防篡改签名]: {block.signature[:30]}...\n") + elif block.type == "text": + print(f"💬 [最终回答]: {block.text[:50]}...\n") + +# 第二轮请求:状态复用(必须原封不动传回完整 Content) +print("=== 第二轮请求(复用思考上下文) ===") +messages_history =[ + {"role": "user", "content": "请用极其复杂的逻辑证明 1+1=2,必须深入思考。"}, + # 【核心点】直接传入上一次的 response1.content (包含 text 和 thinking 及其 signature) + {"role": "assistant", "content": response1.content}, + {"role": "user", "content": "基于你刚才的思考,再推导一下 1+2=3。"} +] + +try: + response2 = client.messages.create( + model=model_name, + max_tokens=4096, + thinking={"type": "enabled", "budget_tokens": 1024}, + messages=messages_history + ) + print("✅ 第二轮请求成功!Claude 成功验证了上一次的思考签名。") + + # 【破坏性试验】如果你尝试修改 response1.content 里 thinking 的任何一个字, + # 再次请求时,Anthropic API 会直接抛出 400 错误:Signature Validation Failed. +except Exception as e: + print(f"❌ 请求失败: {e}") \ No newline at end of file diff --git a/tests/test_llms/test_gemini.py b/tests/test_llms/test_gemini.py new file mode 100644 index 00000000..05874504 --- /dev/null +++ b/tests/test_llms/test_gemini.py @@ -0,0 +1,28 @@ +from google import genai +from google.genai import types +from google.genai.types import HttpOptions + +client = genai.Client(api_key=" ", base_url='http://100.100.72.89:3888') +model_name = "gemini-2.0-flash-thinking-exp-01-21" # 确保使用支持思考的模型 + +print("=== 第一轮请求 (Gemini) ===") +# 使用 Stateful 的 Chats API 可以最直观地体现上下文复用 +chat = client.chats.create(model=model_name) + +response1 = chat.send_message("解释一下量子纠缠,要先自己仔细思考一遍。") + +# 【核心验证点】:遍历 Parts,查看 Google 是如何把思考和内容分开的 +for part in response1.candidates[0].content.parts: + # 思考块在 SDK 层会有额外的标识 (取决于具体版本,有时是以纯文本加上特定标记返回,目前最新版有单独处理) + if getattr(part, 'thought', False): # 或者检查是否有 thought 属性 + print(f"🧠 [思考过程 (Thought/Summary)]: {part.text[:50]}...") + else: + print(f"💬 [最终回答]: {part.text[:50]}...\n") + +print("=== 第二轮请求(Chat对象自动回传上下文) ===") +# 当我们调用 send_message 时,SDK 底层抓取了 response1 的全套 Parts (包含思考和加密状态) 发回。 +response2 = chat.send_message("能不能用通俗的比喻把你刚才想的过程讲一遍?") + +print("✅ 第二轮请求成功!") +print(f"💬 [第二轮回答]: {response2.text[:50]}...") +print("💡 结论:Gemini 把复用机制封装在了 Chat Session 里,内部自动携带了第一轮的 thought parts,防止开发者手动拼接出错或篡改。") \ No newline at end of file diff --git a/tests/test_llms/test_openai.py b/tests/test_llms/test_openai.py new file mode 100644 index 00000000..eb2da3af --- /dev/null +++ b/tests/test_llms/test_openai.py @@ -0,0 +1,38 @@ +from openai import OpenAI + +client = OpenAI(api_key=" ", base_url='http://100.100.72.89:3888/v1') +model_name = "o3-mini" + +print("=== 第一轮请求 (OpenAI o3-mini) ===") +# 使用标准 API 进行测试 +messages =[{"role": "user", "content": "用复杂逻辑分析为什么天空是蓝色的。"}] + +response1 = client.chat.completions.create( + model=model_name, + reasoning_effort="high", # 控制思考深度 + messages=messages +) + +assistant_message = response1.choices[0].message +# 【核心验证点 1】:content 里直接就是最终答案,没有任何思考过程 +print(f"💬[最终回答]: {assistant_message.content[:50]}...") + +# 【核心验证点 2】:真实的思考 Token 数量被藏在 usage 里统计(你得付钱,但看不到内容) +reasoning_tokens = response1.usage.completion_tokens_details.reasoning_tokens +print(f"🧠[隐藏的思考Token数量]: {reasoning_tokens}\n") + +print("=== 第二轮请求(依靠会话历史维持逻辑) ===") +# 把它返回的干净 message (没有任何隐藏思考块) 追加回去 +messages.append(assistant_message) +messages.append({"role": "user", "content": "好的,根据你刚才的逻辑,解释一下晚霞为什么是红的。"}) + +response2 = client.chat.completions.create( + model=model_name, + reasoning_effort="high", + messages=messages +) +print("✅ 第二轮请求成功。虽然你看不到思考过程,但模型通过上一轮生成的最终回答重建了逻辑链。") + +# 备注:在 Assistants API 中,你甚至不需要传 messages 数组, +# 只需要传入 thread_id: client.beta.threads.messages.create(thread_id="xxx", ...) +# 服务端会自动提取隐藏在该 Thread 里的加密思考上下文(Encrypted Reasoning Items)。 \ No newline at end of file diff --git a/tests/test_memory/test_base_memory.py b/tests/test_memory/test_base_memory.py new file mode 100644 index 00000000..9180db21 --- /dev/null +++ b/tests/test_memory/test_base_memory.py @@ -0,0 +1,100 @@ +"""Tests for lagent.memory.base_memory.Memory""" + +import asyncio + +from lagent.memory.base_memory import Memory +from lagent.schema import AgentMessage + + +async def test_add_and_len(): + m = Memory() + m.add(AgentMessage(sender='user', content='hello')) + m.add(AgentMessage(sender='assistant', content='hi')) + m.add(AgentMessage(sender='user', content='how are you')) + assert len(m.memory) == 3 + + +async def test_delete_multi_index(): + m = Memory() + for i in range(5): + m.add(AgentMessage(sender='user', content=f'msg{i}')) + m.delete([1, 3]) + remaining = [msg.content for msg in m.memory] + assert remaining == ['msg0', 'msg2', 'msg4'] + + +async def test_delete_single_index(): + m = Memory() + for i in range(3): + m.add(AgentMessage(sender='user', content=f'msg{i}')) + m.delete(1) + remaining = [msg.content for msg in m.memory] + assert remaining == ['msg0', 'msg2'] + + +async def test_reset(): + m = Memory() + m.add(AgentMessage(sender='user', content='hello')) + m.reset() + assert len(m.memory) == 0 + + +async def test_save_load_roundtrip(): + m = Memory() + m.add(AgentMessage(sender='user', content='test')) + saved = m.save() + m2 = Memory() + m2.load(saved) + assert len(m2.memory) == 1 + assert m2.memory[0].content == 'test' + + +async def test_get_memory_all(): + m = Memory() + for i in range(5): + m.add(AgentMessage(sender='user', content=f'msg{i}')) + assert len(m.get_memory()) == 5 + + +async def test_get_memory_with_filter(): + m = Memory() + m.add(AgentMessage(sender='user', content='keep')) + m.add(AgentMessage(sender='assistant', content='skip')) + m.add(AgentMessage(sender='user', content='keep2')) + filtered = m.get_memory(filter_func=lambda i, msg: msg.sender == 'user') + assert len(filtered) == 2 + + +async def test_add_string(): + m = Memory() + m.add('hello') + assert len(m.memory) == 1 + assert m.memory[0].content == 'hello' + + +async def test_add_none(): + m = Memory() + m.add(None) + assert len(m.memory) == 0 + + +async def main(): + tests = [ + test_add_and_len, + test_delete_multi_index, + test_delete_single_index, + test_reset, + test_save_load_roundtrip, + test_get_memory_all, + test_get_memory_with_filter, + test_add_string, + test_add_none, + ] + for test in tests: + await test() + print(f" {test.__name__}: OK") + print(" ALL PASSED") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/tests/test_memory/test_claude_code_provider.py b/tests/test_memory/test_claude_code_provider.py new file mode 100644 index 00000000..c669cd9d --- /dev/null +++ b/tests/test_memory/test_claude_code_provider.py @@ -0,0 +1,90 @@ +"""Tests for lagent.memory.claude_code_provider""" + +import asyncio +import tempfile +from pathlib import Path + +from lagent.memory.claude_code_provider import ClaudeCodeMemoryProvider + + +async def test_get_info_empty(): + with tempfile.TemporaryDirectory() as tmpdir: + provider = ClaudeCodeMemoryProvider(Path(tmpdir)) + info = await provider.get_info() + assert info == {} + + +async def test_get_info_with_files(): + with tempfile.TemporaryDirectory() as tmpdir: + d = Path(tmpdir) + (d / "user_role.md").write_text( + "---\nname: user role\ntype: user\n---\nSenior engineer.\n" + ) + (d / "feedback.md").write_text( + "---\nname: feedback\ntype: feedback\n---\nNo mocks.\n" + ) + (d / "MEMORY.md").write_text( + "- [Role](user_role.md) — engineer\n" + "- [Feedback](feedback.md) — no mocks\n" + ) + + provider = ClaudeCodeMemoryProvider(d) + info = await provider.get_info() + assert info['available'] is True + assert len(info['memories']) == 2 + assert 'Senior engineer' in info['memories'][0] + assert 'No mocks' in info['memories'][1] + + +async def test_broken_link_skipped(): + with tempfile.TemporaryDirectory() as tmpdir: + d = Path(tmpdir) + (d / "exists.md").write_text("Content here.\n") + (d / "MEMORY.md").write_text( + "- [Exists](exists.md) — ok\n" + "- [Missing](gone.md) — not found\n" + ) + + provider = ClaudeCodeMemoryProvider(d) + info = await provider.get_info() + assert len(info['memories']) == 1 + + +async def test_no_actions(): + with tempfile.TemporaryDirectory() as tmpdir: + provider = ClaudeCodeMemoryProvider(Path(tmpdir)) + assert provider.actions == [] + + +async def test_parse_various_link_formats(): + with tempfile.TemporaryDirectory() as tmpdir: + d = Path(tmpdir) + (d / "a.md").write_text("A") + (d / "b.md").write_text("B") + (d / "MEMORY.md").write_text( + "- [Title with spaces](a.md) — description\n" + " - [Indented](b.md) — also works\n" + "some text without links\n" + ) + + provider = ClaudeCodeMemoryProvider(d) + info = await provider.get_info() + assert len(info['memories']) == 2 + + +async def main(): + tests = [ + test_get_info_empty, + test_get_info_with_files, + test_broken_link_skipped, + test_no_actions, + test_parse_various_link_formats, + ] + for test in tests: + await test() + print(f" {test.__name__}: OK") + print(" ALL PASSED") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/tests/test_memory/test_openclaw_provider.py b/tests/test_memory/test_openclaw_provider.py new file mode 100644 index 00000000..eb2fd33a --- /dev/null +++ b/tests/test_memory/test_openclaw_provider.py @@ -0,0 +1,93 @@ +"""Tests for lagent.memory.openclaw_provider + lagent.actions.save_memory""" + +import asyncio +import tempfile +from pathlib import Path + +from lagent.memory.openclaw_provider import OpenClawMemoryProvider +from lagent.actions.save_memory import SaveMemoryAction + + +async def test_get_info_empty(): + with tempfile.TemporaryDirectory() as tmpdir: + provider = OpenClawMemoryProvider(Path(tmpdir)) + info = await provider.get_info() + assert info == {} + + +async def test_save_then_read(): + with tempfile.TemporaryDirectory() as tmpdir: + workspace = Path(tmpdir) + provider = OpenClawMemoryProvider(workspace) + action = SaveMemoryAction(workspace) + + await action.save_memory( + history_entry="[2026-04-08 10:00] User discussed refactoring", + memory_update="# Facts\n- Refactoring lagent memory", + ) + + info = await provider.get_info() + assert info['available'] is True + assert 'Refactoring' in info['long_term'] + + +async def test_history_append(): + with tempfile.TemporaryDirectory() as tmpdir: + workspace = Path(tmpdir) + action = SaveMemoryAction(workspace) + + await action.save_memory(history_entry="[2026-04-08] First entry") + await action.save_memory(history_entry="[2026-04-08] Second entry") + + history = (workspace / "memory" / "HISTORY.md").read_text() + assert "First entry" in history + assert "Second entry" in history + + +async def test_memory_overwrite(): + with tempfile.TemporaryDirectory() as tmpdir: + workspace = Path(tmpdir) + provider = OpenClawMemoryProvider(workspace) + action = SaveMemoryAction(workspace) + + await action.save_memory(memory_update="version 1") + await action.save_memory(memory_update="version 2") + + info = await provider.get_info() + assert 'version 2' in info['long_term'] + assert 'version 1' not in info['long_term'] + + +async def test_provider_and_action_independent(): + """Provider and action don't reference each other.""" + with tempfile.TemporaryDirectory() as tmpdir: + workspace = Path(tmpdir) + provider = OpenClawMemoryProvider(workspace) + action = SaveMemoryAction(workspace) + + # No shared state + assert not hasattr(provider, '_action') + assert not hasattr(action, '_provider') + + # But they operate on the same storage + await action.save_memory(memory_update="shared data") + info = await provider.get_info() + assert 'shared data' in info['long_term'] + + +async def main(): + tests = [ + test_get_info_empty, + test_save_then_read, + test_history_append, + test_memory_overwrite, + test_provider_and_action_independent, + ] + for test in tests: + await test() + print(f" {test.__name__}: OK") + print(" ALL PASSED") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/tests/test_services/test_agent.py b/tests/test_services/test_agent.py new file mode 100644 index 00000000..90ce3c65 --- /dev/null +++ b/tests/test_services/test_agent.py @@ -0,0 +1,1003 @@ +"""Unit tests for AgentService. + +Run: + pytest tests/test_agent_service.py -v +""" + +import asyncio +import json +import sys +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +ROOT = Path(__file__).resolve().parent.parent +if str(ROOT) not in sys.path: + sys.path.insert(0, str(ROOT)) + +from lagent.services.agent import AgentEntry, AgentService, AgentStatus, _now_ms +from lagent.services.agent_loader import AgentLoader, AgentSpec + + +# ── Helpers ────────────────────────────────────────────────────────── + + +def _make_mock_agent(name="mock", content="done"): + """Create a mock async agent that returns an AgentMessage-like object.""" + response = MagicMock() + response.content = content + + agent = AsyncMock() + agent.name = name + agent.return_value = response + agent.state_dict = MagicMock(return_value={"memory": [{"content": content}]}) + agent.load_state_dict = MagicMock() + return agent + + +def _make_slow_agent(delay=10): + """Create a mock agent that takes a long time (for testing stop/cancel).""" + async def slow(*args, **kwargs): + await asyncio.sleep(delay) + return MagicMock(content="slow done") + + agent = AsyncMock(side_effect=slow) + agent.state_dict = MagicMock(return_value={}) + return agent + + +def _make_failing_agent(error_msg="fail"): + """Create a mock agent that raises on call.""" + agent = AsyncMock(side_effect=RuntimeError(error_msg)) + agent.state_dict = MagicMock(return_value={}) + return agent + + +def _make_spec(name="test-agent", build_agent=None): + """Create an AgentSpec with a mock build function.""" + agent = build_agent or _make_mock_agent(name) + + def build(config): + return agent + + async def abuild(config): + return agent + + return AgentSpec( + name=name, + description=f"Test agent: {name}", + agent_config={"type": "mock"}, + build=build, + ), agent + + +# ── AgentEntry ─────────────────────────────────────────────────────── + + +class TestAgentEntry: + + def test_default_id_generated(self): + entry = AgentEntry() + assert len(entry.id) == 8 + + def test_to_dict(self): + entry = AgentEntry( + id="abc12345", + agent_type="reviewer", + label="review task", + task="review this code", + status=AgentStatus.RUNNING, + ) + d = entry.to_dict() + assert d["id"] == "abc12345" + assert d["agent_type"] == "reviewer" + assert d["status"] == "running" + assert d["result"] is None + + def test_from_dict(self): + data = { + "id": "xyz", + "agent_type": "translator", + "label": "translate", + "task": "translate this", + "status": "stopped", + "result": "translated text", + "error": None, + "created_at_ms": 1000, + "finished_at_ms": 2000, + } + entry = AgentEntry.from_dict(data) + assert entry.id == "xyz" + assert entry.agent_type == "translator" + assert entry.status == "stopped" + assert entry.result == "translated text" + assert entry.finished_at_ms == 2000 + + def test_from_dict_defaults(self): + entry = AgentEntry.from_dict({}) + assert len(entry.id) == 8 + assert entry.status == AgentStatus.PENDING + + def test_roundtrip(self): + entry = AgentEntry( + agent_type="test", label="lbl", task="tsk", + status=AgentStatus.STOPPED, result="ok", + ) + restored = AgentEntry.from_dict(entry.to_dict()) + assert restored.agent_type == entry.agent_type + assert restored.result == entry.result + + +# ── AgentService: init & specs ─────────────────────────────────────── + + +class TestAgentServiceSpecs: + + def test_init_empty(self): + service = AgentService() + assert service.available_types == [] + + def test_register_spec(self): + service = AgentService() + spec, _ = _make_spec("my-agent") + service.register_spec(spec) + assert "my-agent" in service.available_types + assert service.get_spec("my-agent") is spec + + def test_get_spec_not_found(self): + service = AgentService() + assert service.get_spec("nonexistent") is None + + @pytest.mark.asyncio + async def test_load_specs_from_loader(self, tmp_path): + agents = tmp_path / "agents" + agent_dir = agents / "simple" + agent_dir.mkdir(parents=True) + (agent_dir / "config.py").write_text( + "from lagent.agents.agent import Agent\n" + "agent_config = dict(type=Agent, name='simple')\n" + "name = 'simple'\n" + "description = 'A simple agent'\n", + encoding="utf-8", + ) + + loader = AgentLoader(tmp_path) + service = AgentService(agent_loader=loader) + await service.load_specs() + + assert "simple" in service.available_types + + @pytest.mark.asyncio + async def test_load_specs_without_loader(self): + service = AgentService() + await service.load_specs() # should not raise + assert service.available_types == [] + + +# ── AgentService: spawn (sync) ─────────────────────────────────────── + + +class TestAgentServiceSpawnSync: + + @pytest.mark.asyncio + async def test_spawn_sync_success(self): + service = AgentService() + spec, mock_agent = _make_spec("worker") + service.register_spec(spec) + + entry = await service.spawn("worker", "do something", mode="sync") + + assert entry.status == AgentStatus.STOPPED + assert entry.result == "done" + assert entry.finished_at_ms is not None + assert entry.agent_type == "worker" + mock_agent.assert_awaited_once() + + @pytest.mark.asyncio + async def test_spawn_sync_with_label(self): + service = AgentService() + spec, _ = _make_spec("worker") + service.register_spec(spec) + + entry = await service.spawn("worker", "task", label="my label", mode="sync") + assert entry.label == "my label" + + @pytest.mark.asyncio + async def test_spawn_sync_auto_label(self): + service = AgentService() + spec, _ = _make_spec("worker") + service.register_spec(spec) + + long_task = "x" * 100 + entry = await service.spawn("worker", long_task, mode="sync") + assert len(entry.label) <= 41 # 40 chars + ellipsis + + @pytest.mark.asyncio + async def test_spawn_unknown_type_raises(self): + service = AgentService() + with pytest.raises(ValueError, match="Unknown agent type"): + await service.spawn("nonexistent", "task") + + @pytest.mark.asyncio + async def test_spawn_sync_failure(self): + agent = AsyncMock(side_effect=RuntimeError("boom")) + agent.state_dict = MagicMock(return_value={}) + + spec = AgentSpec( + name="failing", + agent_config={"type": "mock"}, + build=lambda cfg: agent, + ) + service = AgentService() + service.register_spec(spec) + + entry = await service.spawn("failing", "task", mode="sync") + assert entry.status == AgentStatus.FAILED + assert "boom" in entry.error + + @pytest.mark.asyncio + async def test_spawn_saves_state_after_completion(self): + """After sync execution, agent state should be saved to _saved_states.""" + service = AgentService() + spec, mock_agent = _make_spec("worker") + service.register_spec(spec) + + entry = await service.spawn("worker", "task", mode="sync") + + assert entry.id in service._saved_states + assert service._saved_states[entry.id] == {"memory": [{"content": "done"}]} + + @pytest.mark.asyncio + async def test_spawn_saves_state_even_on_failure(self): + """State should be saved even if the agent fails.""" + agent = AsyncMock(side_effect=RuntimeError("fail")) + agent.state_dict = MagicMock(return_value={"memory": []}) + + spec = AgentSpec( + name="failing", + agent_config={"type": "mock"}, + build=lambda cfg: agent, + ) + service = AgentService() + service.register_spec(spec) + + entry = await service.spawn("failing", "task", mode="sync") + assert entry.status == AgentStatus.FAILED + assert entry.id in service._saved_states + + @pytest.mark.asyncio + async def test_spawn_with_state_transfer(self): + """If state dict is provided, load_state_dict is called.""" + service = AgentService() + spec, mock_agent = _make_spec("worker") + service.register_spec(spec) + + state = {"memory": [{"content": "previous"}]} + await service.spawn("worker", "continue", mode="sync", state=state) + + mock_agent.load_state_dict.assert_called_once_with(state) + + @pytest.mark.asyncio + async def test_spawn_agent_removed_from_live_after_completion(self): + """Live _agents dict should be empty after sync completion.""" + service = AgentService() + spec, _ = _make_spec("worker") + service.register_spec(spec) + + await service.spawn("worker", "task", mode="sync") + assert len(service._agents) == 0 + + +# ── AgentService: spawn (async) ────────────────────────────────────── + + +class TestAgentServiceSpawnAsync: + + @pytest.mark.asyncio + async def test_spawn_async_returns_immediately(self): + service = AgentService() + spec, _ = _make_spec("worker") + service.register_spec(spec) + + entry = await service.spawn("worker", "task", mode="async") + + # Should return with PENDING or RUNNING, not STOPPED + assert entry.status in (AgentStatus.PENDING, AgentStatus.RUNNING) + assert entry.id in service._tasks + + # Wait for completion + await service._tasks[entry.id] + assert entry.status == AgentStatus.STOPPED + assert entry.result == "done" + + @pytest.mark.asyncio + async def test_spawn_async_saves_state(self): + service = AgentService() + spec, _ = _make_spec("worker") + service.register_spec(spec) + + entry = await service.spawn("worker", "task", mode="async") + await service._tasks[entry.id] + + assert entry.id in service._saved_states + + @pytest.mark.asyncio + async def test_spawn_async_on_complete_callback(self): + completed = [] + + async def on_complete(entry): + completed.append(entry) + + service = AgentService(on_complete=on_complete) + spec, _ = _make_spec("worker") + service.register_spec(spec) + + entry = await service.spawn("worker", "task", mode="async") + await service._tasks[entry.id] + + assert len(completed) == 1 + assert completed[0].id == entry.id + + @pytest.mark.asyncio + async def test_spawn_async_failure(self): + agent = AsyncMock(side_effect=RuntimeError("async boom")) + agent.state_dict = MagicMock(return_value={}) + + spec = AgentSpec( + name="failing", + agent_config={"type": "mock"}, + build=lambda cfg: agent, + ) + service = AgentService() + service.register_spec(spec) + + entry = await service.spawn("failing", "task", mode="async") + # Wait for the task to complete + task = service._tasks.get(entry.id) + if task is not None: + await task + + assert entry.status == AgentStatus.FAILED + assert "async boom" in entry.error + + +# ── AgentService: query ────────────────────────────────────────────── + + +class TestAgentServiceQuery: + + @pytest.mark.asyncio + async def test_list_all(self): + service = AgentService() + spec, _ = _make_spec("worker") + service.register_spec(spec) + + await service.spawn("worker", "task1", mode="sync") + await service.spawn("worker", "task2", mode="sync") + + entries = service.list() + assert len(entries) == 2 + + @pytest.mark.asyncio + async def test_list_filter_by_status(self): + service = AgentService() + + spec_ok, _ = _make_spec("ok-agent") + service.register_spec(spec_ok) + + spec_fail = AgentSpec( + name="fail-agent", + agent_config={"type": "mock"}, + build=lambda cfg: _make_failing_agent("fail"), + ) + service.register_spec(spec_fail) + + await service.spawn("ok-agent", "ok task", mode="sync") + await service.spawn("fail-agent", "bad task", mode="sync") + + stopped = service.list(status=AgentStatus.STOPPED) + failed = service.list(status=AgentStatus.FAILED) + assert len(stopped) == 1 + assert len(failed) == 1 + + @pytest.mark.asyncio + async def test_list_filter_by_agent_type(self): + service = AgentService() + spec1, _ = _make_spec("type-a") + spec2, _ = _make_spec("type-b") + service.register_spec(spec1) + service.register_spec(spec2) + + await service.spawn("type-a", "task", mode="sync") + await service.spawn("type-b", "task", mode="sync") + + a_only = service.list(agent_type="type-a") + assert len(a_only) == 1 + assert a_only[0].agent_type == "type-a" + + @pytest.mark.asyncio + async def test_get_by_id(self): + service = AgentService() + spec, _ = _make_spec("worker") + service.register_spec(spec) + + entry = await service.spawn("worker", "task", mode="sync") + assert service.get(entry.id) is entry + + def test_get_not_found(self): + service = AgentService() + assert service.get("nonexistent") is None + + +# ── AgentService: resume ───────────────────────────────────────────── + + +class TestAgentServiceResume: + + @pytest.mark.asyncio + async def test_resume_reuses_entry(self): + service = AgentService() + spec, _ = _make_spec("worker") + service.register_spec(spec) + + entry = await service.spawn("worker", "first task", mode="sync") + original_id = entry.id + + resumed = await service.resume(original_id, "continue") + assert resumed.id == original_id + assert resumed.status == AgentStatus.STOPPED + + @pytest.mark.asyncio + async def test_resume_restores_state(self): + service = AgentService() + + call_count = 0 + agents = [] + + def build(cfg): + nonlocal call_count + call_count += 1 + agent = _make_mock_agent(content=f"result-{call_count}") + agents.append(agent) + return agent + + spec = AgentSpec( + name="stateful", + agent_config={"type": "mock"}, + build=build, + ) + service.register_spec(spec) + + entry = await service.spawn("stateful", "first", mode="sync") + assert entry.result == "result-1" + assert entry.id in service._saved_states + + resumed = await service.resume(entry.id, "second") + assert resumed.result == "result-2" + # The second agent should have had load_state_dict called + assert agents[1].load_state_dict.called + + @pytest.mark.asyncio + async def test_resume_not_found_raises(self): + service = AgentService() + with pytest.raises(ValueError, match="not found"): + await service.resume("nonexistent", "msg") + + @pytest.mark.asyncio + async def test_resume_running_raises(self): + service = AgentService() + spec = AgentSpec( + name="slow", + agent_config={"type": "mock"}, + build=lambda cfg: _make_slow_agent(), + ) + service.register_spec(spec) + + entry = await service.spawn("slow", "task", mode="async") + await asyncio.sleep(0.05) # let it start running + # Entry should be running + with pytest.raises(ValueError, match="still running"): + await service.resume(entry.id, "msg") + + # Cleanup + await service.shutdown() + + @pytest.mark.asyncio + async def test_resume_resets_entry_fields(self): + service = AgentService() + spec, _ = _make_spec("worker") + service.register_spec(spec) + + entry = await service.spawn("worker", "task", mode="sync") + assert entry.result is not None + assert entry.finished_at_ms is not None + + # Resume clears old result/error before re-running + resumed = await service.resume(entry.id, "new task") + assert resumed.result == "done" + assert resumed.error is None + + @pytest.mark.asyncio + async def test_resume_spec_not_found_raises(self): + service = AgentService() + spec, _ = _make_spec("temp") + service.register_spec(spec) + + entry = await service.spawn("temp", "task", mode="sync") + + # Remove the spec to simulate missing spec + del service._specs["temp"] + + with pytest.raises(ValueError, match="Agent spec.*not found"): + await service.resume(entry.id, "msg") + + +# ── AgentService: stop ─────────────────────────────────────────────── + + +class TestAgentServiceStop: + + @pytest.mark.asyncio + async def test_stop_running_agent(self): + spec = AgentSpec( + name="slow", + agent_config={"type": "mock"}, + build=lambda cfg: _make_slow_agent(), + ) + service = AgentService() + service.register_spec(spec) + + entry = await service.spawn("slow", "task", mode="async") + await asyncio.sleep(0.05) # let it start + + result = await service.stop(entry.id) + assert result is True + assert entry.status == AgentStatus.STOPPED + + @pytest.mark.asyncio + async def test_stop_nonexistent_returns_false(self): + service = AgentService() + result = await service.stop("nonexistent") + assert result is False + + +# ── AgentService: persistence ──────────────────────────────────────── + + +class TestAgentServicePersistence: + + @pytest.mark.asyncio + async def test_save_and_load_entries(self, tmp_path): + service = AgentService() + spec, _ = _make_spec("worker") + service.register_spec(spec) + + entry = await service.spawn("worker", "task", mode="sync") + + save_path = tmp_path / "agent_data" + await service.save_all(save_path) + + # Verify files created + assert (save_path / "agents.json").exists() + + # Load into a new service + service2 = AgentService() + await service2.load_all(save_path) + + assert entry.id in service2._entries + loaded = service2._entries[entry.id] + assert loaded.agent_type == "worker" + assert loaded.result == "done" + + @pytest.mark.asyncio + async def test_save_includes_saved_states(self, tmp_path): + service = AgentService() + spec, _ = _make_spec("worker") + service.register_spec(spec) + + entry = await service.spawn("worker", "task", mode="sync") + + save_path = tmp_path / "agent_data" + await service.save_all(save_path) + + # saved_states should also be persisted + states_dir = save_path / "states" + # Note: save_all saves _agents (running), not _saved_states. + # But _saved_states are populated after completion. + # The current save_all only saves live _agents, so states_dir + # may be empty for completed agents. + + @pytest.mark.asyncio + async def test_load_from_nonexistent_path(self, tmp_path): + service = AgentService() + await service.load_all(tmp_path / "nonexistent") + assert len(service._entries) == 0 + + @pytest.mark.asyncio + async def test_load_restores_saved_states(self, tmp_path): + save_path = tmp_path / "agent_data" + save_path.mkdir(parents=True) + + # Write entries + entries = [{"id": "abc", "agent_type": "worker", "status": "stopped", + "task": "t", "label": "l", "result": "r"}] + (save_path / "agents.json").write_text( + json.dumps(entries), encoding="utf-8", + ) + + # Write state + states_dir = save_path / "states" + states_dir.mkdir() + (states_dir / "abc.json").write_text( + json.dumps({"memory": [{"msg": "hello"}]}), encoding="utf-8", + ) + + service = AgentService() + await service.load_all(save_path) + + assert "abc" in service._entries + assert "abc" in service._saved_states + assert service._saved_states["abc"]["memory"][0]["msg"] == "hello" + + +# ── AgentService: shutdown ─────────────────────────────────────────── + + +class TestAgentServiceShutdown: + + @pytest.mark.asyncio + async def test_shutdown_stops_all_running(self): + spec = AgentSpec( + name="slow", + agent_config={"type": "mock"}, + build=lambda cfg: _make_slow_agent(), + ) + service = AgentService() + service.register_spec(spec) + + await service.spawn("slow", "task1", mode="async") + await service.spawn("slow", "task2", mode="async") + await asyncio.sleep(0.05) + + assert len(service._tasks) == 2 + + await service.shutdown() + assert len(service._tasks) == 0 + + +# ── AgentService: remove ───────────────────────────────────────────── + + +class TestAgentServiceRemove: + + @pytest.mark.asyncio + async def test_remove_stopped_entry(self): + service = AgentService() + spec, _ = _make_spec("worker") + service.register_spec(spec) + + entry = await service.spawn("worker", "task", mode="sync") + assert service.remove(entry.id) is True + assert service.get(entry.id) is None + + @pytest.mark.asyncio + async def test_remove_running_returns_false(self): + spec = AgentSpec( + name="slow", + agent_config={"type": "mock"}, + build=lambda cfg: _make_slow_agent(), + ) + service = AgentService() + service.register_spec(spec) + + entry = await service.spawn("slow", "task", mode="async") + await asyncio.sleep(0.05) + + assert service.remove(entry.id) is False + await service.shutdown() + + def test_remove_nonexistent_returns_false(self): + service = AgentService() + assert service.remove("nonexistent") is False + + @pytest.mark.asyncio + async def test_remove_cleans_saved_states(self): + service = AgentService() + spec, _ = _make_spec("worker") + service.register_spec(spec) + + entry = await service.spawn("worker", "task", mode="sync") + assert entry.id in service._saved_states + + service.remove(entry.id) + assert entry.id not in service._saved_states + + +# ── AgentService: concurrency ──────────────────────────────────────── + + +class TestAgentServiceConcurrency: + + @pytest.mark.asyncio + async def test_semaphore_limits_concurrent(self): + """max_concurrent limits how many agents run simultaneously.""" + running = [] + max_seen = 0 + + async def slow_call(*args, **kwargs): + nonlocal max_seen + running.append(1) + max_seen = max(max_seen, len(running)) + await asyncio.sleep(0.1) + running.pop() + return MagicMock(content="done") + + def make_agent(cfg): + agent = AsyncMock(side_effect=slow_call) + agent.state_dict = MagicMock(return_value={}) + return agent + + spec = AgentSpec( + name="slow", + agent_config={"type": "mock"}, + build=make_agent, + ) + service = AgentService(max_concurrent=2) + service.register_spec(spec) + + # Spawn 4 async agents + entries = [] + for i in range(4): + e = await service.spawn("slow", f"task-{i}", mode="async") + entries.append(e) + + # Wait for all to complete + await asyncio.gather(*[service._tasks[e.id] for e in entries]) + + assert max_seen <= 2 + + +# ── Edge cases: warning/error branches ─────────────────────────────── + + +@pytest.mark.asyncio +async def test_spawn_dynamic_load_from_loader(tmp_path): + """L197-199: spawn loads spec via loader when not pre-registered.""" + agents = tmp_path / "agents" + agent_dir = agents / "dynamic" + agent_dir.mkdir(parents=True) + (agent_dir / "config.py").write_text( + "from lagent.agents.agent import Agent\n" + "agent_config = dict(type=Agent, name='dynamic')\n" + "name = 'dynamic'\n", + encoding="utf-8", + ) + + loader = AgentLoader(tmp_path) + service = AgentService(agent_loader=loader) + # Don't call load_specs — let spawn discover it dynamically + + # Agent will fail (no llm), but the spec should be loaded and cached + entry = await service.spawn("dynamic", "hello", mode="sync") + assert "dynamic" in service._specs # spec was dynamically loaded and cached + + +@pytest.mark.asyncio +async def test_spawn_state_transfer_failure_warns(): + """L226-227: load_state_dict failure is warned, not fatal.""" + agent = _make_mock_agent() + agent.load_state_dict = MagicMock(side_effect=RuntimeError("bad state")) + + spec = AgentSpec( + name="worker", + agent_config={"type": "mock"}, + build=lambda cfg: agent, + ) + service = AgentService() + service.register_spec(spec) + + entry = await service.spawn( + "worker", "task", mode="sync", + state={"memory": ["corrupted"]}, + ) + # Should succeed despite state transfer failure + assert entry.status == AgentStatus.STOPPED + assert entry.result == "done" + + +@pytest.mark.asyncio +async def test_run_sync_state_dict_failure_warns(): + """L263-264: state_dict() failure in finally is warned, not fatal.""" + agent = AsyncMock(return_value=MagicMock(content="ok")) + agent.state_dict = MagicMock(side_effect=RuntimeError("state broken")) + + spec = AgentSpec( + name="fragile", + agent_config={"type": "mock"}, + build=lambda cfg: agent, + ) + service = AgentService() + service.register_spec(spec) + + entry = await service.spawn("fragile", "task", mode="sync") + assert entry.status == AgentStatus.STOPPED + # State should NOT be saved + assert entry.id not in service._saved_states + + +@pytest.mark.asyncio +async def test_run_async_state_dict_failure_warns(): + """L294-295: async state_dict() failure in finally is warned.""" + agent = AsyncMock(return_value=MagicMock(content="ok")) + agent.state_dict = MagicMock(side_effect=RuntimeError("state broken")) + + spec = AgentSpec( + name="fragile", + agent_config={"type": "mock"}, + build=lambda cfg: agent, + ) + service = AgentService() + service.register_spec(spec) + + entry = await service.spawn("fragile", "task", mode="async") + await service._tasks[entry.id] + + assert entry.status == AgentStatus.STOPPED + assert entry.id not in service._saved_states + + +@pytest.mark.asyncio +async def test_on_complete_callback_failure_logged(): + """L303-304: on_complete callback failure is logged, not fatal.""" + async def bad_callback(entry): + raise RuntimeError("callback boom") + + service = AgentService(on_complete=bad_callback) + spec, _ = _make_spec("worker") + service.register_spec(spec) + + entry = await service.spawn("worker", "task", mode="async") + await service._tasks[entry.id] + + # Agent should still complete successfully despite callback failure + assert entry.status == AgentStatus.STOPPED + assert entry.result == "done" + + +@pytest.mark.asyncio +async def test_resume_load_state_dict_failure_warns(): + """L361-362: resume load_state_dict failure is warned.""" + call_count = 0 + + def build(cfg): + nonlocal call_count + call_count += 1 + agent = _make_mock_agent(content=f"r{call_count}") + if call_count == 2: + agent.load_state_dict = MagicMock( + side_effect=RuntimeError("restore failed") + ) + return agent + + spec = AgentSpec( + name="stateful", + agent_config={"type": "mock"}, + build=build, + ) + service = AgentService() + service.register_spec(spec) + + entry = await service.spawn("stateful", "first", mode="sync") + + # Resume should succeed despite state restore failure + resumed = await service.resume(entry.id, "second") + assert resumed.status == AgentStatus.STOPPED + assert resumed.result == "r2" + + +@pytest.mark.asyncio +async def test_save_all_saves_running_agent_states(tmp_path): + """L415-424: save_all persists live agent states to disk.""" + save_started = asyncio.Event() + save_done = asyncio.Event() + + async def controlled_agent(*args, **kwargs): + save_started.set() + await save_done.wait() + return MagicMock(content="done") + + agent = AsyncMock(side_effect=controlled_agent) + agent.state_dict = MagicMock(return_value={"memory": [{"msg": "live"}]}) + + spec = AgentSpec( + name="live", + agent_config={"type": "mock"}, + build=lambda cfg: agent, + ) + service = AgentService() + service.register_spec(spec) + + entry = await service.spawn("live", "task", mode="async") + await save_started.wait() + + # Save while agent is running + save_path = tmp_path / "data" + await service.save_all(save_path) + + state_file = save_path / "states" / f"{entry.id}.json" + assert state_file.exists() + saved = json.loads(state_file.read_text()) + assert saved["memory"][0]["msg"] == "live" + + # Let agent finish + save_done.set() + await service._tasks[entry.id] + + +@pytest.mark.asyncio +async def test_save_all_state_dict_failure_warns(tmp_path): + """L415-424: save_all handles state_dict failure gracefully.""" + save_started = asyncio.Event() + save_done = asyncio.Event() + + async def controlled_agent(*args, **kwargs): + save_started.set() + await save_done.wait() + return MagicMock(content="done") + + agent = AsyncMock(side_effect=controlled_agent) + agent.state_dict = MagicMock(side_effect=RuntimeError("cant serialize")) + + spec = AgentSpec( + name="broken", + agent_config={"type": "mock"}, + build=lambda cfg: agent, + ) + service = AgentService() + service.register_spec(spec) + + entry = await service.spawn("broken", "task", mode="async") + await save_started.wait() + + save_path = tmp_path / "data" + await service.save_all(save_path) # should not raise + + # No state file should be written + state_file = save_path / "states" / f"{entry.id}.json" + assert not state_file.exists() + + save_done.set() + await service._tasks[entry.id] + + +@pytest.mark.asyncio +async def test_load_all_corrupted_state_warns(tmp_path): + """L451-452: corrupted state file is warned, not fatal.""" + save_path = tmp_path / "data" + save_path.mkdir() + + (save_path / "agents.json").write_text( + json.dumps([{"id": "abc", "agent_type": "x", "status": "stopped", + "task": "t", "label": "l"}]), + encoding="utf-8", + ) + + states_dir = save_path / "states" + states_dir.mkdir() + (states_dir / "abc.json").write_text("NOT VALID JSON {{{", encoding="utf-8") + + service = AgentService() + await service.load_all(save_path) + + assert "abc" in service._entries + assert "abc" not in service._saved_states # failed to load + + +# ── _now_ms ────────────────────────────────────────────────────────── + + +def test_now_ms(): + ts = _now_ms() + assert isinstance(ts, int) + assert ts > 0 diff --git a/tests/test_services/test_agent_dynamic.py b/tests/test_services/test_agent_dynamic.py new file mode 100644 index 00000000..515d753d --- /dev/null +++ b/tests/test_services/test_agent_dynamic.py @@ -0,0 +1,226 @@ +"""Unit tests for AgentService + SubAgentAction (spec-based + build kwargs). + +Tests the new design where: +- AgentService only manages lifecycle (no default_llm) +- LLM/actions are passed via spec.acreate(**kwargs) → build(config, **kwargs) +- SubAgentAction holds default_llm + parent_actions +""" + +import asyncio +import os +import sys +import types + +import pytest + +# --- bypass circular import --- +_here = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +if _here not in sys.path: + sys.path.insert(0, _here) +for pkg_name, subdir in [ + ("lagent.services", "lagent/services"), + ("lagent.agents", "lagent/agents"), +]: + if pkg_name not in sys.modules: + _pkg = types.ModuleType(pkg_name) + _pkg.__path__ = [os.path.join(_here, *subdir.split("/"))] + _pkg.__package__ = pkg_name + sys.modules[pkg_name] = _pkg + +from lagent.agents.agent import AsyncAgent +from lagent.services.agent import AgentService, AgentStatus +from lagent.services.agent_loader import AgentSpec + + +# ── Mock LLM ───────────────────────────────────────────────────────── + +class MockLLM: + def __init__(self, response: str = "mock response"): + self._response = response + self.call_count = 0 + + async def chat(self, messages, **kwargs): + self.call_count += 1 + return {"content": self._response} + + +# ── Helper: register a "default" spec with build function ──────────── + +def _register_default_spec(svc: AgentService, default_llm: MockLLM = None): + """Register a 'default' spec whose build accepts llm/actions kwargs.""" + + def build(config, llm=None, actions=None, system_prompt=None, **kw): + template = system_prompt or config.get("template", "") + return AsyncAgent(llm=llm, template=template) + + spec = AgentSpec( + name="default", + agent_config={"type": "AsyncAgent", "template": "default agent"}, + build=build, + ) + svc.register_spec(spec) + + +# ═══════════════════════════════════════════════════════════════════════ +# SPEC-BASED SPAWN +# ═══════════════════════════════════════════════════════════════════════ + +@pytest.mark.asyncio +class TestSpawnFromSpec: + async def test_spawn_with_spec(self): + llm = MockLLM(response="spec result") + svc = AgentService() + _register_default_spec(svc, llm) + + entry = await svc.spawn( + task="do something", + agent_type="default", + llm=llm, + mode="sync", + ) + + assert entry.status == AgentStatus.STOPPED + assert entry.agent_type == "default" + assert entry.result is not None + + async def test_spawn_unknown_type_raises(self): + svc = AgentService() + with pytest.raises(ValueError, match="Unknown agent type"): + await svc.spawn(task="test", agent_type="nonexistent") + + async def test_kwargs_forwarded_to_build(self): + """Build function receives llm and actions via **spec_kwargs.""" + received = {} + + def build(config, **kwargs): + received.update(kwargs) + return AsyncAgent(llm=kwargs.get("llm"), template="test") + + spec = AgentSpec(name="spy", agent_config={}, build=build) + svc = AgentService() + svc.register_spec(spec) + + llm = MockLLM() + await svc.spawn( + task="test", agent_type="spy", + llm=llm, actions=["fake"], custom_param="hello", + mode="sync", + ) + + assert received["llm"] is llm + assert received["actions"] == ["fake"] + assert received["custom_param"] == "hello" + + async def test_shared_llm_instance(self): + """Multiple spawns share the same LLM reference.""" + llm = MockLLM() + svc = AgentService() + _register_default_spec(svc) + + await svc.spawn(task="t1", agent_type="default", llm=llm, mode="sync") + await svc.spawn(task="t2", agent_type="default", llm=llm, mode="sync") + + assert llm.call_count == 2 + + async def test_system_prompt_override(self): + """system_prompt kwarg should override config template.""" + received_template = {} + + def build(config, llm=None, system_prompt=None, **kw): + t = system_prompt or config.get("template", "") + received_template["value"] = t + return AsyncAgent(llm=llm, template=t) + + spec = AgentSpec( + name="default", + agent_config={"template": "original"}, + build=build, + ) + svc = AgentService() + svc.register_spec(spec) + + llm = MockLLM() + await svc.spawn( + task="test", agent_type="default", + llm=llm, system_prompt="overridden prompt", + mode="sync", + ) + + assert received_template["value"] == "overridden prompt" + + +# ═══════════════════════════════════════════════════════════════════════ +# SPAWN_AGENT (pre-built instance) +# ═══════════════════════════════════════════════════════════════════════ + +@pytest.mark.asyncio +class TestSpawnAgent: + async def test_spawn_prebuilt(self): + llm = MockLLM(response="prebuilt result") + agent = AsyncAgent(llm=llm, template="custom") + svc = AgentService() + + entry = await svc.spawn_agent(agent=agent, task="do it", mode="sync") + + assert entry.status == AgentStatus.STOPPED + assert entry.agent_type == "_custom" + assert entry.result is not None + + async def test_spawn_prebuilt_custom_type(self): + llm = MockLLM() + agent = AsyncAgent(llm=llm, template="test") + svc = AgentService() + + entry = await svc.spawn_agent( + agent=agent, task="test", agent_type="my_type", + ) + assert entry.agent_type == "my_type" + + async def test_spawn_prebuilt_async(self): + llm = MockLLM() + agent = AsyncAgent(llm=llm, template="test") + svc = AgentService() + + entry = await svc.spawn_agent(agent=agent, task="bg", mode="async") + await asyncio.sleep(0.5) + assert entry.status == AgentStatus.STOPPED + + +# ═══════════════════════════════════════════════════════════════════════ +# LIFECYCLE (list, get, stop) +# ═══════════════════════════════════════════════════════════════════════ + +@pytest.mark.asyncio +class TestLifecycle: + async def test_list_entries(self): + llm = MockLLM() + svc = AgentService() + _register_default_spec(svc) + + await svc.spawn(task="t1", agent_type="default", llm=llm, mode="sync") + await svc.spawn(task="t2", agent_type="default", llm=llm, mode="sync") + assert len(svc.list()) == 2 + + async def test_get_entry(self): + llm = MockLLM() + svc = AgentService() + _register_default_spec(svc) + + entry = await svc.spawn(task="t1", agent_type="default", llm=llm, mode="sync") + found = svc.get(entry.id) + assert found is not None + assert found.id == entry.id + + async def test_stop_async_agent(self): + llm = MockLLM() + svc = AgentService() + _register_default_spec(svc) + + entry = await svc.spawn(task="t", agent_type="default", llm=llm, mode="async") + await asyncio.sleep(0.1) + stopped = await svc.stop(entry.id) + assert isinstance(stopped, bool) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/test_services/test_agent_loader.py b/tests/test_services/test_agent_loader.py new file mode 100644 index 00000000..cb0ef800 --- /dev/null +++ b/tests/test_services/test_agent_loader.py @@ -0,0 +1,573 @@ +"""Unit tests for AgentLoader (pyconfig only). + +Run: + pytest tests/test_agent_loader.py -v +""" + +import sys +from pathlib import Path + +import pytest + +ROOT = Path(__file__).resolve().parent.parent +if str(ROOT) not in sys.path: + sys.path.insert(0, str(ROOT)) + +from lagent.services.agent_loader import AgentLoader, AgentSpec, _import_module_from_path + + +# ── Fixtures ───────────────────────────────────────────────────────── + + +CONFIG_STANDARD = """\ +from lagent.agents.internclaw_agent import ( + AsyncEnvAgent, + AsyncPolicyAgent, + InternClawAgent, +) +from lagent.llms.model import AsyncAPIClient +from lagent.agents.aggregator.context import InternClawContextBuilder + +name = "code-reviewer" +description = "Reviews code for quality and best practices" +background = False + +llm = dict( + type=AsyncAPIClient, + model=dict( + model="gpt-4o", + base_url="http://localhost:8000/v1", + api_key="test-key", + ), + sample_params=dict(temperature=0.7), + timeout=60, + max_retry=3, +) + +agent_config = dict( + type=InternClawAgent, + policy_agent=dict( + type=AsyncPolicyAgent, + llm=llm, + aggregator=dict(type=InternClawContextBuilder), + name="policy", + ), + env_agent=dict( + type=AsyncEnvAgent, + actions=[], + name="env", + ), + max_turn=50, +) +""" + +CONFIG_WITH_BUILD = """\ +from lagent.agents.internclaw_agent import ( + AsyncEnvAgent, + AsyncPolicyAgent, + InternClawAgent, +) +from lagent.llms.model import AsyncAPIClient +from lagent.agents.aggregator.context import InternClawContextBuilder + +name = "translator" +description = "Translates text between languages" +background = True + +llm = dict( + type=AsyncAPIClient, + model=dict( + model="gpt-4o-mini", + base_url="http://localhost:8000/v1", + api_key="test-key", + ), + timeout=60, +) + +agent_config = dict( + type=InternClawAgent, + policy_agent=dict( + type=AsyncPolicyAgent, + llm=llm, + aggregator=dict(type=InternClawContextBuilder), + name="policy", + ), + env_agent=dict( + type=AsyncEnvAgent, + actions=[], + name="env", + ), + max_turn=10, +) + +def build(config): + from lagent.utils import create_object + return create_object(config) +""" + + +@pytest.fixture +def workspace(tmp_path): + """Create a workspace with various agent project dirs.""" + agents = tmp_path / "agents" + + # 1. Standard InternClawAgent config + reviewer = agents / "code-reviewer" + reviewer.mkdir(parents=True) + (reviewer / "config.py").write_text(CONFIG_STANDARD, encoding="utf-8") + + # 2. Config with custom build() + translator = agents / "translator" + translator.mkdir(parents=True) + (translator / "config.py").write_text(CONFIG_WITH_BUILD, encoding="utf-8") + + # 3. Empty dir (should be skipped) + (agents / "empty-dir").mkdir(parents=True) + + # 4. Dir with only AGENT.md (should be skipped) + md_only = agents / "md-only" + md_only.mkdir(parents=True) + (md_only / "AGENT.md").write_text( + "---\nname: md-only\n---\nHello\n", + encoding="utf-8", + ) + + return tmp_path + + +# ── AgentLoader.list() ─────────────────────────────────────────────── + + +@pytest.mark.asyncio +async def test_list_discovers_pyconfig_only(workspace): + loader = AgentLoader(workspace) + entries = await loader.list() + names = {e["name"] for e in entries} + + assert names == {"code-reviewer", "translator"} + assert "empty-dir" not in names + assert "md-only" not in names + + +@pytest.mark.asyncio +async def test_list_returns_path(workspace): + loader = AgentLoader(workspace) + entries = await loader.list() + for entry in entries: + assert "path" in entry + assert Path(entry["path"]).is_dir() + + +# ── AgentLoader.load() ────────────────────────────────────────────── + + +@pytest.mark.asyncio +async def test_load_standard_config(workspace): + loader = AgentLoader(workspace) + spec = await loader.load("code-reviewer") + + assert spec is not None + assert spec.name == "code-reviewer" + assert spec.description == "Reviews code for quality and best practices" + assert spec.background is False + assert spec.build is None + assert spec.project_dir is not None + + # agent_config structure matches InternClawAgent pattern + cfg = spec.agent_config + assert cfg is not None + assert "policy_agent" in cfg + assert "env_agent" in cfg + assert cfg.get("max_turn") == 50 + assert cfg["policy_agent"]["name"] == "policy" + assert cfg["env_agent"]["name"] == "env" + + +@pytest.mark.asyncio +async def test_load_config_with_build(workspace): + loader = AgentLoader(workspace) + spec = await loader.load("translator") + + assert spec is not None + assert spec.name == "translator" + assert spec.description == "Translates text between languages" + assert spec.background is True + assert callable(spec.build) + + cfg = spec.agent_config + assert cfg is not None + assert cfg.get("max_turn") == 10 + + +@pytest.mark.asyncio +async def test_load_nonexistent_returns_none(workspace): + loader = AgentLoader(workspace) + assert await loader.load("does-not-exist") is None + + +@pytest.mark.asyncio +async def test_load_empty_dir_returns_none(workspace): + loader = AgentLoader(workspace) + assert await loader.load("empty-dir") is None + + +@pytest.mark.asyncio +async def test_load_md_only_returns_none(workspace): + loader = AgentLoader(workspace) + assert await loader.load("md-only") is None + + +# ── AgentLoader.load_all() ────────────────────────────────────────── + + +@pytest.mark.asyncio +async def test_load_all(workspace): + loader = AgentLoader(workspace) + specs = await loader.load_all() + + assert len(specs) == 2 + assert "code-reviewer" in specs + assert "translator" in specs + + +# ── AgentLoader.build_agents_summary() ────────────────────────────── + + +@pytest.mark.asyncio +async def test_build_agents_summary(workspace): + loader = AgentLoader(workspace) + summary = await loader.build_agents_summary() + + assert "" in summary + assert "code-reviewer" in summary + assert "translator" in summary + assert "true" in summary # translator is background + + +@pytest.mark.asyncio +async def test_build_agents_summary_empty(tmp_path): + loader = AgentLoader(tmp_path) + assert await loader.build_agents_summary() == "" + + +# ── AgentSpec.create() / acreate() ────────────────────────────────── +# NOTE: Real InternClawAgent instantiation requires live dependencies +# (LLM connections, workspace paths, etc.). These tests verify the +# config structure is correct. Full instantiation is covered by +# integration tests (test_agent_service_e2e.py). + + +@pytest.mark.asyncio +async def test_spec_config_has_internclaw_structure(workspace): + """Loaded config should have the InternClawAgent nested structure.""" + loader = AgentLoader(workspace) + spec = await loader.load("code-reviewer") + assert spec is not None + + cfg = spec.agent_config + assert cfg is not None + + # Top-level: InternClawAgent + from lagent.agents.internclaw_agent import InternClawAgent + assert cfg["type"] is InternClawAgent + + # Nested: policy_agent with LLM + policy = cfg["policy_agent"] + assert policy["name"] == "policy" + assert "llm" in policy + + # Nested: env_agent with actions + env = cfg["env_agent"] + assert env["name"] == "env" + assert "actions" in env + + # max_turn + assert cfg["max_turn"] == 50 + + +@pytest.mark.asyncio +async def test_spec_with_build_is_callable(workspace): + loader = AgentLoader(workspace) + spec = await loader.load("translator") + assert spec is not None + assert callable(spec.build) + assert spec.agent_config["max_turn"] == 10 + + +@pytest.mark.asyncio +async def test_spec_create_uses_build_when_provided(): + """When build is set, create() delegates to it.""" + created = [] + + def mock_build(config): + agent = type("MockAgent", (), {"name": config.get("name", "mock")})() + created.append(agent) + return agent + + spec = AgentSpec( + name="test", + agent_config={"name": "test-agent"}, + build=mock_build, + ) + agent = spec.create() + assert len(created) == 1 + assert agent.name == "test-agent" + + +@pytest.mark.asyncio +async def test_spec_acreate_uses_async_build(): + """acreate() handles async build functions.""" + async def async_build(config): + return type("MockAgent", (), {"name": "async-built"})() + + spec = AgentSpec( + name="test", + agent_config={"name": "test"}, + build=async_build, + ) + agent = await spec.acreate() + assert agent.name == "async-built" + + +def test_spec_create_no_config_raises(): + spec = AgentSpec(name="broken") + with pytest.raises(ValueError, match="has no agent_config"): + spec.create() + + +@pytest.mark.asyncio +async def test_spec_acreate_no_config_raises(): + spec = AgentSpec(name="broken") + with pytest.raises(ValueError, match="has no agent_config"): + await spec.acreate() + + +# ── AgentSpec serialization ───────────────────────────────────────── + + +def test_spec_to_dict_from_dict_roundtrip(): + spec = AgentSpec( + name="test", + description="A test agent", + background=True, + project_dir="/tmp/agents/test", + agent_config=dict(type="lagent.agents.Agent"), + extra={"model": "gpt-4"}, + ) + d = spec.to_dict() + assert d["name"] == "test" + assert d["project_dir"] == "/tmp/agents/test" + + restored = AgentSpec.from_dict(d) + assert restored.name == spec.name + assert restored.description == spec.description + assert restored.background == spec.background + assert restored.project_dir == spec.project_dir + assert restored.extra == spec.extra + assert restored.build is None + + +# ── sys.modules cleanup ───────────────────────────────────────────── + + +@pytest.mark.asyncio +async def test_load_does_not_pollute_sys_modules(workspace): + module_name = "_agentproject_code-reviewer_config" + assert module_name not in sys.modules + + loader = AgentLoader(workspace) + await loader.load("code-reviewer") + + assert module_name not in sys.modules + + +@pytest.mark.asyncio +async def test_load_twice_gives_fresh_spec(workspace): + loader = AgentLoader(workspace) + spec1 = await loader.load("code-reviewer") + spec2 = await loader.load("code-reviewer") + + assert spec1 is not None and spec2 is not None + assert spec1.name == spec2.name + assert spec1 is not spec2 + + +# ── Multiple search dirs ──────────────────────────────────────────── + + +@pytest.mark.asyncio +async def test_user_agents_dir(workspace, tmp_path): + user_dir = tmp_path / "user_agents" + custom = user_dir / "custom-agent" + custom.mkdir(parents=True) + (custom / "config.py").write_text( + CONFIG_STANDARD.replace("code-reviewer", "custom-agent"), + encoding="utf-8", + ) + + loader = AgentLoader(workspace, user_agents_dir=user_dir) + entries = await loader.list() + names = {e["name"] for e in entries} + + assert "custom-agent" in names + assert "code-reviewer" in names + + +@pytest.mark.asyncio +async def test_workspace_agents_take_priority(workspace, tmp_path): + """If same name exists in workspace and user dir, workspace wins.""" + user_dir = tmp_path / "user_agents" + dupe = user_dir / "code-reviewer" + dupe.mkdir(parents=True) + (dupe / "config.py").write_text( + CONFIG_STANDARD.replace( + 'description = "Reviews code for quality and best practices"', + 'description = "From user dir"', + ), + encoding="utf-8", + ) + + loader = AgentLoader(workspace, user_agents_dir=user_dir) + spec = await loader.load("code-reviewer") + + assert spec is not None + assert spec.description == "Reviews code for quality and best practices" + + +# ── Error handling ────────────────────────────────────────────────── + + +@pytest.mark.asyncio +async def test_load_broken_config_returns_none(workspace): + broken = workspace / "agents" / "broken" + broken.mkdir(parents=True) + (broken / "config.py").write_text("@@@ not python", encoding="utf-8") + + loader = AgentLoader(workspace) + assert await loader.load("broken") is None + + +@pytest.mark.asyncio +async def test_load_config_missing_agent_config_returns_none(workspace): + bad = workspace / "agents" / "no-config" + bad.mkdir(parents=True) + (bad / "config.py").write_text( + "name = 'no-config'\ndescription = 'missing agent_config'\n", + encoding="utf-8", + ) + + loader = AgentLoader(workspace) + assert await loader.load("no-config") is None + + +# ── create() / acreate() fallback to create_object ────────────────── + + +def test_spec_create_without_build_calls_create_object(): + """create() without build delegates to create_object().""" + from lagent.agents.agent import Agent + + spec = AgentSpec( + name="simple", + agent_config=dict(type=Agent, name="simple-agent"), + ) + agent = spec.create() + assert agent.name == "simple-agent" + + +@pytest.mark.asyncio +async def test_spec_acreate_without_build_calls_create_object(): + """acreate() without build delegates to create_object().""" + from lagent.agents.agent import Agent + + spec = AgentSpec( + name="simple", + agent_config=dict(type=Agent, name="simple-agent"), + ) + agent = await spec.acreate() + assert agent.name == "simple-agent" + + +# ── _import_module_from_path: spec_from_file_location returns None ── + + +def test_import_module_from_unloadable_file(tmp_path): + """If spec_from_file_location returns None, ImportError is raised.""" + # A file without .py extension causes spec_from_file_location to + # return None, triggering the ImportError guard on L156. + bad_file = tmp_path / "not_a_module" + bad_file.write_text("x = 1\n", encoding="utf-8") + with pytest.raises(ImportError, match="Cannot load spec"): + _import_module_from_path("_test_mod", bad_file, tmp_path) + + +# ── _import_module_from_path: prev_module restoration ─────────────── + + +def test_import_restores_prev_module(tmp_path): + """If a module with the same name already exists, it's restored after load.""" + module_name = "_test_prev_module_restore" + sentinel = object() + sys.modules[module_name] = sentinel + + try: + agent_dir = tmp_path / "dummy" + agent_dir.mkdir() + config_file = agent_dir / "config.py" + config_file.write_text("value = 42\n", encoding="utf-8") + + attrs = _import_module_from_path(module_name, config_file, agent_dir) + assert attrs["value"] == 42 + # Previous module should be restored + assert sys.modules.get(module_name) is sentinel + finally: + sys.modules.pop(module_name, None) + + +# ── list() deduplication across multiple dirs ──────────────────────── + + +@pytest.mark.asyncio +async def test_list_deduplicates_across_dirs(workspace, tmp_path): + """Same agent name in workspace and user dir: only listed once.""" + user_dir = tmp_path / "user_agents" + dupe = user_dir / "code-reviewer" + dupe.mkdir(parents=True) + (dupe / "config.py").write_text( + CONFIG_STANDARD.replace( + 'description = "Reviews code for quality and best practices"', + 'description = "From user dir"', + ), + encoding="utf-8", + ) + + loader = AgentLoader(workspace, user_agents_dir=user_dir) + entries = await loader.list() + names = [e["name"] for e in entries] + + # code-reviewer should appear exactly once (workspace wins) + assert names.count("code-reviewer") == 1 + + +# ── builtin_agents_dir ────────────────────────────────────────────── + + +@pytest.mark.asyncio +async def test_builtin_agents_dir(tmp_path): + """Agents in builtin_agents_dir are discovered.""" + workspace = tmp_path / "workspace" + workspace.mkdir() + + builtin = tmp_path / "builtin" + agent = builtin / "builtin-agent" + agent.mkdir(parents=True) + (agent / "config.py").write_text( + CONFIG_STANDARD.replace("code-reviewer", "builtin-agent"), + encoding="utf-8", + ) + + loader = AgentLoader(workspace, builtin_agents_dir=builtin) + entries = await loader.list() + names = {e["name"] for e in entries} + + assert "builtin-agent" in names diff --git a/tests/test_services/test_cron.py b/tests/test_services/test_cron.py new file mode 100644 index 00000000..12d2e40f --- /dev/null +++ b/tests/test_services/test_cron.py @@ -0,0 +1,510 @@ +"""Unit tests for CronService (lagent/services/cron.py).""" + +import asyncio +import json +import os +import sys +import tempfile +import types +from datetime import datetime, timezone, timedelta +from pathlib import Path + +import pytest + +# --- bypass circular import in lagent.services.__init__.py --- +_here = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +if _here not in sys.path: + sys.path.insert(0, _here) +if "lagent.services" not in sys.modules: + _pkg = types.ModuleType("lagent.services") + _pkg.__path__ = [os.path.join(_here, "lagent", "services")] + _pkg.__package__ = "lagent.services" + sys.modules["lagent.services"] = _pkg + +from lagent.services.cron import ( + CronService, CronJob, Schedule, JobState, compute_next_run, _now_ms, + _load_jobs, _save_jobs, +) + + +# ═══════════════════════════════════════════════════════════════════════ +# DATA MODEL +# ═══════════════════════════════════════════════════════════════════════ + +class TestSchedule: + def test_defaults(self): + s = Schedule() + assert s.kind == "at" + assert s.at is None + assert s.every_seconds is None + assert s.expr is None + assert s.tz is None + + +class TestCronJob: + def test_defaults(self): + job = CronJob() + assert len(job.id) == 8 + assert job.name == "" + assert job.enabled is True + assert isinstance(job.schedule, Schedule) + assert isinstance(job.state, JobState) + + def test_to_dict_from_dict_roundtrip(self): + job = CronJob( + name="test-job", + schedule=Schedule(kind="every", every_seconds=60), + payload={"message": "hello", "channel": "cli"}, + state=JobState(next_run_at_ms=12345, consecutive_errors=2), + delete_after_run=True, + ) + d = job.to_dict() + restored = CronJob.from_dict(d) + assert restored.name == "test-job" + assert restored.schedule.kind == "every" + assert restored.schedule.every_seconds == 60 + assert restored.state.next_run_at_ms == 12345 + assert restored.state.consecutive_errors == 2 + assert restored.delete_after_run is True + assert restored.payload["message"] == "hello" + + def test_from_dict_missing_fields(self): + job = CronJob.from_dict({}) + assert job.name == "" + assert job.enabled is True + assert job.schedule.kind == "at" + + def test_from_dict_ignores_extra_keys(self): + job = CronJob.from_dict({"name": "ok", "schedule": {"kind": "at", "bogus": 1}}) + assert job.name == "ok" + assert job.schedule.kind == "at" + + +# ═══════════════════════════════════════════════════════════════════════ +# compute_next_run +# ═══════════════════════════════════════════════════════════════════════ + +class TestComputeNextRun: + def test_at_future(self): + future = datetime.now(timezone.utc) + timedelta(hours=1) + s = Schedule(kind="at", at=future.isoformat()) + result = compute_next_run(s, _now_ms()) + assert result is not None + assert result > _now_ms() + + def test_at_past_returns_none(self): + past = datetime.now(timezone.utc) - timedelta(hours=1) + s = Schedule(kind="at", at=past.isoformat()) + result = compute_next_run(s, _now_ms()) + assert result is None + + def test_at_missing_returns_none(self): + s = Schedule(kind="at", at=None) + assert compute_next_run(s, _now_ms()) is None + + def test_at_invalid_iso_returns_none(self): + s = Schedule(kind="at", at="not-a-date") + assert compute_next_run(s, _now_ms()) is None + + def test_every(self): + s = Schedule(kind="every", every_seconds=30) + now = _now_ms() + result = compute_next_run(s, now) + assert result == now + 30_000 + + def test_every_zero_returns_none(self): + s = Schedule(kind="every", every_seconds=0) + assert compute_next_run(s, _now_ms()) is None + + def test_every_negative_returns_none(self): + s = Schedule(kind="every", every_seconds=-5) + assert compute_next_run(s, _now_ms()) is None + + def test_every_missing_returns_none(self): + s = Schedule(kind="every", every_seconds=None) + assert compute_next_run(s, _now_ms()) is None + + def test_cron_without_expr_returns_none(self): + s = Schedule(kind="cron", expr=None) + assert compute_next_run(s, _now_ms()) is None + + def test_cron_valid(self): + pytest.importorskip("croniter") + s = Schedule(kind="cron", expr="* * * * *") # every minute + result = compute_next_run(s, _now_ms()) + assert result is not None + assert result > _now_ms() + + def test_unknown_kind_returns_none(self): + s = Schedule(kind="bogus") + assert compute_next_run(s, _now_ms()) is None + + +# ═══════════════════════════════════════════════════════════════════════ +# PERSISTENCE HELPERS +# ═══════════════════════════════════════════════════════════════════════ + +class TestPersistenceHelpers: + def test_save_and_load(self): + with tempfile.TemporaryDirectory() as tmpdir: + path = Path(tmpdir) / "jobs.json" + jobs = [ + CronJob(name="j1", schedule=Schedule(kind="every", every_seconds=10)), + CronJob(name="j2", schedule=Schedule(kind="at", at="2030-01-01T00:00:00Z")), + ] + _save_jobs(jobs, path) + assert path.exists() + + loaded = _load_jobs(path) + assert len(loaded) == 2 + assert loaded[0].name == "j1" + assert loaded[1].name == "j2" + + def test_load_missing_file(self): + result = _load_jobs(Path("/nonexistent/path.json")) + assert result == [] + + def test_load_invalid_json(self): + with tempfile.TemporaryDirectory() as tmpdir: + path = Path(tmpdir) / "jobs.json" + path.write_text("not json", "utf-8") + result = _load_jobs(path) + assert result == [] + + def test_load_wrong_version(self): + with tempfile.TemporaryDirectory() as tmpdir: + path = Path(tmpdir) / "jobs.json" + path.write_text(json.dumps({"version": 99, "jobs": []}), "utf-8") + result = _load_jobs(path) + assert result == [] + + def test_save_format(self): + with tempfile.TemporaryDirectory() as tmpdir: + path = Path(tmpdir) / "jobs.json" + _save_jobs([CronJob(name="x")], path) + data = json.loads(path.read_text("utf-8")) + assert data["version"] == 1 + assert len(data["jobs"]) == 1 + assert data["jobs"][0]["name"] == "x" + + +# ═══════════════════════════════════════════════════════════════════════ +# CRON SERVICE +# ═══════════════════════════════════════════════════════════════════════ + +class TestCronServiceCRUD: + def _make_service(self, tmpdir) -> CronService: + path = Path(tmpdir) / "jobs.json" + return CronService(store_path=path) + + def test_add_job(self): + with tempfile.TemporaryDirectory() as tmpdir: + svc = self._make_service(tmpdir) + job = svc.add_job( + name="test", + schedule=Schedule(kind="every", every_seconds=60), + message="hello", + ) + assert job.name == "test" + assert job.payload["message"] == "hello" + assert len(svc.list_jobs(include_disabled=True)) == 1 + + def test_add_job_validates_at(self): + with tempfile.TemporaryDirectory() as tmpdir: + svc = self._make_service(tmpdir) + with pytest.raises(ValueError, match="'at'"): + svc.add_job( + name="bad", + schedule=Schedule(kind="at", at=None), + message="x", + ) + + def test_add_job_validates_every(self): + with tempfile.TemporaryDirectory() as tmpdir: + svc = self._make_service(tmpdir) + with pytest.raises(ValueError, match="every_seconds"): + svc.add_job( + name="bad", + schedule=Schedule(kind="every", every_seconds=0), + message="x", + ) + + def test_add_job_validates_cron(self): + with tempfile.TemporaryDirectory() as tmpdir: + svc = self._make_service(tmpdir) + with pytest.raises(ValueError, match="'expr'"): + svc.add_job( + name="bad", + schedule=Schedule(kind="cron", expr=None), + message="x", + ) + + def test_remove_job(self): + with tempfile.TemporaryDirectory() as tmpdir: + svc = self._make_service(tmpdir) + job = svc.add_job( + name="test", + schedule=Schedule(kind="every", every_seconds=60), + message="hello", + ) + assert svc.remove_job(job.id) is True + assert len(svc.list_jobs(include_disabled=True)) == 0 + + def test_remove_nonexistent(self): + with tempfile.TemporaryDirectory() as tmpdir: + svc = self._make_service(tmpdir) + assert svc.remove_job("nonexistent") is False + + def test_get_job(self): + with tempfile.TemporaryDirectory() as tmpdir: + svc = self._make_service(tmpdir) + job = svc.add_job( + name="test", + schedule=Schedule(kind="every", every_seconds=60), + message="hello", + ) + found = svc.get_job(job.id) + assert found is not None + assert found.name == "test" + + def test_get_nonexistent(self): + with tempfile.TemporaryDirectory() as tmpdir: + svc = self._make_service(tmpdir) + assert svc.get_job("nope") is None + + def test_list_jobs_filters_disabled(self): + with tempfile.TemporaryDirectory() as tmpdir: + svc = self._make_service(tmpdir) + j1 = svc.add_job( + name="enabled", + schedule=Schedule(kind="every", every_seconds=60), + message="a", + ) + j2 = svc.add_job( + name="disabled", + schedule=Schedule(kind="every", every_seconds=60), + message="b", + ) + j2.enabled = False + assert len(svc.list_jobs(include_disabled=False)) == 1 + assert len(svc.list_jobs(include_disabled=True)) == 2 + + def test_list_jobs_sorted_by_next_run(self): + with tempfile.TemporaryDirectory() as tmpdir: + svc = self._make_service(tmpdir) + j1 = svc.add_job( + name="later", + schedule=Schedule(kind="every", every_seconds=120), + message="a", + ) + j2 = svc.add_job( + name="sooner", + schedule=Schedule(kind="every", every_seconds=10), + message="b", + ) + jobs = svc.list_jobs() + assert jobs[0].name == "sooner" + assert jobs[1].name == "later" + + +class TestCronServicePersistence: + def test_persist_and_reload(self): + with tempfile.TemporaryDirectory() as tmpdir: + path = Path(tmpdir) / "jobs.json" + svc1 = CronService(store_path=path) + svc1.add_job( + name="persist-test", + schedule=Schedule(kind="every", every_seconds=30), + message="hi", + ) + # Load in a new service instance + svc2 = CronService(store_path=path) + loaded = _load_jobs(path) + assert len(loaded) == 1 + assert loaded[0].name == "persist-test" + + +class TestCronServiceStateDict: + def test_state_dict_roundtrip(self): + with tempfile.TemporaryDirectory() as tmpdir: + path = Path(tmpdir) / "jobs.json" + svc = CronService(store_path=path) + svc.add_job( + name="j1", + schedule=Schedule(kind="every", every_seconds=60), + message="hello", + ) + svc.add_job( + name="j2", + schedule=Schedule(kind="every", every_seconds=120), + message="world", + ) + + state = svc.state_dict() + assert state["version"] == 1 + assert len(state["jobs"]) == 2 + + # Restore into a fresh service + path2 = Path(tmpdir) / "jobs2.json" + svc2 = CronService(store_path=path2) + svc2.load_state_dict(state) + + jobs = svc2.list_jobs(include_disabled=True) + assert len(jobs) == 2 + names = {j.name for j in jobs} + assert "j1" in names + assert "j2" in names + + +# ═══════════════════════════════════════════════════════════════════════ +# ASYNC TIMER EXECUTION (end-to-end) +# ═══════════════════════════════════════════════════════════════════════ + +@pytest.mark.asyncio +class TestCronTimerExecution: + """Test that jobs actually fire via the async timer engine.""" + + async def test_every_job_fires(self): + """Add a 1-second recurring job, verify on_job is called.""" + fired = [] + + async def on_job(job): + fired.append(job.name) + + with tempfile.TemporaryDirectory() as tmpdir: + path = Path(tmpdir) / "jobs.json" + svc = CronService(store_path=path, on_job=on_job) + await svc.start() + + svc.add_job( + name="fast-repeat", + schedule=Schedule(kind="every", every_seconds=1), + message="tick", + ) + + # Wait long enough for at least 1 fire + # _MIN_REFIRE_GAP_S=2.0, so timer fires at ~2s, then job is due + await asyncio.sleep(5) + svc.stop() + + assert len(fired) >= 1 + assert fired[0] == "fast-repeat" + + async def test_at_job_fires_once(self): + """Add a one-shot 'at' job 2s in the future, verify it fires once.""" + fired = [] + + async def on_job(job): + fired.append(job.name) + + with tempfile.TemporaryDirectory() as tmpdir: + path = Path(tmpdir) / "jobs.json" + svc = CronService(store_path=path, on_job=on_job) + await svc.start() + + future = ( + datetime.now(timezone.utc) + timedelta(seconds=2) + ).isoformat() + svc.add_job( + name="one-shot", + schedule=Schedule(kind="at", at=future), + message="boom", + delete_after_run=True, + ) + + await asyncio.sleep(5) + svc.stop() + + assert len(fired) == 1 + assert fired[0] == "one-shot" + # Job should be deleted after run + assert len(svc.list_jobs(include_disabled=True)) == 0 + + async def test_job_error_records_state(self): + """A failing on_job should record error state.""" + async def on_job(job): + raise RuntimeError("boom") + + with tempfile.TemporaryDirectory() as tmpdir: + path = Path(tmpdir) / "jobs.json" + svc = CronService(store_path=path, on_job=on_job) + await svc.start() + + svc.add_job( + name="fail-job", + schedule=Schedule(kind="every", every_seconds=1), + message="fail", + ) + + await asyncio.sleep(5) + svc.stop() + + job = svc.list_jobs(include_disabled=True)[0] + assert job.state.last_status == "error" + assert job.state.last_error == "boom" + assert job.state.consecutive_errors >= 1 + + async def test_no_on_job_callback(self): + """Service without on_job should not crash when jobs fire.""" + with tempfile.TemporaryDirectory() as tmpdir: + path = Path(tmpdir) / "jobs.json" + svc = CronService(store_path=path, on_job=None) + await svc.start() + + svc.add_job( + name="silent", + schedule=Schedule(kind="every", every_seconds=1), + message="noop", + ) + + await asyncio.sleep(4) + svc.stop() + + job = svc.list_jobs(include_disabled=True)[0] + assert job.state.last_status == "ok" + + async def test_stop_prevents_further_fires(self): + """After stop(), no more jobs should fire.""" + fired = [] + + async def on_job(job): + fired.append(job.name) + + with tempfile.TemporaryDirectory() as tmpdir: + path = Path(tmpdir) / "jobs.json" + svc = CronService(store_path=path, on_job=on_job) + await svc.start() + + svc.add_job( + name="stoppable", + schedule=Schedule(kind="every", every_seconds=1), + message="tick", + ) + + await asyncio.sleep(4) + count_before_stop = len(fired) + svc.stop() + + await asyncio.sleep(3) + assert len(fired) == count_before_stop # no new fires + + +if __name__ == "__main__": + async def _run_all(): + t = TestCronTimerExecution() + tests = [ + ("every_job_fires", t.test_every_job_fires), + ("at_job_fires_once", t.test_at_job_fires_once), + ("job_error_records_state", t.test_job_error_records_state), + ("no_on_job_callback", t.test_no_on_job_callback), + ("stop_prevents_further_fires", t.test_stop_prevents_further_fires), + ] + for name, fn in tests: + print(f"Running {name}...", end=" ", flush=True) + try: + await fn() + print("PASSED") + except Exception as e: + print(f"FAILED: {e}") + + asyncio.run(_run_all()) \ No newline at end of file diff --git a/tests/test_services/test_mailbox.py b/tests/test_services/test_mailbox.py new file mode 100644 index 00000000..a714a354 --- /dev/null +++ b/tests/test_services/test_mailbox.py @@ -0,0 +1,319 @@ +"""Unit tests for Mailbox (lagent/services/mailbox.py).""" + +import asyncio +import json +import os +import sys +import tempfile +import types +from pathlib import Path + +import pytest + +# --- bypass circular import in lagent.services.__init__.py --- +_here = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +if _here not in sys.path: + sys.path.insert(0, _here) +if "lagent.services" not in sys.modules: + _pkg = types.ModuleType("lagent.services") + _pkg.__path__ = [os.path.join(_here, "lagent", "services")] + _pkg.__package__ = "lagent.services" + sys.modules["lagent.services"] = _pkg + +from lagent.schema import AgentMessage +from lagent.services.mailbox import Mailbox + + +def _msg(sender: str, receiver: str, content: str) -> AgentMessage: + return AgentMessage(sender=sender, receiver=receiver, content=content) + + +# ═══════════════════════════════════════════════════════════════════════ +# REGISTRATION +# ═══════════════════════════════════════════════════════════════════════ + +class TestRegister: + def test_register_creates_queue(self): + mb = Mailbox() + mb.register("agent-A") + assert "agent-A" in mb.agents + + def test_register_idempotent(self): + mb = Mailbox() + mb.register("agent-A") + mb.register("agent-A") + assert mb.agents.count("agent-A") == 1 + + def test_agents_empty_initially(self): + mb = Mailbox() + assert mb.agents == [] + + +# ═══════════════════════════════════════════════════════════════════════ +# SEND +# ═══════════════════════════════════════════════════════════════════════ + +class TestSend: + def test_send_basic(self): + mb = Mailbox() + mb.send(_msg("A", "B", "hello")) + assert mb.has_messages("B") + assert not mb.has_messages("A") + + def test_send_creates_queue_for_receiver(self): + mb = Mailbox() + mb.send(_msg("A", "B", "hello")) + assert "B" in mb.agents + + def test_send_without_receiver_raises(self): + mb = Mailbox() + with pytest.raises(ValueError, match="receiver"): + mb.send(AgentMessage(sender="A", content="oops")) + + def test_send_multiple_messages(self): + mb = Mailbox() + mb.send(_msg("A", "B", "msg1")) + mb.send(_msg("A", "B", "msg2")) + mb.send(_msg("C", "B", "msg3")) + assert mb.message_count("B") == 3 + + def test_broadcast(self): + mb = Mailbox() + mb.register("A") + mb.register("B") + mb.register("C") + mb.send(_msg("A", "*", "hello everyone")) + assert mb.has_messages("B") + assert mb.has_messages("C") + assert not mb.has_messages("A") # sender excluded + + def test_broadcast_to_unregistered_skips(self): + mb = Mailbox() + mb.register("A") + # No B or C registered + mb.send(_msg("A", "*", "hello")) + # Should not crash, just no recipients + assert mb.message_count() == 0 + + +# ═══════════════════════════════════════════════════════════════════════ +# DRAIN +# ═══════════════════════════════════════════════════════════════════════ + +class TestDrain: + def test_drain_returns_messages(self): + mb = Mailbox() + mb.send(_msg("A", "B", "hello")) + mb.send(_msg("A", "B", "world")) + msgs = mb.drain("B") + assert len(msgs) == 2 + assert msgs[0].content == "hello" + assert msgs[1].content == "world" + + def test_drain_clears_queue(self): + mb = Mailbox() + mb.send(_msg("A", "B", "hello")) + mb.drain("B") + assert not mb.has_messages("B") + assert mb.drain("B") == [] + + def test_drain_empty_returns_empty(self): + mb = Mailbox() + assert mb.drain("nonexistent") == [] + + def test_drain_only_affects_target(self): + mb = Mailbox() + mb.send(_msg("A", "B", "for B")) + mb.send(_msg("A", "C", "for C")) + mb.drain("B") + assert mb.has_messages("C") + + +# ═══════════════════════════════════════════════════════════════════════ +# PEEK +# ═══════════════════════════════════════════════════════════════════════ + +class TestPeek: + def test_peek_returns_messages(self): + mb = Mailbox() + mb.send(_msg("A", "B", "hello")) + msgs = mb.peek("B") + assert len(msgs) == 1 + + def test_peek_does_not_remove(self): + mb = Mailbox() + mb.send(_msg("A", "B", "hello")) + mb.peek("B") + assert mb.has_messages("B") + + def test_peek_returns_copy(self): + mb = Mailbox() + mb.send(_msg("A", "B", "hello")) + peeked = mb.peek("B") + peeked.clear() + assert mb.has_messages("B") # original not affected + + def test_peek_empty(self): + mb = Mailbox() + assert mb.peek("nobody") == [] + + +# ═══════════════════════════════════════════════════════════════════════ +# HAS_MESSAGES / MESSAGE_COUNT +# ═══════════════════════════════════════════════════════════════════════ + +class TestQueryMethods: + def test_has_messages(self): + mb = Mailbox() + assert not mb.has_messages("B") + mb.send(_msg("A", "B", "hello")) + assert mb.has_messages("B") + + def test_message_count_per_agent(self): + mb = Mailbox() + mb.send(_msg("A", "B", "1")) + mb.send(_msg("A", "B", "2")) + mb.send(_msg("A", "C", "3")) + assert mb.message_count("B") == 2 + assert mb.message_count("C") == 1 + + def test_message_count_total(self): + mb = Mailbox() + mb.send(_msg("A", "B", "1")) + mb.send(_msg("A", "C", "2")) + assert mb.message_count() == 2 + + +# ═══════════════════════════════════════════════════════════════════════ +# CLEAR +# ═══════════════════════════════════════════════════════════════════════ + +class TestClear: + def test_clear_one_agent(self): + mb = Mailbox() + mb.send(_msg("A", "B", "1")) + mb.send(_msg("A", "C", "2")) + mb.clear("B") + assert not mb.has_messages("B") + assert mb.has_messages("C") + + def test_clear_all(self): + mb = Mailbox() + mb.send(_msg("A", "B", "1")) + mb.send(_msg("A", "C", "2")) + mb.clear_all() + assert mb.message_count() == 0 + + +# ═══════════════════════════════════════════════════════════════════════ +# WAIT_FOR_MESSAGE (async) +# ═══════════════════════════════════════════════════════════════════════ + +@pytest.mark.asyncio +class TestWaitForMessage: + async def test_wait_receives_signal(self): + mb = Mailbox() + mb.register("B") + + async def sender(): + await asyncio.sleep(0.1) + mb.send(_msg("A", "B", "hello")) + + task = asyncio.create_task(sender()) + result = await mb.wait_for_message("B", timeout=2.0) + assert result is True + assert mb.has_messages("B") + await task + + async def test_wait_timeout(self): + mb = Mailbox() + mb.register("B") + result = await mb.wait_for_message("B", timeout=0.1) + assert result is False + + async def test_wait_auto_registers(self): + mb = Mailbox() + # Don't explicitly register + async def sender(): + await asyncio.sleep(0.1) + mb.send(_msg("A", "B", "hello")) + + task = asyncio.create_task(sender()) + result = await mb.wait_for_message("B", timeout=2.0) + assert result is True + await task + + +# ═══════════════════════════════════════════════════════════════════════ +# STATE_DICT / LOAD_STATE_DICT +# ═══════════════════════════════════════════════════════════════════════ + +class TestStateDict: + def test_roundtrip(self): + mb = Mailbox() + mb.send(_msg("A", "B", "hello")) + mb.send(_msg("C", "B", "world")) + mb.send(_msg("A", "D", "test")) + + state = mb.state_dict() + + mb2 = Mailbox() + mb2.load_state_dict(state) + + assert mb2.message_count("B") == 2 + assert mb2.message_count("D") == 1 + msgs = mb2.drain("B") + assert msgs[0].content == "hello" + assert msgs[1].content == "world" + + def test_state_dict_format(self): + mb = Mailbox() + mb.send(_msg("A", "B", "hello")) + state = mb.state_dict() + assert state["version"] == 1 + assert "B" in state["queues"] + assert len(state["queues"]["B"]) == 1 + assert state["queues"]["B"][0]["sender"] == "A" + + +# ═══════════════════════════════════════════════════════════════════════ +# FILE PERSISTENCE +# ═══════════════════════════════════════════════════════════════════════ + +class TestFilePersistence: + def test_persist_and_reload(self): + with tempfile.TemporaryDirectory() as tmpdir: + path = Path(tmpdir) / "mailbox.json" + + mb = Mailbox(store_path=path) + mb.send(_msg("A", "B", "hello")) + mb.send(_msg("C", "D", "world")) + assert path.exists() + + mb2 = Mailbox(store_path=path) + assert mb2.has_messages("B") + assert mb2.has_messages("D") + msgs = mb2.drain("B") + assert len(msgs) == 1 + assert msgs[0].content == "hello" + + def test_persist_format(self): + with tempfile.TemporaryDirectory() as tmpdir: + path = Path(tmpdir) / "mailbox.json" + mb = Mailbox(store_path=path) + mb.send(_msg("A", "B", "test")) + + data = json.loads(path.read_text("utf-8")) + assert data["version"] == 1 + assert "B" in data["queues"] + assert data["queues"]["B"][0]["content"] == "test" + + def test_drain_updates_file(self): + with tempfile.TemporaryDirectory() as tmpdir: + path = Path(tmpdir) / "mailbox.json" + mb = Mailbox(store_path=path) + mb.send(_msg("A", "B", "hello")) + mb.drain("B") + + data = json.loads(path.read_text("utf-8")) + assert data["queues"].get("B", []) == [] diff --git a/tests/test_services/test_task.py b/tests/test_services/test_task.py new file mode 100644 index 00000000..ee567bd8 --- /dev/null +++ b/tests/test_services/test_task.py @@ -0,0 +1,456 @@ +"""Unit tests for TaskBoard (lagent/services/task.py). + +Uses direct file-based import to bypass the circular import in +lagent.services.__init__.py (a pre-existing issue, not introduced by +the task module). +""" + +import json +import sys +import tempfile +import types +import os +from pathlib import Path + +import pytest + +# --- bypass circular import in lagent.services.__init__.py --- +_here = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +if _here not in sys.path: + sys.path.insert(0, _here) +if "lagent.services" not in sys.modules: + _pkg = types.ModuleType("lagent.services") + _pkg.__path__ = [os.path.join(_here, "lagent", "services")] + _pkg.__package__ = "lagent.services" + sys.modules["lagent.services"] = _pkg + +from lagent.services.task import Task, TaskBoard, ClaimResult + + +# ── Helpers ────────────────────────────────────────────────────────── + +def make_board(**kwargs) -> TaskBoard: + return TaskBoard(**kwargs) + + +# ── CRUD ───────────────────────────────────────────────────────────── + +class TestCreate: + def test_basic_create(self): + board = make_board() + task = board.create("Fix bug", "Fix the login bug") + assert task.id == "1" + assert task.subject == "Fix bug" + assert task.description == "Fix the login bug" + assert task.status == "pending" + assert task.blocks == [] + assert task.blocked_by == [] + + def test_auto_increment_ids(self): + board = make_board() + t1 = board.create("Task 1", "desc") + t2 = board.create("Task 2", "desc") + t3 = board.create("Task 3", "desc") + assert t1.id == "1" + assert t2.id == "2" + assert t3.id == "3" + + def test_create_with_active_form(self): + board = make_board() + task = board.create("Run tests", "desc", active_form="Running tests") + assert task.active_form == "Running tests" + + def test_create_with_blocked_by(self): + board = make_board() + t1 = board.create("Setup", "desc") + t2 = board.create("Build", "desc", blocked_by=["1"]) + assert t2.blocked_by == ["1"] + # Bidirectional: t1 should now block t2 + t1_refreshed = board.get("1") + assert "2" in t1_refreshed.blocks + + def test_create_with_metadata(self): + board = make_board() + task = board.create("Task", "desc", metadata={"priority": "high"}) + assert task.metadata == {"priority": "high"} + + +class TestUpdate: + def test_update_status(self): + board = make_board() + board.create("Task", "desc") + updated = board.update("1", status="in_progress") + assert updated is not None + assert updated.status == "in_progress" + + def test_update_subject(self): + board = make_board() + board.create("Old title", "desc") + updated = board.update("1", subject="New title") + assert updated.subject == "New title" + + def test_update_nonexistent_returns_none(self): + board = make_board() + result = board.update("999", status="completed") + assert result is None + + def test_update_add_blocked_by(self): + board = make_board() + board.create("Task A", "desc") + board.create("Task B", "desc") + board.update("2", add_blocked_by=["1"]) + t1 = board.get("1") + t2 = board.get("2") + assert "1" in t2.blocked_by + assert "2" in t1.blocks + + def test_update_add_blocks(self): + board = make_board() + board.create("Task A", "desc") + board.create("Task B", "desc") + board.update("1", add_blocks=["2"]) + t1 = board.get("1") + t2 = board.get("2") + assert "2" in t1.blocks + assert "1" in t2.blocked_by + + def test_update_metadata_merge(self): + board = make_board() + board.create("Task", "desc", metadata={"a": 1, "b": 2}) + board.update("1", metadata={"b": 99, "c": 3}) + task = board.get("1") + assert task.metadata == {"a": 1, "b": 99, "c": 3} + + def test_update_metadata_delete_key(self): + board = make_board() + board.create("Task", "desc", metadata={"a": 1, "b": 2}) + board.update("1", metadata={"b": None}) + task = board.get("1") + assert task.metadata == {"a": 1} + + def test_update_deleted_status(self): + board = make_board() + board.create("Task", "desc") + result = board.update("1", status="deleted") + assert result is None # deleted returns None + assert board.get("1") is None + assert len(board.list()) == 0 + + +class TestDelete: + def test_delete_existing(self): + board = make_board() + board.create("Task", "desc") + assert board.delete("1") is True + assert board.get("1") is None + + def test_delete_nonexistent(self): + board = make_board() + assert board.delete("999") is False + + def test_cascade_cleanup(self): + board = make_board() + board.create("A", "desc") + board.create("B", "desc", blocked_by=["1"]) + board.create("C", "desc", blocked_by=["1"]) + board.delete("1") + t2 = board.get("2") + t3 = board.get("3") + assert "1" not in t2.blocked_by + assert "1" not in t3.blocked_by + + +class TestGetAndList: + def test_get_existing(self): + board = make_board() + board.create("Task", "desc") + task = board.get("1") + assert task is not None + assert task.subject == "Task" + + def test_get_nonexistent(self): + board = make_board() + assert board.get("999") is None + + def test_list_all(self): + board = make_board() + board.create("A", "desc") + board.create("B", "desc") + assert len(board.list()) == 2 + + def test_list_filtered(self): + board = make_board() + board.create("A", "desc") + board.create("B", "desc") + board.update("1", status="completed") + assert len(board.list(status="completed")) == 1 + assert len(board.list(status="pending")) == 1 + + +# ── High water mark ────────────────────────────────────────────────── + +class TestHighWaterMark: + def test_ids_never_reused_after_delete(self): + board = make_board() + board.create("A", "desc") + board.create("B", "desc") + board.create("C", "desc") + board.delete("2") + t4 = board.create("D", "desc") + assert t4.id == "4" + + def test_ids_never_reused_after_delete_all(self): + board = make_board() + board.create("A", "desc") + board.create("B", "desc") + board.delete("1") + board.delete("2") + t3 = board.create("C", "desc") + assert t3.id == "3" + + +# ── Dependency graph ───────────────────────────────────────────────── + +class TestDependencyGraph: + def test_bidirectional_on_create(self): + board = make_board() + board.create("A", "desc") + board.create("B", "desc", blocked_by=["1"]) + a = board.get("1") + b = board.get("2") + assert "2" in a.blocks + assert "1" in b.blocked_by + + def test_bidirectional_on_update(self): + board = make_board() + board.create("A", "desc") + board.create("B", "desc") + board.update("2", add_blocked_by=["1"]) + a = board.get("1") + b = board.get("2") + assert "2" in a.blocks + assert "1" in b.blocked_by + + def test_cascade_delete_cleans_both_directions(self): + board = make_board() + board.create("A", "desc") + board.create("B", "desc", blocked_by=["1"]) + board.create("C", "desc") + board.update("1", add_blocks=["3"]) + assert "2" in board.get("1").blocks + assert "3" in board.get("1").blocks + board.delete("1") + assert "1" not in board.get("2").blocked_by + assert "1" not in board.get("3").blocked_by + + def test_no_duplicate_deps(self): + board = make_board() + board.create("A", "desc") + board.create("B", "desc", blocked_by=["1"]) + board.update("2", add_blocked_by=["1"]) + b = board.get("2") + assert b.blocked_by.count("1") == 1 + a = board.get("1") + assert a.blocks.count("2") == 1 + + +# ── Claim ──────────────────────────────────────────────────────────── + +class TestClaim: + def test_claim_success(self): + board = make_board() + board.create("Task", "desc") + result = board.claim("1", "worker-A") + assert result.success is True + assert result.task.owner == "worker-A" + assert result.task.status == "in_progress" + + def test_claim_nonexistent(self): + board = make_board() + result = board.claim("999", "worker-A") + assert result.success is False + assert result.reason == "task_not_found" + + def test_claim_completed_task(self): + board = make_board() + board.create("Task", "desc") + board.update("1", status="completed") + result = board.claim("1", "worker-A") + assert result.success is False + assert result.reason == "already_completed" + + def test_claim_blocked_task(self): + board = make_board() + board.create("A", "desc") + board.create("B", "desc", blocked_by=["1"]) + result = board.claim("2", "worker-A") + assert result.success is False + assert result.reason == "blocked" + assert "1" in result.blocked_by_tasks + + def test_claim_unblocked_after_completion(self): + board = make_board() + board.create("A", "desc") + board.create("B", "desc", blocked_by=["1"]) + board.update("1", status="completed") + result = board.claim("2", "worker-A") + assert result.success is True + + def test_claim_agent_busy(self): + board = make_board() + board.create("A", "desc") + board.create("B", "desc") + board.claim("1", "worker-A") + result = board.claim("2", "worker-A", check_busy=True) + assert result.success is False + assert result.reason == "agent_busy" + + def test_claim_agent_busy_disabled(self): + board = make_board() + board.create("A", "desc") + board.create("B", "desc") + board.claim("1", "worker-A") + result = board.claim("2", "worker-A", check_busy=False) + assert result.success is True + + +# ── Release agent ──────────────────────────────────────────────────── + +class TestReleaseAgent: + def test_release(self): + board = make_board() + board.create("A", "desc") + board.create("B", "desc") + board.claim("1", "worker-A", check_busy=False) + board.claim("2", "worker-A", check_busy=False) + released = board.release_agent("worker-A") + assert len(released) == 2 + for t in released: + assert t.owner is None + assert t.status == "pending" + + def test_release_preserves_completed(self): + board = make_board() + board.create("A", "desc") + board.claim("1", "worker-A") + board.update("1", status="completed") + released = board.release_agent("worker-A") + assert len(released) == 0 + assert board.get("1").status == "completed" + + +# ── Query helpers ──────────────────────────────────────────────────── + +class TestQueryHelpers: + def test_all_completed_empty(self): + board = make_board() + assert board.all_completed() is True + + def test_all_completed_false(self): + board = make_board() + board.create("A", "desc") + assert board.all_completed() is False + + def test_all_completed_true(self): + board = make_board() + board.create("A", "desc") + board.update("1", status="completed") + assert board.all_completed() is True + + def test_list_available(self): + board = make_board() + board.create("A", "desc") + board.create("B", "desc", blocked_by=["1"]) + board.create("C", "desc") + available = board.list_available() + ids = [t.id for t in available] + assert "1" in ids + assert "3" in ids + assert "2" not in ids + + def test_list_available_excludes_owned(self): + board = make_board() + board.create("A", "desc") + board.create("B", "desc") + board.claim("1", "worker-A") + available = board.list_available() + ids = [t.id for t in available] + assert "1" not in ids + assert "2" in ids + + def test_get_summary(self): + board = make_board() + board.create("Audit code", "desc") + board.create("Write tests", "desc", blocked_by=["1"]) + board.update("1", status="in_progress", owner="coder") + summary = board.get_summary() + assert "#1" in summary + assert "#2" in summary + assert "in_progress" in summary + assert "@coder" in summary + assert "blocked by #1" in summary + + +# ── state_dict / load_state_dict ───────────────────────────────────── + +class TestStateDictRoundTrip: + def test_roundtrip(self): + board = make_board() + board.create("A", "desc") + board.create("B", "desc", blocked_by=["1"]) + board.update("1", status="completed") + + state = board.state_dict() + board2 = make_board() + board2.load_state_dict(state) + + assert len(board2.list()) == 2 + assert board2.get("1").status == "completed" + assert board2.get("2").blocked_by == ["1"] + t3 = board2.create("C", "desc") + assert t3.id == "3" + + def test_high_water_mark_preserved(self): + board = make_board() + board.create("A", "desc") + board.create("B", "desc") + board.delete("2") + + state = board.state_dict() + board2 = make_board() + board2.load_state_dict(state) + + t = board2.create("C", "desc") + assert t.id == "3" + + +# ── File persistence ───────────────────────────────────────────────── + +class TestFilePersistence: + def test_persist_and_reload(self): + with tempfile.TemporaryDirectory() as tmpdir: + path = Path(tmpdir) / "tasks.json" + board = TaskBoard(store_path=path) + board.create("A", "desc") + board.create("B", "desc", blocked_by=["1"]) + board.update("1", status="completed") + assert path.exists() + + board2 = TaskBoard(store_path=path) + assert len(board2.list()) == 2 + assert board2.get("1").status == "completed" + assert board2.get("2").blocked_by == ["1"] + t3 = board2.create("C", "desc") + assert t3.id == "3" + + def test_persist_format(self): + with tempfile.TemporaryDirectory() as tmpdir: + path = Path(tmpdir) / "tasks.json" + board = TaskBoard(store_path=path) + board.create("Test", "desc") + + data = json.loads(path.read_text("utf-8")) + assert data["version"] == 1 + assert data["next_id"] == 2 + assert len(data["tasks"]) == 1 + assert data["tasks"][0]["subject"] == "Test" diff --git a/workspace/AGENTS.md b/workspace/AGENTS.md new file mode 100644 index 00000000..7497b08c --- /dev/null +++ b/workspace/AGENTS.md @@ -0,0 +1,21 @@ +# Agent Instructions + +You are a helpful AI assistant. Be concise, accurate, and friendly. + +## Scheduled Reminders + +Before scheduling reminders, check available skills and follow skill guidance first. +Use the built-in `cron` tool to create/list/remove jobs (do not call `internclaw cron` via `exec`). +Get USER_ID and CHANNEL from the current session (e.g., `8281248569` and `telegram` from `telegram:8281248569`). + +**Do NOT just write reminders to MEMORY.md** — that won't trigger actual notifications. + +## Heartbeat Tasks + +`HEARTBEAT.md` is checked on the configured heartbeat interval. Use file tools to manage periodic tasks: + +- **Add**: `edit_file` to append new tasks +- **Remove**: `edit_file` to delete completed tasks +- **Rewrite**: `write_file` to replace all tasks + +When the user asks for a recurring/periodic task, update `HEARTBEAT.md` instead of creating a one-time cron reminder. diff --git a/workspace/HEARTBEAT.md b/workspace/HEARTBEAT.md new file mode 100644 index 00000000..47f50018 --- /dev/null +++ b/workspace/HEARTBEAT.md @@ -0,0 +1,17 @@ +# Heartbeat Tasks + +This file is checked every 30 minutes by your internclaw agent. +Add tasks below that you want the agent to work on periodically. + +If this file has no tasks (only headers and comments), the agent will skip the heartbeat. + +## Active Tasks + + +- 每天上午10点检查HuggingFace上是否有Agent相关的Paper + + +## Completed + + + diff --git a/workspace/README.md b/workspace/README.md new file mode 100644 index 00000000..663b4072 --- /dev/null +++ b/workspace/README.md @@ -0,0 +1,35 @@ +# Workspace README + +这是我的工作空间的 README 文件。 + +## 目录结构 +- `AGENTS.md` - Agent 说明 +- `HEARTBEAT.md` - 心跳任务配置 +- `SOUL.md` - Agent 个人档案 +- `TOOLS.md` - 工具说明 +- `USER.md` - 用户档案 +- `memory/` - 记忆系统 +- `skills/` - 技能文件夹 +- `README.md` - 本文件 + +## 使用说明 +- 使用 `memory` 技能管理长时记忆 +- 使用 `cron` 技能管理定时任务 +- 使用 `clawhub` 技能安装新技能 +- 使用 `github` 技能管理 GitHub 项目 + +## 技能列表 +当前可用技能: +- cron - 定时任务 +- memory - 记忆管理 +- weather - 天气预报 +- github - GitHub 集成 +- clawhub - 技能仓库 +- skill-creator - 创建技能 +- summarize - 总结工具 +- tmux - Tmux 集成 + +## 注意事项 +- 请在使用前阅读相关技能的说明 +- 重要信息已存储在 MEMORY.md 中 +- 历史记录已保存在 HISTORY.md 中 diff --git a/workspace/SOUL.md b/workspace/SOUL.md new file mode 100644 index 00000000..ce8c230a --- /dev/null +++ b/workspace/SOUL.md @@ -0,0 +1,21 @@ +# Soul + +I am internclaw 🐈, a personal AI assistant. + +## Personality + +- Helpful and friendly +- Concise and to the point +- Curious and eager to learn + +## Values + +- Accuracy over speed +- User privacy and safety +- Transparency in actions + +## Communication Style + +- Be clear and direct +- Explain reasoning when helpful +- Ask clarifying questions when needed diff --git a/workspace/TOOLS.md b/workspace/TOOLS.md new file mode 100644 index 00000000..51c3a2d0 --- /dev/null +++ b/workspace/TOOLS.md @@ -0,0 +1,15 @@ +# Tool Usage Notes + +Tool signatures are provided automatically via function calling. +This file documents non-obvious constraints and usage patterns. + +## exec — Safety Limits + +- Commands have a configurable timeout (default 60s) +- Dangerous commands are blocked (rm -rf, format, dd, shutdown, etc.) +- Output is truncated at 10,000 characters +- `restrictToWorkspace` config can limit file access to the workspace + +## cron — Scheduled Reminders + +- Please refer to cron skill for usage. diff --git a/workspace/USER.md b/workspace/USER.md new file mode 100644 index 00000000..16245917 --- /dev/null +++ b/workspace/USER.md @@ -0,0 +1,49 @@ +# User Profile + +Information about the user to help personalize interactions. + +## Basic Information + +- **Name**: (your name) +- **Timezone**: (your timezone, e.g., UTC+8) +- **Language**: (preferred language) + +## Preferences + +### Communication Style + +- [ ] Casual +- [ ] Professional +- [ ] Technical + +### Response Length + +- [ ] Brief and concise +- [ ] Detailed explanations +- [ ] Adaptive based on question + +### Technical Level + +- [ ] Beginner +- [ ] Intermediate +- [ ] Expert + +## Work Context + +- **Primary Role**: (your role, e.g., developer, researcher) +- **Main Projects**: (what you're working on) +- **Tools You Use**: (IDEs, languages, frameworks) + +## Topics of Interest + +- +- +- + +## Special Instructions + +(Any specific instructions for how the assistant should behave) + +--- + +*Edit this file to customize internclaw's behavior for your needs.* diff --git a/workspace/agents/default_agent/config.py b/workspace/agents/default_agent/config.py new file mode 100644 index 00000000..688fe31b --- /dev/null +++ b/workspace/agents/default_agent/config.py @@ -0,0 +1,119 @@ +"""Default agent config for AgentDaemon. + +Produces a JSON-serializable dict that ``create_object()`` can instantiate. + +Usage:: + + # As AgentDaemon config file + python -c "from config import agent_config; import json; print(json.dumps(agent_config))" > agent.json + python -m lagent.serving.sandbox.daemon start --mode agent --config agent.json + + # Or load directly in Python + from config import agent_config + from lagent.utils import create_object + agent = create_object(agent_config) +""" + +# ── Model ── +model_name = "sft_interns2_pre_base03_20260413a_lr2e5_128gpu_lmdeploy_fix_parser_retry" +api_base = "http://s-20260104203038-22bhb.ailab-evalservice.pjh-service.org.cn/v1" +api_key = "sk-admin" + +extra_body = {} + +model = dict( + type="lagent.llms.model.AsyncAPIClient", + model=dict(model=model_name, base_url=api_base, api_key=api_key), + sample_params=dict(temperature=0.7, top_p=1.0, top_k=50), + timeout=600, + max_retry=500, + sleep_interval=5, + extra_body=dict(spaces_between_special_tokens=False), +) + +# ── Workspace ── +workspace = "/root/workspace" + +# ── Actions ── +base_actions = [ + dict(type="lagent.actions.filesystem.ReadFileAction", workspace=workspace), + dict(type="lagent.actions.filesystem.WriteFileAction", workspace=workspace), + dict(type="lagent.actions.filesystem.EditFileAction", workspace=workspace), + dict(type="lagent.actions.shell.ShellAction", working_dir=workspace), +] + +save_memory_action = dict( + type="lagent.actions.save_memory.AsyncSaveMemoryAction", + workspace=workspace, +) + +# ── Prompts ── +CONSOLIDATION_PROMPT = ( + "You are a memory consolidation agent. Review the conversation " + "and call the save_memory tool to persist important information.\n\n" + "Extract key facts, decisions, user preferences, and project context. " + "Merge with existing long-term memory. For history_entry, write a " + "grep-searchable summary starting with [YYYY-MM-DD HH:MM]." +) + +# ── Policy Agent ── +policy_agent = dict( + type="lagent.agents.internclaw_agent.AsyncPolicyAgent", + llm=model, + aggregator=dict( + type="lagent.agents.aggregator.context.InternClawContextBuilder", + workspace=workspace, + ), + hooks=[dict(type="lagent.hooks.logger.MessageLogger")], +) + +# ── Env Agent ── +env_agent = dict( + type="lagent.agents.internclaw_agent.AsyncEnvAgent", + actions=base_actions + [save_memory_action], + skills=dict( + type="lagent.skills.skills.SkillsLoader", + workspace=workspace, + ), + long_term_memory=dict( + type="lagent.memory.openclaw_provider.OpenClawMemoryProvider", + workspace=workspace, + ), +) + +# ── Compact Agent ── +compact_agent = dict( + type="lagent.agents.compact_agent.AsyncCompactAgent", + name="compact", + llm=model, + max_context_tokens=65536, + threshold_ratio=0.5, +) + +# ── Consolidate Agent ── +consolidate_agent = dict( + type="lagent.agents.internclaw_agent.InternClawAgent", + policy_agent=dict( + type="lagent.agents.internclaw_agent.AsyncPolicyAgent", + name="consolidate_policy", + llm=model, + template=CONSOLIDATION_PROMPT, + hooks=[dict(type="lagent.hooks.logger.MessageLogger")], + aggregator=dict(type="lagent.agents.aggregator.compact_aggregator.CompactAggregator"), + ), + env_agent=dict( + type="lagent.agents.internclaw_agent.AsyncEnvAgent", + actions=[save_memory_action], + ), + max_turn=1, + finish_condition=None, +) + +# ── Full Agent Config (for AgentDaemon) ── +agent_config = dict( + type="lagent.agents.internclaw_agent.InternClawAgent", + policy_agent=policy_agent, + env_agent=env_agent, + compact_agent=compact_agent, + consolidate_agent=consolidate_agent, +) diff --git a/workspace/memory/HISTORY.md b/workspace/memory/HISTORY.md new file mode 100644 index 00000000..f77961eb --- /dev/null +++ b/workspace/memory/HISTORY.md @@ -0,0 +1,4 @@ +[2024-12-31 23:59] No conversation content provided for memory consolidation + +[2026-04-10 07:52] 用户要求对比 nanobot 和 lagent/workspace 两个项目。nanobot 是超轻量级个人AI助手(OpenClaw的99%精简版),支持多渠道(Telegram、Discord、WhatsApp、微信等)和预设集成;lagent 是基于PyTorch设计的AI代理开发框架,提供模块化组件(actions、agents、llms、memory等)用于构建自定义代理。nanobot侧重生产就绪和易用性,lagent侧重灵活性和研究扩展性。 + diff --git a/workspace/memory/MEMORY.md b/workspace/memory/MEMORY.md new file mode 100644 index 00000000..b587b8a3 --- /dev/null +++ b/workspace/memory/MEMORY.md @@ -0,0 +1 @@ +**nanobot** - Ultra-lightweight personal AI assistant inspired by OpenClaw, 99% smaller codebase. Multi-channel support (Telegram, Discord, WhatsApp, WeChat, Feishu, etc.). Focus: production-ready, deployment, usability. **lagent** - Lightweight open-source framework for building LLM-based agents, PyTorch-inspired architecture. Modular design with components: actions, agents, llms, memory, prompts, hooks, services. Focus: flexibility, research, custom agent development. **Key Difference**: nanobot is an application (ready-to-use AI assistant), lagent is a framework (for building custom agents). nanobot has pre-built integrations, lagent has modular components. nanobot emphasizes deployment and usability, lagent emphasizes research and extensibility. \ No newline at end of file diff --git a/workspace/memory/__init__.py b/workspace/memory/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/workspace/skills/agent-browser/SKILL.md b/workspace/skills/agent-browser/SKILL.md new file mode 100644 index 00000000..73bd7848 --- /dev/null +++ b/workspace/skills/agent-browser/SKILL.md @@ -0,0 +1,682 @@ +--- +name: agent-browser +description: Browser automation CLI for AI agents. Use when the user needs to interact with websites, including navigating pages, filling forms, clicking buttons, taking screenshots, extracting data, testing web apps, or automating any browser task. Triggers include requests to "open a website", "fill out a form", "click a button", "take a screenshot", "scrape data from a page", "test this web app", "login to a site", "automate browser actions", or any task requiring programmatic web interaction. +allowed-tools: Bash(npx agent-browser:*), Bash(agent-browser:*) +--- + +# Browser Automation with agent-browser + +The CLI uses Chrome/Chromium via CDP directly. Install via `npm i -g agent-browser`, `brew install agent-browser`, or `cargo install agent-browser`. Run `agent-browser install` to download Chrome. Run `agent-browser upgrade` to update to the latest version. + +## Core Workflow + +Every browser automation follows this pattern: + +1. **Navigate**: `agent-browser open ` +2. **Snapshot**: `agent-browser snapshot -i` (get element refs like `@e1`, `@e2`) +3. **Interact**: Use refs to click, fill, select +4. **Re-snapshot**: After navigation or DOM changes, get fresh refs + +```bash +agent-browser open https://example.com/form +agent-browser snapshot -i +# Output: @e1 [input type="email"], @e2 [input type="password"], @e3 [button] "Submit" + +agent-browser fill @e1 "user@example.com" +agent-browser fill @e2 "password123" +agent-browser click @e3 +agent-browser wait --load networkidle +agent-browser snapshot -i # Check result +``` + +## Command Chaining + +Commands can be chained with `&&` in a single shell invocation. The browser persists between commands via a background daemon, so chaining is safe and more efficient than separate calls. + +```bash +# Chain open + wait + snapshot in one call +agent-browser open https://example.com && agent-browser wait --load networkidle && agent-browser snapshot -i + +# Chain multiple interactions +agent-browser fill @e1 "user@example.com" && agent-browser fill @e2 "password123" && agent-browser click @e3 + +# Navigate and capture +agent-browser open https://example.com && agent-browser wait --load networkidle && agent-browser screenshot page.png +``` + +**When to chain:** Use `&&` when you don't need to read the output of an intermediate command before proceeding (e.g., open + wait + screenshot). Run commands separately when you need to parse the output first (e.g., snapshot to discover refs, then interact using those refs). + +## Handling Authentication + +When automating a site that requires login, choose the approach that fits: + +**Option 1: Import auth from the user's browser (fastest for one-off tasks)** + +```bash +# Connect to the user's running Chrome (they're already logged in) +agent-browser --auto-connect state save ./auth.json +# Use that auth state +agent-browser --state ./auth.json open https://app.example.com/dashboard +``` + +State files contain session tokens in plaintext -- add to `.gitignore` and delete when no longer needed. Set `AGENT_BROWSER_ENCRYPTION_KEY` for encryption at rest. + +**Option 2: Persistent profile (simplest for recurring tasks)** + +```bash +# First run: login manually or via automation +agent-browser --profile ~/.myapp open https://app.example.com/login +# ... fill credentials, submit ... + +# All future runs: already authenticated +agent-browser --profile ~/.myapp open https://app.example.com/dashboard +``` + +**Option 3: Session name (auto-save/restore cookies + localStorage)** + +```bash +agent-browser --session-name myapp open https://app.example.com/login +# ... login flow ... +agent-browser close # State auto-saved + +# Next time: state auto-restored +agent-browser --session-name myapp open https://app.example.com/dashboard +``` + +**Option 4: Auth vault (credentials stored encrypted, login by name)** + +```bash +echo "$PASSWORD" | agent-browser auth save myapp --url https://app.example.com/login --username user --password-stdin +agent-browser auth login myapp +``` + +**Option 5: State file (manual save/load)** + +```bash +# After logging in: +agent-browser state save ./auth.json +# In a future session: +agent-browser state load ./auth.json +agent-browser open https://app.example.com/dashboard +``` + +See [references/authentication.md](references/authentication.md) for OAuth, 2FA, cookie-based auth, and token refresh patterns. + +## Essential Commands + +```bash +# Navigation +agent-browser open # Navigate (aliases: goto, navigate) +agent-browser close # Close browser + +# Snapshot +agent-browser snapshot -i # Interactive elements with refs (recommended) +agent-browser snapshot -i -C # Include cursor-interactive elements (divs with onclick, cursor:pointer) +agent-browser snapshot -s "#selector" # Scope to CSS selector + +# Interaction (use @refs from snapshot) +agent-browser click @e1 # Click element +agent-browser click @e1 --new-tab # Click and open in new tab +agent-browser fill @e2 "text" # Clear and type text +agent-browser type @e2 "text" # Type without clearing +agent-browser select @e1 "option" # Select dropdown option +agent-browser check @e1 # Check checkbox +agent-browser press Enter # Press key +agent-browser keyboard type "text" # Type at current focus (no selector) +agent-browser keyboard inserttext "text" # Insert without key events +agent-browser scroll down 500 # Scroll page +agent-browser scroll down 500 --selector "div.content" # Scroll within a specific container + +# Get information +agent-browser get text @e1 # Get element text +agent-browser get url # Get current URL +agent-browser get title # Get page title +agent-browser get cdp-url # Get CDP WebSocket URL + +# Wait +agent-browser wait @e1 # Wait for element +agent-browser wait --load networkidle # Wait for network idle +agent-browser wait --url "**/page" # Wait for URL pattern +agent-browser wait 2000 # Wait milliseconds +agent-browser wait --text "Welcome" # Wait for text to appear (substring match) +agent-browser wait --fn "!document.body.innerText.includes('Loading...')" # Wait for text to disappear +agent-browser wait "#spinner" --state hidden # Wait for element to disappear + +# Downloads +agent-browser download @e1 ./file.pdf # Click element to trigger download +agent-browser wait --download ./output.zip # Wait for any download to complete +agent-browser --download-path ./downloads open # Set default download directory + +# Network +agent-browser network requests # Inspect tracked requests +agent-browser network route "**/api/*" --abort # Block matching requests +agent-browser network har start # Start HAR recording +agent-browser network har stop ./capture.har # Stop and save HAR file + +# Viewport & Device Emulation +agent-browser set viewport 1920 1080 # Set viewport size (default: 1280x720) +agent-browser set viewport 1920 1080 2 # 2x retina (same CSS size, higher res screenshots) +agent-browser set device "iPhone 14" # Emulate device (viewport + user agent) + +# Capture +agent-browser screenshot # Screenshot to temp dir +agent-browser screenshot --full # Full page screenshot +agent-browser screenshot --annotate # Annotated screenshot with numbered element labels +agent-browser screenshot --screenshot-dir ./shots # Save to custom directory +agent-browser screenshot --screenshot-format jpeg --screenshot-quality 80 +agent-browser pdf output.pdf # Save as PDF + +# Clipboard +agent-browser clipboard read # Read text from clipboard +agent-browser clipboard write "Hello, World!" # Write text to clipboard +agent-browser clipboard copy # Copy current selection +agent-browser clipboard paste # Paste from clipboard + +# Diff (compare page states) +agent-browser diff snapshot # Compare current vs last snapshot +agent-browser diff snapshot --baseline before.txt # Compare current vs saved file +agent-browser diff screenshot --baseline before.png # Visual pixel diff +agent-browser diff url # Compare two pages +agent-browser diff url --wait-until networkidle # Custom wait strategy +agent-browser diff url --selector "#main" # Scope to element +``` + +## Batch Execution + +Execute multiple commands in a single invocation by piping a JSON array of string arrays to `batch`. This avoids per-command process startup overhead when running multi-step workflows. + +```bash +echo '[ + ["open", "https://example.com"], + ["snapshot", "-i"], + ["click", "@e1"], + ["screenshot", "result.png"] +]' | agent-browser batch --json + +# Stop on first error +agent-browser batch --bail < commands.json +``` + +Use `batch` when you have a known sequence of commands that don't depend on intermediate output. Use separate commands or `&&` chaining when you need to parse output between steps (e.g., snapshot to discover refs, then interact). + +## Common Patterns + +### Form Submission + +```bash +agent-browser open https://example.com/signup +agent-browser snapshot -i +agent-browser fill @e1 "Jane Doe" +agent-browser fill @e2 "jane@example.com" +agent-browser select @e3 "California" +agent-browser check @e4 +agent-browser click @e5 +agent-browser wait --load networkidle +``` + +### Authentication with Auth Vault (Recommended) + +```bash +# Save credentials once (encrypted with AGENT_BROWSER_ENCRYPTION_KEY) +# Recommended: pipe password via stdin to avoid shell history exposure +echo "pass" | agent-browser auth save github --url https://github.com/login --username user --password-stdin + +# Login using saved profile (LLM never sees password) +agent-browser auth login github + +# List/show/delete profiles +agent-browser auth list +agent-browser auth show github +agent-browser auth delete github +``` + +### Authentication with State Persistence + +```bash +# Login once and save state +agent-browser open https://app.example.com/login +agent-browser snapshot -i +agent-browser fill @e1 "$USERNAME" +agent-browser fill @e2 "$PASSWORD" +agent-browser click @e3 +agent-browser wait --url "**/dashboard" +agent-browser state save auth.json + +# Reuse in future sessions +agent-browser state load auth.json +agent-browser open https://app.example.com/dashboard +``` + +### Session Persistence + +```bash +# Auto-save/restore cookies and localStorage across browser restarts +agent-browser --session-name myapp open https://app.example.com/login +# ... login flow ... +agent-browser close # State auto-saved to ~/.agent-browser/sessions/ + +# Next time, state is auto-loaded +agent-browser --session-name myapp open https://app.example.com/dashboard + +# Encrypt state at rest +export AGENT_BROWSER_ENCRYPTION_KEY=$(openssl rand -hex 32) +agent-browser --session-name secure open https://app.example.com + +# Manage saved states +agent-browser state list +agent-browser state show myapp-default.json +agent-browser state clear myapp +agent-browser state clean --older-than 7 +``` + +### Working with Iframes + +Iframe content is automatically inlined in snapshots. Refs inside iframes carry frame context, so you can interact with them directly. + +```bash +agent-browser open https://example.com/checkout +agent-browser snapshot -i +# @e1 [heading] "Checkout" +# @e2 [Iframe] "payment-frame" +# @e3 [input] "Card number" +# @e4 [input] "Expiry" +# @e5 [button] "Pay" + +# Interact directly — no frame switch needed +agent-browser fill @e3 "4111111111111111" +agent-browser fill @e4 "12/28" +agent-browser click @e5 + +# To scope a snapshot to one iframe: +agent-browser frame @e2 +agent-browser snapshot -i # Only iframe content +agent-browser frame main # Return to main frame +``` + +### Data Extraction + +```bash +agent-browser open https://example.com/products +agent-browser snapshot -i +agent-browser get text @e5 # Get specific element text +agent-browser get text body > page.txt # Get all page text + +# JSON output for parsing +agent-browser snapshot -i --json +agent-browser get text @e1 --json +``` + +### Parallel Sessions + +```bash +agent-browser --session site1 open https://site-a.com +agent-browser --session site2 open https://site-b.com + +agent-browser --session site1 snapshot -i +agent-browser --session site2 snapshot -i + +agent-browser session list +``` + +### Connect to Existing Chrome + +```bash +# Auto-discover running Chrome with remote debugging enabled +agent-browser --auto-connect open https://example.com +agent-browser --auto-connect snapshot + +# Or with explicit CDP port +agent-browser --cdp 9222 snapshot +``` + +Auto-connect discovers Chrome via `DevToolsActivePort`, common debugging ports (9222, 9229), and falls back to a direct WebSocket connection if HTTP-based CDP discovery fails. + +### Color Scheme (Dark Mode) + +```bash +# Persistent dark mode via flag (applies to all pages and new tabs) +agent-browser --color-scheme dark open https://example.com + +# Or via environment variable +AGENT_BROWSER_COLOR_SCHEME=dark agent-browser open https://example.com + +# Or set during session (persists for subsequent commands) +agent-browser set media dark +``` + +### Viewport & Responsive Testing + +```bash +# Set a custom viewport size (default is 1280x720) +agent-browser set viewport 1920 1080 +agent-browser screenshot desktop.png + +# Test mobile-width layout +agent-browser set viewport 375 812 +agent-browser screenshot mobile.png + +# Retina/HiDPI: same CSS layout at 2x pixel density +# Screenshots stay at logical viewport size, but content renders at higher DPI +agent-browser set viewport 1920 1080 2 +agent-browser screenshot retina.png + +# Device emulation (sets viewport + user agent in one step) +agent-browser set device "iPhone 14" +agent-browser screenshot device.png +``` + +The `scale` parameter (3rd argument) sets `window.devicePixelRatio` without changing CSS layout. Use it when testing retina rendering or capturing higher-resolution screenshots. + +### Visual Browser (Debugging) + +```bash +agent-browser --headed open https://example.com +agent-browser highlight @e1 # Highlight element +agent-browser inspect # Open Chrome DevTools for the active page +agent-browser record start demo.webm # Record session +agent-browser profiler start # Start Chrome DevTools profiling +agent-browser profiler stop trace.json # Stop and save profile (path optional) +``` + +Use `AGENT_BROWSER_HEADED=1` to enable headed mode via environment variable. Browser extensions work in both headed and headless mode. + +### Local Files (PDFs, HTML) + +```bash +# Open local files with file:// URLs +agent-browser --allow-file-access open file:///path/to/document.pdf +agent-browser --allow-file-access open file:///path/to/page.html +agent-browser screenshot output.png +``` + +### iOS Simulator (Mobile Safari) + +```bash +# List available iOS simulators +agent-browser device list + +# Launch Safari on a specific device +agent-browser -p ios --device "iPhone 16 Pro" open https://example.com + +# Same workflow as desktop - snapshot, interact, re-snapshot +agent-browser -p ios snapshot -i +agent-browser -p ios tap @e1 # Tap (alias for click) +agent-browser -p ios fill @e2 "text" +agent-browser -p ios swipe up # Mobile-specific gesture + +# Take screenshot +agent-browser -p ios screenshot mobile.png + +# Close session (shuts down simulator) +agent-browser -p ios close +``` + +**Requirements:** macOS with Xcode, Appium (`npm install -g appium && appium driver install xcuitest`) + +**Real devices:** Works with physical iOS devices if pre-configured. Use `--device ""` where UDID is from `xcrun xctrace list devices`. + +## Security + +All security features are opt-in. By default, agent-browser imposes no restrictions on navigation, actions, or output. + +### Content Boundaries (Recommended for AI Agents) + +Enable `--content-boundaries` to wrap page-sourced output in markers that help LLMs distinguish tool output from untrusted page content: + +```bash +export AGENT_BROWSER_CONTENT_BOUNDARIES=1 +agent-browser snapshot +# Output: +# --- AGENT_BROWSER_PAGE_CONTENT nonce= origin=https://example.com --- +# [accessibility tree] +# --- END_AGENT_BROWSER_PAGE_CONTENT nonce= --- +``` + +### Domain Allowlist + +Restrict navigation to trusted domains. Wildcards like `*.example.com` also match the bare domain `example.com`. Sub-resource requests, WebSocket, and EventSource connections to non-allowed domains are also blocked. Include CDN domains your target pages depend on: + +```bash +export AGENT_BROWSER_ALLOWED_DOMAINS="example.com,*.example.com" +agent-browser open https://example.com # OK +agent-browser open https://malicious.com # Blocked +``` + +### Action Policy + +Use a policy file to gate destructive actions: + +```bash +export AGENT_BROWSER_ACTION_POLICY=./policy.json +``` + +Example `policy.json`: + +```json +{ "default": "deny", "allow": ["navigate", "snapshot", "click", "scroll", "wait", "get"] } +``` + +Auth vault operations (`auth login`, etc.) bypass action policy but domain allowlist still applies. + +### Output Limits + +Prevent context flooding from large pages: + +```bash +export AGENT_BROWSER_MAX_OUTPUT=50000 +``` + +## Diffing (Verifying Changes) + +Use `diff snapshot` after performing an action to verify it had the intended effect. This compares the current accessibility tree against the last snapshot taken in the session. + +```bash +# Typical workflow: snapshot -> action -> diff +agent-browser snapshot -i # Take baseline snapshot +agent-browser click @e2 # Perform action +agent-browser diff snapshot # See what changed (auto-compares to last snapshot) +``` + +For visual regression testing or monitoring: + +```bash +# Save a baseline screenshot, then compare later +agent-browser screenshot baseline.png +# ... time passes or changes are made ... +agent-browser diff screenshot --baseline baseline.png + +# Compare staging vs production +agent-browser diff url https://staging.example.com https://prod.example.com --screenshot +``` + +`diff snapshot` output uses `+` for additions and `-` for removals, similar to git diff. `diff screenshot` produces a diff image with changed pixels highlighted in red, plus a mismatch percentage. + +## Timeouts and Slow Pages + +The default timeout is 25 seconds. This can be overridden with the `AGENT_BROWSER_DEFAULT_TIMEOUT` environment variable (value in milliseconds). For slow websites or large pages, use explicit waits instead of relying on the default timeout: + +```bash +# Wait for network activity to settle (best for slow pages) +agent-browser wait --load networkidle + +# Wait for a specific element to appear +agent-browser wait "#content" +agent-browser wait @e1 + +# Wait for a specific URL pattern (useful after redirects) +agent-browser wait --url "**/dashboard" + +# Wait for a JavaScript condition +agent-browser wait --fn "document.readyState === 'complete'" + +# Wait a fixed duration (milliseconds) as a last resort +agent-browser wait 5000 +``` + +When dealing with consistently slow websites, use `wait --load networkidle` after `open` to ensure the page is fully loaded before taking a snapshot. If a specific element is slow to render, wait for it directly with `wait ` or `wait @ref`. + +## Session Management and Cleanup + +When running multiple agents or automations concurrently, always use named sessions to avoid conflicts: + +```bash +# Each agent gets its own isolated session +agent-browser --session agent1 open site-a.com +agent-browser --session agent2 open site-b.com + +# Check active sessions +agent-browser session list +``` + +Always close your browser session when done to avoid leaked processes: + +```bash +agent-browser close # Close default session +agent-browser --session agent1 close # Close specific session +``` + +If a previous session was not closed properly, the daemon may still be running. Use `agent-browser close` to clean it up before starting new work. + +To auto-shutdown the daemon after a period of inactivity (useful for ephemeral/CI environments): + +```bash +AGENT_BROWSER_IDLE_TIMEOUT_MS=60000 agent-browser open example.com +``` + +## Ref Lifecycle (Important) + +Refs (`@e1`, `@e2`, etc.) are invalidated when the page changes. Always re-snapshot after: + +- Clicking links or buttons that navigate +- Form submissions +- Dynamic content loading (dropdowns, modals) + +```bash +agent-browser click @e5 # Navigates to new page +agent-browser snapshot -i # MUST re-snapshot +agent-browser click @e1 # Use new refs +``` + +## Annotated Screenshots (Vision Mode) + +Use `--annotate` to take a screenshot with numbered labels overlaid on interactive elements. Each label `[N]` maps to ref `@eN`. This also caches refs, so you can interact with elements immediately without a separate snapshot. + +```bash +agent-browser screenshot --annotate +# Output includes the image path and a legend: +# [1] @e1 button "Submit" +# [2] @e2 link "Home" +# [3] @e3 textbox "Email" +agent-browser click @e2 # Click using ref from annotated screenshot +``` + +Use annotated screenshots when: + +- The page has unlabeled icon buttons or visual-only elements +- You need to verify visual layout or styling +- Canvas or chart elements are present (invisible to text snapshots) +- You need spatial reasoning about element positions + +## Semantic Locators (Alternative to Refs) + +When refs are unavailable or unreliable, use semantic locators: + +```bash +agent-browser find text "Sign In" click +agent-browser find label "Email" fill "user@test.com" +agent-browser find role button click --name "Submit" +agent-browser find placeholder "Search" type "query" +agent-browser find testid "submit-btn" click +``` + +## JavaScript Evaluation (eval) + +Use `eval` to run JavaScript in the browser context. **Shell quoting can corrupt complex expressions** -- use `--stdin` or `-b` to avoid issues. + +```bash +# Simple expressions work with regular quoting +agent-browser eval 'document.title' +agent-browser eval 'document.querySelectorAll("img").length' + +# Complex JS: use --stdin with heredoc (RECOMMENDED) +agent-browser eval --stdin <<'EVALEOF' +JSON.stringify( + Array.from(document.querySelectorAll("img")) + .filter(i => !i.alt) + .map(i => ({ src: i.src.split("/").pop(), width: i.width })) +) +EVALEOF + +# Alternative: base64 encoding (avoids all shell escaping issues) +agent-browser eval -b "$(echo -n 'Array.from(document.querySelectorAll("a")).map(a => a.href)' | base64)" +``` + +**Why this matters:** When the shell processes your command, inner double quotes, `!` characters (history expansion), backticks, and `$()` can all corrupt the JavaScript before it reaches agent-browser. The `--stdin` and `-b` flags bypass shell interpretation entirely. + +**Rules of thumb:** + +- Single-line, no nested quotes -> regular `eval 'expression'` with single quotes is fine +- Nested quotes, arrow functions, template literals, or multiline -> use `eval --stdin <<'EVALEOF'` +- Programmatic/generated scripts -> use `eval -b` with base64 + +## Configuration File + +Create `agent-browser.json` in the project root for persistent settings: + +```json +{ + "headed": true, + "proxy": "http://localhost:8080", + "profile": "./browser-data" +} +``` + +Priority (lowest to highest): `~/.agent-browser/config.json` < `./agent-browser.json` < env vars < CLI flags. Use `--config ` or `AGENT_BROWSER_CONFIG` env var for a custom config file (exits with error if missing/invalid). All CLI options map to camelCase keys (e.g., `--executable-path` -> `"executablePath"`). Boolean flags accept `true`/`false` values (e.g., `--headed false` overrides config). Extensions from user and project configs are merged, not replaced. + +## Deep-Dive Documentation + +| Reference | When to Use | +| -------------------------------------------------------------------- | --------------------------------------------------------- | +| [references/commands.md](references/commands.md) | Full command reference with all options | +| [references/snapshot-refs.md](references/snapshot-refs.md) | Ref lifecycle, invalidation rules, troubleshooting | +| [references/session-management.md](references/session-management.md) | Parallel sessions, state persistence, concurrent scraping | +| [references/authentication.md](references/authentication.md) | Login flows, OAuth, 2FA handling, state reuse | +| [references/video-recording.md](references/video-recording.md) | Recording workflows for debugging and documentation | +| [references/profiling.md](references/profiling.md) | Chrome DevTools profiling for performance analysis | +| [references/proxy-support.md](references/proxy-support.md) | Proxy configuration, geo-testing, rotating proxies | + +## Browser Engine Selection + +Use `--engine` to choose a local browser engine. The default is `chrome`. + +```bash +# Use Lightpanda (fast headless browser, requires separate install) +agent-browser --engine lightpanda open example.com + +# Via environment variable +export AGENT_BROWSER_ENGINE=lightpanda +agent-browser open example.com + +# With custom binary path +agent-browser --engine lightpanda --executable-path /path/to/lightpanda open example.com +``` + +Supported engines: +- `chrome` (default) -- Chrome/Chromium via CDP +- `lightpanda` -- Lightpanda headless browser via CDP (10x faster, 10x less memory than Chrome) + +Lightpanda does not support `--extension`, `--profile`, `--state`, or `--allow-file-access`. Install Lightpanda from https://lightpanda.io/docs/open-source/installation. + +## Ready-to-Use Templates + +| Template | Description | +| ------------------------------------------------------------------------ | ----------------------------------- | +| [templates/form-automation.sh](templates/form-automation.sh) | Form filling with validation | +| [templates/authenticated-session.sh](templates/authenticated-session.sh) | Login once, reuse state | +| [templates/capture-workflow.sh](templates/capture-workflow.sh) | Content extraction with screenshots | + +```bash +./templates/form-automation.sh https://example.com/form +./templates/authenticated-session.sh https://app.example.com/login +./templates/capture-workflow.sh https://example.com ./output +``` diff --git a/workspace/skills/agent-browser/references/authentication.md b/workspace/skills/agent-browser/references/authentication.md new file mode 100644 index 00000000..89f4788a --- /dev/null +++ b/workspace/skills/agent-browser/references/authentication.md @@ -0,0 +1,303 @@ +# Authentication Patterns + +Login flows, session persistence, OAuth, 2FA, and authenticated browsing. + +**Related**: [session-management.md](session-management.md) for state persistence details, [SKILL.md](../SKILL.md) for quick start. + +## Contents + +- [Import Auth from Your Browser](#import-auth-from-your-browser) +- [Persistent Profiles](#persistent-profiles) +- [Session Persistence](#session-persistence) +- [Basic Login Flow](#basic-login-flow) +- [Saving Authentication State](#saving-authentication-state) +- [Restoring Authentication](#restoring-authentication) +- [OAuth / SSO Flows](#oauth--sso-flows) +- [Two-Factor Authentication](#two-factor-authentication) +- [HTTP Basic Auth](#http-basic-auth) +- [Cookie-Based Auth](#cookie-based-auth) +- [Token Refresh Handling](#token-refresh-handling) +- [Security Best Practices](#security-best-practices) + +## Import Auth from Your Browser + +The fastest way to authenticate is to reuse cookies from a Chrome session you are already logged into. + +**Step 1: Start Chrome with remote debugging** + +```bash +# macOS +"/Applications/Google Chrome.app/Contents/MacOS/Google Chrome" --remote-debugging-port=9222 + +# Linux +google-chrome --remote-debugging-port=9222 + +# Windows +"C:\Program Files\Google\Chrome\Application\chrome.exe" --remote-debugging-port=9222 +``` + +Log in to your target site(s) in this Chrome window as you normally would. + +> **Security note:** `--remote-debugging-port` exposes full browser control on localhost. Any local process can connect and read cookies, execute JS, etc. Only use on trusted machines and close Chrome when done. + +**Step 2: Grab the auth state** + +```bash +# Auto-discover the running Chrome and save its cookies + localStorage +agent-browser --auto-connect state save ./my-auth.json +``` + +**Step 3: Reuse in automation** + +```bash +# Load auth at launch +agent-browser --state ./my-auth.json open https://app.example.com/dashboard + +# Or load into an existing session +agent-browser state load ./my-auth.json +agent-browser open https://app.example.com/dashboard +``` + +This works for any site, including those with complex OAuth flows, SSO, or 2FA -- as long as Chrome already has valid session cookies. + +> **Security note:** State files contain session tokens in plaintext. Add them to `.gitignore`, delete when no longer needed, and set `AGENT_BROWSER_ENCRYPTION_KEY` for encryption at rest. See [Security Best Practices](#security-best-practices). + +**Tip:** Combine with `--session-name` so the imported auth auto-persists across restarts: + +```bash +agent-browser --session-name myapp state load ./my-auth.json +# From now on, state is auto-saved/restored for "myapp" +``` + +## Persistent Profiles + +Use `--profile` to point agent-browser at a Chrome user data directory. This persists everything (cookies, IndexedDB, service workers, cache) across browser restarts without explicit save/load: + +```bash +# First run: login once +agent-browser --profile ~/.myapp-profile open https://app.example.com/login +# ... complete login flow ... + +# All subsequent runs: already authenticated +agent-browser --profile ~/.myapp-profile open https://app.example.com/dashboard +``` + +Use different paths for different projects or test users: + +```bash +agent-browser --profile ~/.profiles/admin open https://app.example.com +agent-browser --profile ~/.profiles/viewer open https://app.example.com +``` + +Or set via environment variable: + +```bash +export AGENT_BROWSER_PROFILE=~/.myapp-profile +agent-browser open https://app.example.com/dashboard +``` + +## Session Persistence + +Use `--session-name` to auto-save and restore cookies + localStorage by name, without managing files: + +```bash +# Auto-saves state on close, auto-restores on next launch +agent-browser --session-name twitter open https://twitter.com +# ... login flow ... +agent-browser close # state saved to ~/.agent-browser/sessions/ + +# Next time: state is automatically restored +agent-browser --session-name twitter open https://twitter.com +``` + +Encrypt state at rest: + +```bash +export AGENT_BROWSER_ENCRYPTION_KEY=$(openssl rand -hex 32) +agent-browser --session-name secure open https://app.example.com +``` + +## Basic Login Flow + +```bash +# Navigate to login page +agent-browser open https://app.example.com/login +agent-browser wait --load networkidle + +# Get form elements +agent-browser snapshot -i +# Output: @e1 [input type="email"], @e2 [input type="password"], @e3 [button] "Sign In" + +# Fill credentials +agent-browser fill @e1 "user@example.com" +agent-browser fill @e2 "password123" + +# Submit +agent-browser click @e3 +agent-browser wait --load networkidle + +# Verify login succeeded +agent-browser get url # Should be dashboard, not login +``` + +## Saving Authentication State + +After logging in, save state for reuse: + +```bash +# Login first (see above) +agent-browser open https://app.example.com/login +agent-browser snapshot -i +agent-browser fill @e1 "user@example.com" +agent-browser fill @e2 "password123" +agent-browser click @e3 +agent-browser wait --url "**/dashboard" + +# Save authenticated state +agent-browser state save ./auth-state.json +``` + +## Restoring Authentication + +Skip login by loading saved state: + +```bash +# Load saved auth state +agent-browser state load ./auth-state.json + +# Navigate directly to protected page +agent-browser open https://app.example.com/dashboard + +# Verify authenticated +agent-browser snapshot -i +``` + +## OAuth / SSO Flows + +For OAuth redirects: + +```bash +# Start OAuth flow +agent-browser open https://app.example.com/auth/google + +# Handle redirects automatically +agent-browser wait --url "**/accounts.google.com**" +agent-browser snapshot -i + +# Fill Google credentials +agent-browser fill @e1 "user@gmail.com" +agent-browser click @e2 # Next button +agent-browser wait 2000 +agent-browser snapshot -i +agent-browser fill @e3 "password" +agent-browser click @e4 # Sign in + +# Wait for redirect back +agent-browser wait --url "**/app.example.com**" +agent-browser state save ./oauth-state.json +``` + +## Two-Factor Authentication + +Handle 2FA with manual intervention: + +```bash +# Login with credentials +agent-browser open https://app.example.com/login --headed # Show browser +agent-browser snapshot -i +agent-browser fill @e1 "user@example.com" +agent-browser fill @e2 "password123" +agent-browser click @e3 + +# Wait for user to complete 2FA manually +echo "Complete 2FA in the browser window..." +agent-browser wait --url "**/dashboard" --timeout 120000 + +# Save state after 2FA +agent-browser state save ./2fa-state.json +``` + +## HTTP Basic Auth + +For sites using HTTP Basic Authentication: + +```bash +# Set credentials before navigation +agent-browser set credentials username password + +# Navigate to protected resource +agent-browser open https://protected.example.com/api +``` + +## Cookie-Based Auth + +Manually set authentication cookies: + +```bash +# Set auth cookie +agent-browser cookies set session_token "abc123xyz" + +# Navigate to protected page +agent-browser open https://app.example.com/dashboard +``` + +## Token Refresh Handling + +For sessions with expiring tokens: + +```bash +#!/bin/bash +# Wrapper that handles token refresh + +STATE_FILE="./auth-state.json" + +# Try loading existing state +if [[ -f "$STATE_FILE" ]]; then + agent-browser state load "$STATE_FILE" + agent-browser open https://app.example.com/dashboard + + # Check if session is still valid + URL=$(agent-browser get url) + if [[ "$URL" == *"/login"* ]]; then + echo "Session expired, re-authenticating..." + # Perform fresh login + agent-browser snapshot -i + agent-browser fill @e1 "$USERNAME" + agent-browser fill @e2 "$PASSWORD" + agent-browser click @e3 + agent-browser wait --url "**/dashboard" + agent-browser state save "$STATE_FILE" + fi +else + # First-time login + agent-browser open https://app.example.com/login + # ... login flow ... +fi +``` + +## Security Best Practices + +1. **Never commit state files** - They contain session tokens + ```bash + echo "*.auth-state.json" >> .gitignore + ``` + +2. **Use environment variables for credentials** + ```bash + agent-browser fill @e1 "$APP_USERNAME" + agent-browser fill @e2 "$APP_PASSWORD" + ``` + +3. **Clean up after automation** + ```bash + agent-browser cookies clear + rm -f ./auth-state.json + ``` + +4. **Use short-lived sessions for CI/CD** + ```bash + # Don't persist state in CI + agent-browser open https://app.example.com/login + # ... login and perform actions ... + agent-browser close # Session ends, nothing persisted + ``` diff --git a/workspace/skills/agent-browser/references/commands.md b/workspace/skills/agent-browser/references/commands.md new file mode 100644 index 00000000..46de5f15 --- /dev/null +++ b/workspace/skills/agent-browser/references/commands.md @@ -0,0 +1,292 @@ +# Command Reference + +Complete reference for all agent-browser commands. For quick start and common patterns, see SKILL.md. + +## Navigation + +```bash +agent-browser open # Navigate to URL (aliases: goto, navigate) + # Supports: https://, http://, file://, about:, data:// + # Auto-prepends https:// if no protocol given +agent-browser back # Go back +agent-browser forward # Go forward +agent-browser reload # Reload page +agent-browser close # Close browser (aliases: quit, exit) +agent-browser connect 9222 # Connect to browser via CDP port +``` + +## Snapshot (page analysis) + +```bash +agent-browser snapshot # Full accessibility tree +agent-browser snapshot -i # Interactive elements only (recommended) +agent-browser snapshot -c # Compact output +agent-browser snapshot -d 3 # Limit depth to 3 +agent-browser snapshot -s "#main" # Scope to CSS selector +``` + +## Interactions (use @refs from snapshot) + +```bash +agent-browser click @e1 # Click +agent-browser click @e1 --new-tab # Click and open in new tab +agent-browser dblclick @e1 # Double-click +agent-browser focus @e1 # Focus element +agent-browser fill @e2 "text" # Clear and type +agent-browser type @e2 "text" # Type without clearing +agent-browser press Enter # Press key (alias: key) +agent-browser press Control+a # Key combination +agent-browser keydown Shift # Hold key down +agent-browser keyup Shift # Release key +agent-browser hover @e1 # Hover +agent-browser check @e1 # Check checkbox +agent-browser uncheck @e1 # Uncheck checkbox +agent-browser select @e1 "value" # Select dropdown option +agent-browser select @e1 "a" "b" # Select multiple options +agent-browser scroll down 500 # Scroll page (default: down 300px) +agent-browser scrollintoview @e1 # Scroll element into view (alias: scrollinto) +agent-browser drag @e1 @e2 # Drag and drop +agent-browser upload @e1 file.pdf # Upload files +``` + +## Get Information + +```bash +agent-browser get text @e1 # Get element text +agent-browser get html @e1 # Get innerHTML +agent-browser get value @e1 # Get input value +agent-browser get attr @e1 href # Get attribute +agent-browser get title # Get page title +agent-browser get url # Get current URL +agent-browser get cdp-url # Get CDP WebSocket URL +agent-browser get count ".item" # Count matching elements +agent-browser get box @e1 # Get bounding box +agent-browser get styles @e1 # Get computed styles (font, color, bg, etc.) +``` + +## Check State + +```bash +agent-browser is visible @e1 # Check if visible +agent-browser is enabled @e1 # Check if enabled +agent-browser is checked @e1 # Check if checked +``` + +## Screenshots and PDF + +```bash +agent-browser screenshot # Save to temporary directory +agent-browser screenshot path.png # Save to specific path +agent-browser screenshot --full # Full page +agent-browser pdf output.pdf # Save as PDF +``` + +## Video Recording + +```bash +agent-browser record start ./demo.webm # Start recording +agent-browser click @e1 # Perform actions +agent-browser record stop # Stop and save video +agent-browser record restart ./take2.webm # Stop current + start new +``` + +## Wait + +```bash +agent-browser wait @e1 # Wait for element +agent-browser wait 2000 # Wait milliseconds +agent-browser wait --text "Success" # Wait for text (or -t) +agent-browser wait --url "**/dashboard" # Wait for URL pattern (or -u) +agent-browser wait --load networkidle # Wait for network idle (or -l) +agent-browser wait --fn "window.ready" # Wait for JS condition (or -f) +``` + +## Mouse Control + +```bash +agent-browser mouse move 100 200 # Move mouse +agent-browser mouse down left # Press button +agent-browser mouse up left # Release button +agent-browser mouse wheel 100 # Scroll wheel +``` + +## Semantic Locators (alternative to refs) + +```bash +agent-browser find role button click --name "Submit" +agent-browser find text "Sign In" click +agent-browser find text "Sign In" click --exact # Exact match only +agent-browser find label "Email" fill "user@test.com" +agent-browser find placeholder "Search" type "query" +agent-browser find alt "Logo" click +agent-browser find title "Close" click +agent-browser find testid "submit-btn" click +agent-browser find first ".item" click +agent-browser find last ".item" click +agent-browser find nth 2 "a" hover +``` + +## Browser Settings + +```bash +agent-browser set viewport 1920 1080 # Set viewport size +agent-browser set viewport 1920 1080 2 # 2x retina (same CSS size, higher res screenshots) +agent-browser set device "iPhone 14" # Emulate device +agent-browser set geo 37.7749 -122.4194 # Set geolocation (alias: geolocation) +agent-browser set offline on # Toggle offline mode +agent-browser set headers '{"X-Key":"v"}' # Extra HTTP headers +agent-browser set credentials user pass # HTTP basic auth (alias: auth) +agent-browser set media dark # Emulate color scheme +agent-browser set media light reduced-motion # Light mode + reduced motion +``` + +## Cookies and Storage + +```bash +agent-browser cookies # Get all cookies +agent-browser cookies set name value # Set cookie +agent-browser cookies clear # Clear cookies +agent-browser storage local # Get all localStorage +agent-browser storage local key # Get specific key +agent-browser storage local set k v # Set value +agent-browser storage local clear # Clear all +``` + +## Network + +```bash +agent-browser network route # Intercept requests +agent-browser network route --abort # Block requests +agent-browser network route --body '{}' # Mock response +agent-browser network unroute [url] # Remove routes +agent-browser network requests # View tracked requests +agent-browser network requests --filter api # Filter requests +``` + +## Tabs and Windows + +```bash +agent-browser tab # List tabs +agent-browser tab new [url] # New tab +agent-browser tab 2 # Switch to tab by index +agent-browser tab close # Close current tab +agent-browser tab close 2 # Close tab by index +agent-browser window new # New window +``` + +## Frames + +```bash +agent-browser frame "#iframe" # Switch to iframe by CSS selector +agent-browser frame @e3 # Switch to iframe by element ref +agent-browser frame main # Back to main frame +``` + +### Iframe support + +Iframes are detected automatically during snapshots. When the main-frame snapshot runs, `Iframe` nodes are resolved and their content is inlined beneath the iframe element in the output (one level of nesting; iframes within iframes are not expanded). + +```bash +agent-browser snapshot -i +# @e3 [Iframe] "payment-frame" +# @e4 [input] "Card number" +# @e5 [button] "Pay" + +# Interact directly — refs inside iframes already work +agent-browser fill @e4 "4111111111111111" +agent-browser click @e5 + +# Or switch frame context for scoped snapshots +agent-browser frame @e3 # Switch using element ref +agent-browser snapshot -i # Snapshot scoped to that iframe +agent-browser frame main # Return to main frame +``` + +The `frame` command accepts: +- **Element refs** — `frame @e3` resolves the ref to an iframe element +- **CSS selectors** — `frame "#payment-iframe"` finds the iframe by selector +- **Frame name/URL** — matches against the browser's frame tree + +## Dialogs + +```bash +agent-browser dialog accept [text] # Accept dialog +agent-browser dialog dismiss # Dismiss dialog +``` + +## JavaScript + +```bash +agent-browser eval "document.title" # Simple expressions only +agent-browser eval -b "" # Any JavaScript (base64 encoded) +agent-browser eval --stdin # Read script from stdin +``` + +Use `-b`/`--base64` or `--stdin` for reliable execution. Shell escaping with nested quotes and special characters is error-prone. + +```bash +# Base64 encode your script, then: +agent-browser eval -b "ZG9jdW1lbnQucXVlcnlTZWxlY3RvcignW3NyYyo9Il9uZXh0Il0nKQ==" + +# Or use stdin with heredoc for multiline scripts: +cat <<'EOF' | agent-browser eval --stdin +const links = document.querySelectorAll('a'); +Array.from(links).map(a => a.href); +EOF +``` + +## State Management + +```bash +agent-browser state save auth.json # Save cookies, storage, auth state +agent-browser state load auth.json # Restore saved state +``` + +## Global Options + +```bash +agent-browser --session ... # Isolated browser session +agent-browser --json ... # JSON output for parsing +agent-browser --headed ... # Show browser window (not headless) +agent-browser --full ... # Full page screenshot (-f) +agent-browser --cdp ... # Connect via Chrome DevTools Protocol +agent-browser -p ... # Cloud browser provider (--provider) +agent-browser --proxy ... # Use proxy server +agent-browser --proxy-bypass # Hosts to bypass proxy +agent-browser --headers ... # HTTP headers scoped to URL's origin +agent-browser --executable-path

# Custom browser executable +agent-browser --extension ... # Load browser extension (repeatable) +agent-browser --ignore-https-errors # Ignore SSL certificate errors +agent-browser --help # Show help (-h) +agent-browser --version # Show version (-V) +agent-browser --help # Show detailed help for a command +``` + +## Debugging + +```bash +agent-browser --headed open example.com # Show browser window +agent-browser --cdp 9222 snapshot # Connect via CDP port +agent-browser connect 9222 # Alternative: connect command +agent-browser console # View console messages +agent-browser console --clear # Clear console +agent-browser errors # View page errors +agent-browser errors --clear # Clear errors +agent-browser highlight @e1 # Highlight element +agent-browser inspect # Open Chrome DevTools for this session +agent-browser trace start # Start recording trace +agent-browser trace stop trace.zip # Stop and save trace +agent-browser profiler start # Start Chrome DevTools profiling +agent-browser profiler stop trace.json # Stop and save profile +``` + +## Environment Variables + +```bash +AGENT_BROWSER_SESSION="mysession" # Default session name +AGENT_BROWSER_EXECUTABLE_PATH="/path/chrome" # Custom browser path +AGENT_BROWSER_EXTENSIONS="/ext1,/ext2" # Comma-separated extension paths +AGENT_BROWSER_PROVIDER="browserbase" # Cloud browser provider +AGENT_BROWSER_STREAM_PORT="9223" # WebSocket streaming port +AGENT_BROWSER_HOME="/path/to/agent-browser" # Custom install location +``` diff --git a/workspace/skills/agent-browser/references/profiling.md b/workspace/skills/agent-browser/references/profiling.md new file mode 100644 index 00000000..bd47eaa0 --- /dev/null +++ b/workspace/skills/agent-browser/references/profiling.md @@ -0,0 +1,120 @@ +# Profiling + +Capture Chrome DevTools performance profiles during browser automation for performance analysis. + +**Related**: [commands.md](commands.md) for full command reference, [SKILL.md](../SKILL.md) for quick start. + +## Contents + +- [Basic Profiling](#basic-profiling) +- [Profiler Commands](#profiler-commands) +- [Categories](#categories) +- [Use Cases](#use-cases) +- [Output Format](#output-format) +- [Viewing Profiles](#viewing-profiles) +- [Limitations](#limitations) + +## Basic Profiling + +```bash +# Start profiling +agent-browser profiler start + +# Perform actions +agent-browser navigate https://example.com +agent-browser click "#button" +agent-browser wait 1000 + +# Stop and save +agent-browser profiler stop ./trace.json +``` + +## Profiler Commands + +```bash +# Start profiling with default categories +agent-browser profiler start + +# Start with custom trace categories +agent-browser profiler start --categories "devtools.timeline,v8.execute,blink.user_timing" + +# Stop profiling and save to file +agent-browser profiler stop ./trace.json +``` + +## Categories + +The `--categories` flag accepts a comma-separated list of Chrome trace categories. Default categories include: + +- `devtools.timeline` -- standard DevTools performance traces +- `v8.execute` -- time spent running JavaScript +- `blink` -- renderer events +- `blink.user_timing` -- `performance.mark()` / `performance.measure()` calls +- `latencyInfo` -- input-to-latency tracking +- `renderer.scheduler` -- task scheduling and execution +- `toplevel` -- broad-spectrum basic events + +Several `disabled-by-default-*` categories are also included for detailed timeline, call stack, and V8 CPU profiling data. + +## Use Cases + +### Diagnosing Slow Page Loads + +```bash +agent-browser profiler start +agent-browser navigate https://app.example.com +agent-browser wait --load networkidle +agent-browser profiler stop ./page-load-profile.json +``` + +### Profiling User Interactions + +```bash +agent-browser navigate https://app.example.com +agent-browser profiler start +agent-browser click "#submit" +agent-browser wait 2000 +agent-browser profiler stop ./interaction-profile.json +``` + +### CI Performance Regression Checks + +```bash +#!/bin/bash +agent-browser profiler start +agent-browser navigate https://app.example.com +agent-browser wait --load networkidle +agent-browser profiler stop "./profiles/build-${BUILD_ID}.json" +``` + +## Output Format + +The output is a JSON file in Chrome Trace Event format: + +```json +{ + "traceEvents": [ + { "cat": "devtools.timeline", "name": "RunTask", "ph": "X", "ts": 12345, "dur": 100, ... }, + ... + ], + "metadata": { + "clock-domain": "LINUX_CLOCK_MONOTONIC" + } +} +``` + +The `metadata.clock-domain` field is set based on the host platform (Linux or macOS). On Windows it is omitted. + +## Viewing Profiles + +Load the output JSON file in any of these tools: + +- **Chrome DevTools**: Performance panel > Load profile (Ctrl+Shift+I > Performance) +- **Perfetto UI**: https://ui.perfetto.dev/ -- drag and drop the JSON file +- **Trace Viewer**: `chrome://tracing` in any Chromium browser + +## Limitations + +- Only works with Chromium-based browsers (Chrome, Edge). Not supported on Firefox or WebKit. +- Trace data accumulates in memory while profiling is active (capped at 5 million events). Stop profiling promptly after the area of interest. +- Data collection on stop has a 30-second timeout. If the browser is unresponsive, the stop command may fail. diff --git a/workspace/skills/agent-browser/references/proxy-support.md b/workspace/skills/agent-browser/references/proxy-support.md new file mode 100644 index 00000000..e86a8fe3 --- /dev/null +++ b/workspace/skills/agent-browser/references/proxy-support.md @@ -0,0 +1,194 @@ +# Proxy Support + +Proxy configuration for geo-testing, rate limiting avoidance, and corporate environments. + +**Related**: [commands.md](commands.md) for global options, [SKILL.md](../SKILL.md) for quick start. + +## Contents + +- [Basic Proxy Configuration](#basic-proxy-configuration) +- [Authenticated Proxy](#authenticated-proxy) +- [SOCKS Proxy](#socks-proxy) +- [Proxy Bypass](#proxy-bypass) +- [Common Use Cases](#common-use-cases) +- [Verifying Proxy Connection](#verifying-proxy-connection) +- [Troubleshooting](#troubleshooting) +- [Best Practices](#best-practices) + +## Basic Proxy Configuration + +Use the `--proxy` flag or set proxy via environment variable: + +```bash +# Via CLI flag +agent-browser --proxy "http://proxy.example.com:8080" open https://example.com + +# Via environment variable +export HTTP_PROXY="http://proxy.example.com:8080" +agent-browser open https://example.com + +# HTTPS proxy +export HTTPS_PROXY="https://proxy.example.com:8080" +agent-browser open https://example.com + +# Both +export HTTP_PROXY="http://proxy.example.com:8080" +export HTTPS_PROXY="http://proxy.example.com:8080" +agent-browser open https://example.com +``` + +## Authenticated Proxy + +For proxies requiring authentication: + +```bash +# Include credentials in URL +export HTTP_PROXY="http://username:password@proxy.example.com:8080" +agent-browser open https://example.com +``` + +## SOCKS Proxy + +```bash +# SOCKS5 proxy +export ALL_PROXY="socks5://proxy.example.com:1080" +agent-browser open https://example.com + +# SOCKS5 with auth +export ALL_PROXY="socks5://user:pass@proxy.example.com:1080" +agent-browser open https://example.com +``` + +## Proxy Bypass + +Skip proxy for specific domains using `--proxy-bypass` or `NO_PROXY`: + +```bash +# Via CLI flag +agent-browser --proxy "http://proxy.example.com:8080" --proxy-bypass "localhost,*.internal.com" open https://example.com + +# Via environment variable +export NO_PROXY="localhost,127.0.0.1,.internal.company.com" +agent-browser open https://internal.company.com # Direct connection +agent-browser open https://external.com # Via proxy +``` + +## Common Use Cases + +### Geo-Location Testing + +```bash +#!/bin/bash +# Test site from different regions using geo-located proxies + +PROXIES=( + "http://us-proxy.example.com:8080" + "http://eu-proxy.example.com:8080" + "http://asia-proxy.example.com:8080" +) + +for proxy in "${PROXIES[@]}"; do + export HTTP_PROXY="$proxy" + export HTTPS_PROXY="$proxy" + + region=$(echo "$proxy" | grep -oP '^\w+-\w+') + echo "Testing from: $region" + + agent-browser --session "$region" open https://example.com + agent-browser --session "$region" screenshot "./screenshots/$region.png" + agent-browser --session "$region" close +done +``` + +### Rotating Proxies for Scraping + +```bash +#!/bin/bash +# Rotate through proxy list to avoid rate limiting + +PROXY_LIST=( + "http://proxy1.example.com:8080" + "http://proxy2.example.com:8080" + "http://proxy3.example.com:8080" +) + +URLS=( + "https://site.com/page1" + "https://site.com/page2" + "https://site.com/page3" +) + +for i in "${!URLS[@]}"; do + proxy_index=$((i % ${#PROXY_LIST[@]})) + export HTTP_PROXY="${PROXY_LIST[$proxy_index]}" + export HTTPS_PROXY="${PROXY_LIST[$proxy_index]}" + + agent-browser open "${URLS[$i]}" + agent-browser get text body > "output-$i.txt" + agent-browser close + + sleep 1 # Polite delay +done +``` + +### Corporate Network Access + +```bash +#!/bin/bash +# Access internal sites via corporate proxy + +export HTTP_PROXY="http://corpproxy.company.com:8080" +export HTTPS_PROXY="http://corpproxy.company.com:8080" +export NO_PROXY="localhost,127.0.0.1,.company.com" + +# External sites go through proxy +agent-browser open https://external-vendor.com + +# Internal sites bypass proxy +agent-browser open https://intranet.company.com +``` + +## Verifying Proxy Connection + +```bash +# Check your apparent IP +agent-browser open https://httpbin.org/ip +agent-browser get text body +# Should show proxy's IP, not your real IP +``` + +## Troubleshooting + +### Proxy Connection Failed + +```bash +# Test proxy connectivity first +curl -x http://proxy.example.com:8080 https://httpbin.org/ip + +# Check if proxy requires auth +export HTTP_PROXY="http://user:pass@proxy.example.com:8080" +``` + +### SSL/TLS Errors Through Proxy + +Some proxies perform SSL inspection. If you encounter certificate errors: + +```bash +# For testing only - not recommended for production +agent-browser open https://example.com --ignore-https-errors +``` + +### Slow Performance + +```bash +# Use proxy only when necessary +export NO_PROXY="*.cdn.com,*.static.com" # Direct CDN access +``` + +## Best Practices + +1. **Use environment variables** - Don't hardcode proxy credentials +2. **Set NO_PROXY appropriately** - Avoid routing local traffic through proxy +3. **Test proxy before automation** - Verify connectivity with simple requests +4. **Handle proxy failures gracefully** - Implement retry logic for unstable proxies +5. **Rotate proxies for large scraping jobs** - Distribute load and avoid bans diff --git a/workspace/skills/agent-browser/references/session-management.md b/workspace/skills/agent-browser/references/session-management.md new file mode 100644 index 00000000..bb5312db --- /dev/null +++ b/workspace/skills/agent-browser/references/session-management.md @@ -0,0 +1,193 @@ +# Session Management + +Multiple isolated browser sessions with state persistence and concurrent browsing. + +**Related**: [authentication.md](authentication.md) for login patterns, [SKILL.md](../SKILL.md) for quick start. + +## Contents + +- [Named Sessions](#named-sessions) +- [Session Isolation Properties](#session-isolation-properties) +- [Session State Persistence](#session-state-persistence) +- [Common Patterns](#common-patterns) +- [Default Session](#default-session) +- [Session Cleanup](#session-cleanup) +- [Best Practices](#best-practices) + +## Named Sessions + +Use `--session` flag to isolate browser contexts: + +```bash +# Session 1: Authentication flow +agent-browser --session auth open https://app.example.com/login + +# Session 2: Public browsing (separate cookies, storage) +agent-browser --session public open https://example.com + +# Commands are isolated by session +agent-browser --session auth fill @e1 "user@example.com" +agent-browser --session public get text body +``` + +## Session Isolation Properties + +Each session has independent: +- Cookies +- LocalStorage / SessionStorage +- IndexedDB +- Cache +- Browsing history +- Open tabs + +## Session State Persistence + +### Save Session State + +```bash +# Save cookies, storage, and auth state +agent-browser state save /path/to/auth-state.json +``` + +### Load Session State + +```bash +# Restore saved state +agent-browser state load /path/to/auth-state.json + +# Continue with authenticated session +agent-browser open https://app.example.com/dashboard +``` + +### State File Contents + +```json +{ + "cookies": [...], + "localStorage": {...}, + "sessionStorage": {...}, + "origins": [...] +} +``` + +## Common Patterns + +### Authenticated Session Reuse + +```bash +#!/bin/bash +# Save login state once, reuse many times + +STATE_FILE="/tmp/auth-state.json" + +# Check if we have saved state +if [[ -f "$STATE_FILE" ]]; then + agent-browser state load "$STATE_FILE" + agent-browser open https://app.example.com/dashboard +else + # Perform login + agent-browser open https://app.example.com/login + agent-browser snapshot -i + agent-browser fill @e1 "$USERNAME" + agent-browser fill @e2 "$PASSWORD" + agent-browser click @e3 + agent-browser wait --load networkidle + + # Save for future use + agent-browser state save "$STATE_FILE" +fi +``` + +### Concurrent Scraping + +```bash +#!/bin/bash +# Scrape multiple sites concurrently + +# Start all sessions +agent-browser --session site1 open https://site1.com & +agent-browser --session site2 open https://site2.com & +agent-browser --session site3 open https://site3.com & +wait + +# Extract from each +agent-browser --session site1 get text body > site1.txt +agent-browser --session site2 get text body > site2.txt +agent-browser --session site3 get text body > site3.txt + +# Cleanup +agent-browser --session site1 close +agent-browser --session site2 close +agent-browser --session site3 close +``` + +### A/B Testing Sessions + +```bash +# Test different user experiences +agent-browser --session variant-a open "https://app.com?variant=a" +agent-browser --session variant-b open "https://app.com?variant=b" + +# Compare +agent-browser --session variant-a screenshot /tmp/variant-a.png +agent-browser --session variant-b screenshot /tmp/variant-b.png +``` + +## Default Session + +When `--session` is omitted, commands use the default session: + +```bash +# These use the same default session +agent-browser open https://example.com +agent-browser snapshot -i +agent-browser close # Closes default session +``` + +## Session Cleanup + +```bash +# Close specific session +agent-browser --session auth close + +# List active sessions +agent-browser session list +``` + +## Best Practices + +### 1. Name Sessions Semantically + +```bash +# GOOD: Clear purpose +agent-browser --session github-auth open https://github.com +agent-browser --session docs-scrape open https://docs.example.com + +# AVOID: Generic names +agent-browser --session s1 open https://github.com +``` + +### 2. Always Clean Up + +```bash +# Close sessions when done +agent-browser --session auth close +agent-browser --session scrape close +``` + +### 3. Handle State Files Securely + +```bash +# Don't commit state files (contain auth tokens!) +echo "*.auth-state.json" >> .gitignore + +# Delete after use +rm /tmp/auth-state.json +``` + +### 4. Timeout Long Sessions + +```bash +# Set timeout for automated scripts +timeout 60 agent-browser --session long-task get text body +``` diff --git a/workspace/skills/agent-browser/references/snapshot-refs.md b/workspace/skills/agent-browser/references/snapshot-refs.md new file mode 100644 index 00000000..3cc0feaf --- /dev/null +++ b/workspace/skills/agent-browser/references/snapshot-refs.md @@ -0,0 +1,219 @@ +# Snapshot and Refs + +Compact element references that reduce context usage dramatically for AI agents. + +**Related**: [commands.md](commands.md) for full command reference, [SKILL.md](../SKILL.md) for quick start. + +## Contents + +- [How Refs Work](#how-refs-work) +- [Snapshot Command](#the-snapshot-command) +- [Using Refs](#using-refs) +- [Ref Lifecycle](#ref-lifecycle) +- [Best Practices](#best-practices) +- [Ref Notation Details](#ref-notation-details) +- [Troubleshooting](#troubleshooting) + +## How Refs Work + +Traditional approach: +``` +Full DOM/HTML → AI parses → CSS selector → Action (~3000-5000 tokens) +``` + +agent-browser approach: +``` +Compact snapshot → @refs assigned → Direct interaction (~200-400 tokens) +``` + +## The Snapshot Command + +```bash +# Basic snapshot (shows page structure) +agent-browser snapshot + +# Interactive snapshot (-i flag) - RECOMMENDED +agent-browser snapshot -i +``` + +### Snapshot Output Format + +``` +Page: Example Site - Home +URL: https://example.com + +@e1 [header] + @e2 [nav] + @e3 [a] "Home" + @e4 [a] "Products" + @e5 [a] "About" + @e6 [button] "Sign In" + +@e7 [main] + @e8 [h1] "Welcome" + @e9 [form] + @e10 [input type="email"] placeholder="Email" + @e11 [input type="password"] placeholder="Password" + @e12 [button type="submit"] "Log In" + +@e13 [footer] + @e14 [a] "Privacy Policy" +``` + +## Using Refs + +Once you have refs, interact directly: + +```bash +# Click the "Sign In" button +agent-browser click @e6 + +# Fill email input +agent-browser fill @e10 "user@example.com" + +# Fill password +agent-browser fill @e11 "password123" + +# Submit the form +agent-browser click @e12 +``` + +## Ref Lifecycle + +**IMPORTANT**: Refs are invalidated when the page changes! + +```bash +# Get initial snapshot +agent-browser snapshot -i +# @e1 [button] "Next" + +# Click triggers page change +agent-browser click @e1 + +# MUST re-snapshot to get new refs! +agent-browser snapshot -i +# @e1 [h1] "Page 2" ← Different element now! +``` + +## Best Practices + +### 1. Always Snapshot Before Interacting + +```bash +# CORRECT +agent-browser open https://example.com +agent-browser snapshot -i # Get refs first +agent-browser click @e1 # Use ref + +# WRONG +agent-browser open https://example.com +agent-browser click @e1 # Ref doesn't exist yet! +``` + +### 2. Re-Snapshot After Navigation + +```bash +agent-browser click @e5 # Navigates to new page +agent-browser snapshot -i # Get new refs +agent-browser click @e1 # Use new refs +``` + +### 3. Re-Snapshot After Dynamic Changes + +```bash +agent-browser click @e1 # Opens dropdown +agent-browser snapshot -i # See dropdown items +agent-browser click @e7 # Select item +``` + +### 4. Snapshot Specific Regions + +For complex pages, snapshot specific areas: + +```bash +# Snapshot just the form +agent-browser snapshot @e9 +``` + +## Ref Notation Details + +``` +@e1 [tag type="value"] "text content" placeholder="hint" +│ │ │ │ │ +│ │ │ │ └─ Additional attributes +│ │ │ └─ Visible text +│ │ └─ Key attributes shown +│ └─ HTML tag name +└─ Unique ref ID +``` + +### Common Patterns + +``` +@e1 [button] "Submit" # Button with text +@e2 [input type="email"] # Email input +@e3 [input type="password"] # Password input +@e4 [a href="/page"] "Link Text" # Anchor link +@e5 [select] # Dropdown +@e6 [textarea] placeholder="Message" # Text area +@e7 [div class="modal"] # Container (when relevant) +@e8 [img alt="Logo"] # Image +@e9 [checkbox] checked # Checked checkbox +@e10 [radio] selected # Selected radio +``` + +## Iframes + +Snapshots automatically detect and inline iframe content. When the main-frame snapshot runs, each `Iframe` node is resolved and its child accessibility tree is included directly beneath it in the output. Refs assigned to elements inside iframes carry frame context, so interactions like `click`, `fill`, and `type` work without manually switching frames. + +```bash +agent-browser snapshot -i +# @e1 [heading] "Checkout" +# @e2 [Iframe] "payment-frame" +# @e3 [input] "Card number" +# @e4 [input] "Expiry" +# @e5 [button] "Pay" +# @e6 [button] "Cancel" + +# Interact with iframe elements directly using their refs +agent-browser fill @e3 "4111111111111111" +agent-browser fill @e4 "12/28" +agent-browser click @e5 +``` + +**Key details:** +- Only one level of iframe nesting is expanded (iframes within iframes are not recursed) +- Cross-origin iframes that block accessibility tree access are silently skipped +- Empty iframes or iframes with no interactive content are omitted from the output +- To scope a snapshot to a single iframe, use `frame @ref` then `snapshot -i` + +## Troubleshooting + +### "Ref not found" Error + +```bash +# Ref may have changed - re-snapshot +agent-browser snapshot -i +``` + +### Element Not Visible in Snapshot + +```bash +# Scroll down to reveal element +agent-browser scroll down 1000 +agent-browser snapshot -i + +# Or wait for dynamic content +agent-browser wait 1000 +agent-browser snapshot -i +``` + +### Too Many Elements + +```bash +# Snapshot specific container +agent-browser snapshot @e5 + +# Or use get text for content-only extraction +agent-browser get text @e5 +``` diff --git a/workspace/skills/agent-browser/references/video-recording.md b/workspace/skills/agent-browser/references/video-recording.md new file mode 100644 index 00000000..e6a9fb4e --- /dev/null +++ b/workspace/skills/agent-browser/references/video-recording.md @@ -0,0 +1,173 @@ +# Video Recording + +Capture browser automation as video for debugging, documentation, or verification. + +**Related**: [commands.md](commands.md) for full command reference, [SKILL.md](../SKILL.md) for quick start. + +## Contents + +- [Basic Recording](#basic-recording) +- [Recording Commands](#recording-commands) +- [Use Cases](#use-cases) +- [Best Practices](#best-practices) +- [Output Format](#output-format) +- [Limitations](#limitations) + +## Basic Recording + +```bash +# Start recording +agent-browser record start ./demo.webm + +# Perform actions +agent-browser open https://example.com +agent-browser snapshot -i +agent-browser click @e1 +agent-browser fill @e2 "test input" + +# Stop and save +agent-browser record stop +``` + +## Recording Commands + +```bash +# Start recording to file +agent-browser record start ./output.webm + +# Stop current recording +agent-browser record stop + +# Restart with new file (stops current + starts new) +agent-browser record restart ./take2.webm +``` + +## Use Cases + +### Debugging Failed Automation + +```bash +#!/bin/bash +# Record automation for debugging + +agent-browser record start ./debug-$(date +%Y%m%d-%H%M%S).webm + +# Run your automation +agent-browser open https://app.example.com +agent-browser snapshot -i +agent-browser click @e1 || { + echo "Click failed - check recording" + agent-browser record stop + exit 1 +} + +agent-browser record stop +``` + +### Documentation Generation + +```bash +#!/bin/bash +# Record workflow for documentation + +agent-browser record start ./docs/how-to-login.webm + +agent-browser open https://app.example.com/login +agent-browser wait 1000 # Pause for visibility + +agent-browser snapshot -i +agent-browser fill @e1 "demo@example.com" +agent-browser wait 500 + +agent-browser fill @e2 "password" +agent-browser wait 500 + +agent-browser click @e3 +agent-browser wait --load networkidle +agent-browser wait 1000 # Show result + +agent-browser record stop +``` + +### CI/CD Test Evidence + +```bash +#!/bin/bash +# Record E2E test runs for CI artifacts + +TEST_NAME="${1:-e2e-test}" +RECORDING_DIR="./test-recordings" +mkdir -p "$RECORDING_DIR" + +agent-browser record start "$RECORDING_DIR/$TEST_NAME-$(date +%s).webm" + +# Run test +if run_e2e_test; then + echo "Test passed" +else + echo "Test failed - recording saved" +fi + +agent-browser record stop +``` + +## Best Practices + +### 1. Add Pauses for Clarity + +```bash +# Slow down for human viewing +agent-browser click @e1 +agent-browser wait 500 # Let viewer see result +``` + +### 2. Use Descriptive Filenames + +```bash +# Include context in filename +agent-browser record start ./recordings/login-flow-2024-01-15.webm +agent-browser record start ./recordings/checkout-test-run-42.webm +``` + +### 3. Handle Recording in Error Cases + +```bash +#!/bin/bash +set -e + +cleanup() { + agent-browser record stop 2>/dev/null || true + agent-browser close 2>/dev/null || true +} +trap cleanup EXIT + +agent-browser record start ./automation.webm +# ... automation steps ... +``` + +### 4. Combine with Screenshots + +```bash +# Record video AND capture key frames +agent-browser record start ./flow.webm + +agent-browser open https://example.com +agent-browser screenshot ./screenshots/step1-homepage.png + +agent-browser click @e1 +agent-browser screenshot ./screenshots/step2-after-click.png + +agent-browser record stop +``` + +## Output Format + +- Default format: WebM (VP8/VP9 codec) +- Compatible with all modern browsers and video players +- Compressed but high quality + +## Limitations + +- Recording adds slight overhead to automation +- Large recordings can consume significant disk space +- Some headless environments may have codec limitations diff --git a/workspace/skills/agent-browser/templates/authenticated-session.sh b/workspace/skills/agent-browser/templates/authenticated-session.sh new file mode 100644 index 00000000..b66c9289 --- /dev/null +++ b/workspace/skills/agent-browser/templates/authenticated-session.sh @@ -0,0 +1,105 @@ +#!/bin/bash +# Template: Authenticated Session Workflow +# Purpose: Login once, save state, reuse for subsequent runs +# Usage: ./authenticated-session.sh [state-file] +# +# RECOMMENDED: Use the auth vault instead of this template: +# echo "" | agent-browser auth save myapp --url --username --password-stdin +# agent-browser auth login myapp +# The auth vault stores credentials securely and the LLM never sees passwords. +# +# Environment variables: +# APP_USERNAME - Login username/email +# APP_PASSWORD - Login password +# +# Two modes: +# 1. Discovery mode (default): Shows form structure so you can identify refs +# 2. Login mode: Performs actual login after you update the refs +# +# Setup steps: +# 1. Run once to see form structure (discovery mode) +# 2. Update refs in LOGIN FLOW section below +# 3. Set APP_USERNAME and APP_PASSWORD +# 4. Delete the DISCOVERY section + +set -euo pipefail + +LOGIN_URL="${1:?Usage: $0 [state-file]}" +STATE_FILE="${2:-./auth-state.json}" + +echo "Authentication workflow: $LOGIN_URL" + +# ================================================================ +# SAVED STATE: Skip login if valid saved state exists +# ================================================================ +if [[ -f "$STATE_FILE" ]]; then + echo "Loading saved state from $STATE_FILE..." + if agent-browser --state "$STATE_FILE" open "$LOGIN_URL" 2>/dev/null; then + agent-browser wait --load networkidle + + CURRENT_URL=$(agent-browser get url) + if [[ "$CURRENT_URL" != *"login"* ]] && [[ "$CURRENT_URL" != *"signin"* ]]; then + echo "Session restored successfully" + agent-browser snapshot -i + exit 0 + fi + echo "Session expired, performing fresh login..." + agent-browser close 2>/dev/null || true + else + echo "Failed to load state, re-authenticating..." + fi + rm -f "$STATE_FILE" +fi + +# ================================================================ +# DISCOVERY MODE: Shows form structure (delete after setup) +# ================================================================ +echo "Opening login page..." +agent-browser open "$LOGIN_URL" +agent-browser wait --load networkidle + +echo "" +echo "Login form structure:" +echo "---" +agent-browser snapshot -i +echo "---" +echo "" +echo "Next steps:" +echo " 1. Note the refs: username=@e?, password=@e?, submit=@e?" +echo " 2. Update the LOGIN FLOW section below with your refs" +echo " 3. Set: export APP_USERNAME='...' APP_PASSWORD='...'" +echo " 4. Delete this DISCOVERY MODE section" +echo "" +agent-browser close +exit 0 + +# ================================================================ +# LOGIN FLOW: Uncomment and customize after discovery +# ================================================================ +# : "${APP_USERNAME:?Set APP_USERNAME environment variable}" +# : "${APP_PASSWORD:?Set APP_PASSWORD environment variable}" +# +# agent-browser open "$LOGIN_URL" +# agent-browser wait --load networkidle +# agent-browser snapshot -i +# +# # Fill credentials (update refs to match your form) +# agent-browser fill @e1 "$APP_USERNAME" +# agent-browser fill @e2 "$APP_PASSWORD" +# agent-browser click @e3 +# agent-browser wait --load networkidle +# +# # Verify login succeeded +# FINAL_URL=$(agent-browser get url) +# if [[ "$FINAL_URL" == *"login"* ]] || [[ "$FINAL_URL" == *"signin"* ]]; then +# echo "Login failed - still on login page" +# agent-browser screenshot /tmp/login-failed.png +# agent-browser close +# exit 1 +# fi +# +# # Save state for future runs +# echo "Saving state to $STATE_FILE" +# agent-browser state save "$STATE_FILE" +# echo "Login successful" +# agent-browser snapshot -i diff --git a/workspace/skills/agent-browser/templates/capture-workflow.sh b/workspace/skills/agent-browser/templates/capture-workflow.sh new file mode 100644 index 00000000..3bc93ad0 --- /dev/null +++ b/workspace/skills/agent-browser/templates/capture-workflow.sh @@ -0,0 +1,69 @@ +#!/bin/bash +# Template: Content Capture Workflow +# Purpose: Extract content from web pages (text, screenshots, PDF) +# Usage: ./capture-workflow.sh [output-dir] +# +# Outputs: +# - page-full.png: Full page screenshot +# - page-structure.txt: Page element structure with refs +# - page-text.txt: All text content +# - page.pdf: PDF version +# +# Optional: Load auth state for protected pages + +set -euo pipefail + +TARGET_URL="${1:?Usage: $0 [output-dir]}" +OUTPUT_DIR="${2:-.}" + +echo "Capturing: $TARGET_URL" +mkdir -p "$OUTPUT_DIR" + +# Optional: Load authentication state +# if [[ -f "./auth-state.json" ]]; then +# echo "Loading authentication state..." +# agent-browser state load "./auth-state.json" +# fi + +# Navigate to target +agent-browser open "$TARGET_URL" +agent-browser wait --load networkidle + +# Get metadata +TITLE=$(agent-browser get title) +URL=$(agent-browser get url) +echo "Title: $TITLE" +echo "URL: $URL" + +# Capture full page screenshot +agent-browser screenshot --full "$OUTPUT_DIR/page-full.png" +echo "Saved: $OUTPUT_DIR/page-full.png" + +# Get page structure with refs +agent-browser snapshot -i > "$OUTPUT_DIR/page-structure.txt" +echo "Saved: $OUTPUT_DIR/page-structure.txt" + +# Extract all text content +agent-browser get text body > "$OUTPUT_DIR/page-text.txt" +echo "Saved: $OUTPUT_DIR/page-text.txt" + +# Save as PDF +agent-browser pdf "$OUTPUT_DIR/page.pdf" +echo "Saved: $OUTPUT_DIR/page.pdf" + +# Optional: Extract specific elements using refs from structure +# agent-browser get text @e5 > "$OUTPUT_DIR/main-content.txt" + +# Optional: Handle infinite scroll pages +# for i in {1..5}; do +# agent-browser scroll down 1000 +# agent-browser wait 1000 +# done +# agent-browser screenshot --full "$OUTPUT_DIR/page-scrolled.png" + +# Cleanup +agent-browser close + +echo "" +echo "Capture complete:" +ls -la "$OUTPUT_DIR" diff --git a/workspace/skills/agent-browser/templates/form-automation.sh b/workspace/skills/agent-browser/templates/form-automation.sh new file mode 100644 index 00000000..6784fcd3 --- /dev/null +++ b/workspace/skills/agent-browser/templates/form-automation.sh @@ -0,0 +1,62 @@ +#!/bin/bash +# Template: Form Automation Workflow +# Purpose: Fill and submit web forms with validation +# Usage: ./form-automation.sh +# +# This template demonstrates the snapshot-interact-verify pattern: +# 1. Navigate to form +# 2. Snapshot to get element refs +# 3. Fill fields using refs +# 4. Submit and verify result +# +# Customize: Update the refs (@e1, @e2, etc.) based on your form's snapshot output + +set -euo pipefail + +FORM_URL="${1:?Usage: $0 }" + +echo "Form automation: $FORM_URL" + +# Step 1: Navigate to form +agent-browser open "$FORM_URL" +agent-browser wait --load networkidle + +# Step 2: Snapshot to discover form elements +echo "" +echo "Form structure:" +agent-browser snapshot -i + +# Step 3: Fill form fields (customize these refs based on snapshot output) +# +# Common field types: +# agent-browser fill @e1 "John Doe" # Text input +# agent-browser fill @e2 "user@example.com" # Email input +# agent-browser fill @e3 "SecureP@ss123" # Password input +# agent-browser select @e4 "Option Value" # Dropdown +# agent-browser check @e5 # Checkbox +# agent-browser click @e6 # Radio button +# agent-browser fill @e7 "Multi-line text" # Textarea +# agent-browser upload @e8 /path/to/file.pdf # File upload +# +# Uncomment and modify: +# agent-browser fill @e1 "Test User" +# agent-browser fill @e2 "test@example.com" +# agent-browser click @e3 # Submit button + +# Step 4: Wait for submission +# agent-browser wait --load networkidle +# agent-browser wait --url "**/success" # Or wait for redirect + +# Step 5: Verify result +echo "" +echo "Result:" +agent-browser get url +agent-browser snapshot -i + +# Optional: Capture evidence +agent-browser screenshot /tmp/form-result.png +echo "Screenshot saved: /tmp/form-result.png" + +# Cleanup +agent-browser close +echo "Done" diff --git a/workspace/skills/clawhub/SKILL.md b/workspace/skills/clawhub/SKILL.md new file mode 100644 index 00000000..b28a1de0 --- /dev/null +++ b/workspace/skills/clawhub/SKILL.md @@ -0,0 +1,53 @@ +--- +name: clawhub +description: Search and install agent skills from ClawHub, the public skill registry. +homepage: https://clawhub.ai +metadata: {"internclaw":{"emoji":"🦞"}} +--- + +# ClawHub + +Public skill registry for AI agents. Search by natural language (vector search). + +## When to use + +Use this skill when the user asks any of: +- "find a skill for …" +- "search for skills" +- "install a skill" +- "what skills are available?" +- "update my skills" + +## Search + +```bash +npx --yes clawhub@latest search "web scraping" --limit 5 +``` + +## Install + +```bash +npx --yes clawhub@latest install --workdir ~/.internclaw/workspace +``` + +Replace `` with the skill name from search results. This places the skill into `~/.internclaw/workspace/skills/`, where internclaw loads workspace skills from. Always include `--workdir`. + +## Update + +```bash +npx --yes clawhub@latest update --all --workdir ~/.internclaw/workspace +``` + +## List installed + +```bash +npx --yes clawhub@latest list --workdir ~/.internclaw/workspace +``` + +## Notes + +- Requires Node.js (`npx` comes with it). +- No API key needed for search and install. +- Login (`npx --yes clawhub@latest login`) is only required for publishing. +- `--workdir ~/.internclaw/workspace` is critical — without it, skills install to the current directory instead of the internclaw workspace. +- After install, remind the user to start a new session to load the skill. diff --git a/workspace/skills/cron/SKILL.md b/workspace/skills/cron/SKILL.md new file mode 100644 index 00000000..48eaa805 --- /dev/null +++ b/workspace/skills/cron/SKILL.md @@ -0,0 +1,57 @@ +--- +name: cron +description: Schedule reminders and recurring tasks. +--- + +# Cron + +Use the `cron` tool to schedule reminders or recurring tasks. + +## Three Modes + +1. **Reminder** - message is sent directly to user +2. **Task** - message is a task description, agent executes and sends result +3. **One-time** - runs once at a specific time, then auto-deletes + +## Examples + +Fixed reminder: +``` +cron(action="add", message="Time to take a break!", every_seconds=1200) +``` + +Dynamic task (agent executes each time): +``` +cron(action="add", message="Check HKUDS/internclaw GitHub stars and report", every_seconds=600) +``` + +One-time scheduled task (compute ISO datetime from current time): +``` +cron(action="add", message="Remind me about the meeting", at="") +``` + +Timezone-aware cron: +``` +cron(action="add", message="Morning standup", cron_expr="0 9 * * 1-5", tz="America/Vancouver") +``` + +List/remove: +``` +cron(action="list") +cron(action="remove", job_id="abc123") +``` + +## Time Expressions + +| User says | Parameters | +|-----------|------------| +| every 20 minutes | every_seconds: 1200 | +| every hour | every_seconds: 3600 | +| every day at 8am | cron_expr: "0 8 * * *" | +| weekdays at 5pm | cron_expr: "0 17 * * 1-5" | +| 9am Vancouver time daily | cron_expr: "0 9 * * *", tz: "America/Vancouver" | +| at a specific time | at: ISO datetime string (compute from current time) | + +## Timezone + +Use `tz` with `cron_expr` to schedule in a specific IANA timezone. Without `tz`, the server's local timezone is used. diff --git a/workspace/skills/github/SKILL.md b/workspace/skills/github/SKILL.md new file mode 100644 index 00000000..6fa00ecb --- /dev/null +++ b/workspace/skills/github/SKILL.md @@ -0,0 +1,48 @@ +--- +name: github +description: "Interact with GitHub using the `gh` CLI. Use `gh issue`, `gh pr`, `gh run`, and `gh api` for issues, PRs, CI runs, and advanced queries." +metadata: {"internclaw":{"emoji":"🐙","requires":{"bins":["gh"]},"install":[{"id":"brew","kind":"brew","formula":"gh","bins":["gh"],"label":"Install GitHub CLI (brew)"},{"id":"apt","kind":"apt","package":"gh","bins":["gh"],"label":"Install GitHub CLI (apt)"}]}} +--- + +# GitHub Skill + +Use the `gh` CLI to interact with GitHub. Always specify `--repo owner/repo` when not in a git directory, or use URLs directly. + +## Pull Requests + +Check CI status on a PR: +```bash +gh pr checks 55 --repo owner/repo +``` + +List recent workflow runs: +```bash +gh run list --repo owner/repo --limit 10 +``` + +View a run and see which steps failed: +```bash +gh run view --repo owner/repo +``` + +View logs for failed steps only: +```bash +gh run view --repo owner/repo --log-failed +``` + +## API for Advanced Queries + +The `gh api` command is useful for accessing data not available through other subcommands. + +Get PR with specific fields: +```bash +gh api repos/owner/repo/pulls/55 --jq '.title, .state, .user.login' +``` + +## JSON Output + +Most commands support `--json` for structured output. You can use `--jq` to filter: + +```bash +gh issue list --repo owner/repo --json number,title --jq '.[] | "\(.number): \(.title)"' +``` diff --git a/workspace/skills/memory/SKILL.md b/workspace/skills/memory/SKILL.md new file mode 100644 index 00000000..529a02d5 --- /dev/null +++ b/workspace/skills/memory/SKILL.md @@ -0,0 +1,31 @@ +--- +name: memory +description: Two-layer memory system with grep-based recall. +always: true +--- + +# Memory + +## Structure + +- `memory/MEMORY.md` — Long-term facts (preferences, project context, relationships). Always loaded into your context. +- `memory/HISTORY.md` — Append-only event log. NOT loaded into context. Search it with grep. Each entry starts with [YYYY-MM-DD HH:MM]. + +## Search Past Events + +```bash +grep -i "keyword" memory/HISTORY.md +``` + +Use the `exec` tool to run grep. Combine patterns: `grep -iE "meeting|deadline" memory/HISTORY.md` + +## When to Update MEMORY.md + +Write important facts immediately using `edit_file` or `write_file`: +- User preferences ("I prefer dark mode") +- Project context ("The API uses OAuth2") +- Relationships ("Alice is the project lead") + +## Auto-consolidation + +Old conversations are automatically summarized and appended to HISTORY.md when the session grows large. Long-term facts are extracted to MEMORY.md. You don't need to manage this. diff --git a/workspace/skills/skill-creator/SKILL.md b/workspace/skills/skill-creator/SKILL.md new file mode 100644 index 00000000..9b5eb6fe --- /dev/null +++ b/workspace/skills/skill-creator/SKILL.md @@ -0,0 +1,371 @@ +--- +name: skill-creator +description: Create or update AgentSkills. Use when designing, structuring, or packaging skills with scripts, references, and assets. +--- + +# Skill Creator + +This skill provides guidance for creating effective skills. + +## About Skills + +Skills are modular, self-contained packages that extend the agent's capabilities by providing +specialized knowledge, workflows, and tools. Think of them as "onboarding guides" for specific +domains or tasks—they transform the agent from a general-purpose agent into a specialized agent +equipped with procedural knowledge that no model can fully possess. + +### What Skills Provide + +1. Specialized workflows - Multi-step procedures for specific domains +2. Tool integrations - Instructions for working with specific file formats or APIs +3. Domain expertise - Company-specific knowledge, schemas, business logic +4. Bundled resources - Scripts, references, and assets for complex and repetitive tasks + +## Core Principles + +### Concise is Key + +The context window is a public good. Skills share the context window with everything else the agent needs: system prompt, conversation history, other Skills' metadata, and the actual user request. + +**Default assumption: the agent is already very smart.** Only add context the agent doesn't already have. Challenge each piece of information: "Does the agent really need this explanation?" and "Does this paragraph justify its token cost?" + +Prefer concise examples over verbose explanations. + +### Set Appropriate Degrees of Freedom + +Match the level of specificity to the task's fragility and variability: + +**High freedom (text-based instructions)**: Use when multiple approaches are valid, decisions depend on context, or heuristics guide the approach. + +**Medium freedom (pseudocode or scripts with parameters)**: Use when a preferred pattern exists, some variation is acceptable, or configuration affects behavior. + +**Low freedom (specific scripts, few parameters)**: Use when operations are fragile and error-prone, consistency is critical, or a specific sequence must be followed. + +Think of the agent as exploring a path: a narrow bridge with cliffs needs specific guardrails (low freedom), while an open field allows many routes (high freedom). + +### Anatomy of a Skill + +Every skill consists of a required SKILL.md file and optional bundled resources: + +``` +skill-name/ +├── SKILL.md (required) +│ ├── YAML frontmatter metadata (required) +│ │ ├── name: (required) +│ │ └── description: (required) +│ └── Markdown instructions (required) +└── Bundled Resources (optional) + ├── scripts/ - Executable code (Python/Bash/etc.) + ├── references/ - Documentation intended to be loaded into context as needed + └── assets/ - Files used in output (templates, icons, fonts, etc.) +``` + +#### SKILL.md (required) + +Every SKILL.md consists of: + +- **Frontmatter** (YAML): Contains `name` and `description` fields. These are the only fields that the agent reads to determine when the skill gets used, thus it is very important to be clear and comprehensive in describing what the skill is, and when it should be used. +- **Body** (Markdown): Instructions and guidance for using the skill. Only loaded AFTER the skill triggers (if at all). + +#### Bundled Resources (optional) + +##### Scripts (`scripts/`) + +Executable code (Python/Bash/etc.) for tasks that require deterministic reliability or are repeatedly rewritten. + +- **When to include**: When the same code is being rewritten repeatedly or deterministic reliability is needed +- **Example**: `scripts/rotate_pdf.py` for PDF rotation tasks +- **Benefits**: Token efficient, deterministic, may be executed without loading into context +- **Note**: Scripts may still need to be read by the agent for patching or environment-specific adjustments + +##### References (`references/`) + +Documentation and reference material intended to be loaded as needed into context to inform the agent's process and thinking. + +- **When to include**: For documentation that the agent should reference while working +- **Examples**: `references/finance.md` for financial schemas, `references/mnda.md` for company NDA template, `references/policies.md` for company policies, `references/api_docs.md` for API specifications +- **Use cases**: Database schemas, API documentation, domain knowledge, company policies, detailed workflow guides +- **Benefits**: Keeps SKILL.md lean, loaded only when the agent determines it's needed +- **Best practice**: If files are large (>10k words), include grep search patterns in SKILL.md +- **Avoid duplication**: Information should live in either SKILL.md or references files, not both. Prefer references files for detailed information unless it's truly core to the skill—this keeps SKILL.md lean while making information discoverable without hogging the context window. Keep only essential procedural instructions and workflow guidance in SKILL.md; move detailed reference material, schemas, and examples to references files. + +##### Assets (`assets/`) + +Files not intended to be loaded into context, but rather used within the output the agent produces. + +- **When to include**: When the skill needs files that will be used in the final output +- **Examples**: `assets/logo.png` for brand assets, `assets/slides.pptx` for PowerPoint templates, `assets/frontend-template/` for HTML/React boilerplate, `assets/font.ttf` for typography +- **Use cases**: Templates, images, icons, boilerplate code, fonts, sample documents that get copied or modified +- **Benefits**: Separates output resources from documentation, enables the agent to use files without loading them into context + +#### What to Not Include in a Skill + +A skill should only contain essential files that directly support its functionality. Do NOT create extraneous documentation or auxiliary files, including: + +- README.md +- INSTALLATION_GUIDE.md +- QUICK_REFERENCE.md +- CHANGELOG.md +- etc. + +The skill should only contain the information needed for an AI agent to do the job at hand. It should not contain auxiliary context about the process that went into creating it, setup and testing procedures, user-facing documentation, etc. Creating additional documentation files just adds clutter and confusion. + +### Progressive Disclosure Design Principle + +Skills use a three-level loading system to manage context efficiently: + +1. **Metadata (name + description)** - Always in context (~100 words) +2. **SKILL.md body** - When skill triggers (<5k words) +3. **Bundled resources** - As needed by the agent (Unlimited because scripts can be executed without reading into context window) + +#### Progressive Disclosure Patterns + +Keep SKILL.md body to the essentials and under 500 lines to minimize context bloat. Split content into separate files when approaching this limit. When splitting out content into other files, it is very important to reference them from SKILL.md and describe clearly when to read them, to ensure the reader of the skill knows they exist and when to use them. + +**Key principle:** When a skill supports multiple variations, frameworks, or options, keep only the core workflow and selection guidance in SKILL.md. Move variant-specific details (patterns, examples, configuration) into separate reference files. + +**Pattern 1: High-level guide with references** + +```markdown +# PDF Processing + +## Quick start + +Extract text with pdfplumber: +[code example] + +## Advanced features + +- **Form filling**: See [FORMS.md](FORMS.md) for complete guide +- **API reference**: See [REFERENCE.md](REFERENCE.md) for all methods +- **Examples**: See [EXAMPLES.md](EXAMPLES.md) for common patterns +``` + +the agent loads FORMS.md, REFERENCE.md, or EXAMPLES.md only when needed. + +**Pattern 2: Domain-specific organization** + +For Skills with multiple domains, organize content by domain to avoid loading irrelevant context: + +``` +bigquery-skill/ +├── SKILL.md (overview and navigation) +└── reference/ + ├── finance.md (revenue, billing metrics) + ├── sales.md (opportunities, pipeline) + ├── product.md (API usage, features) + └── marketing.md (campaigns, attribution) +``` + +When a user asks about sales metrics, the agent only reads sales.md. + +Similarly, for skills supporting multiple frameworks or variants, organize by variant: + +``` +cloud-deploy/ +├── SKILL.md (workflow + provider selection) +└── references/ + ├── aws.md (AWS deployment patterns) + ├── gcp.md (GCP deployment patterns) + └── azure.md (Azure deployment patterns) +``` + +When the user chooses AWS, the agent only reads aws.md. + +**Pattern 3: Conditional details** + +Show basic content, link to advanced content: + +```markdown +# DOCX Processing + +## Creating documents + +Use docx-js for new documents. See [DOCX-JS.md](DOCX-JS.md). + +## Editing documents + +For simple edits, modify the XML directly. + +**For tracked changes**: See [REDLINING.md](REDLINING.md) +**For OOXML details**: See [OOXML.md](OOXML.md) +``` + +the agent reads REDLINING.md or OOXML.md only when the user needs those features. + +**Important guidelines:** + +- **Avoid deeply nested references** - Keep references one level deep from SKILL.md. All reference files should link directly from SKILL.md. +- **Structure longer reference files** - For files longer than 100 lines, include a table of contents at the top so the agent can see the full scope when previewing. + +## Skill Creation Process + +Skill creation involves these steps: + +1. Understand the skill with concrete examples +2. Plan reusable skill contents (scripts, references, assets) +3. Initialize the skill (run init_skill.py) +4. Edit the skill (implement resources and write SKILL.md) +5. Package the skill (run package_skill.py) +6. Iterate based on real usage + +Follow these steps in order, skipping only if there is a clear reason why they are not applicable. + +### Skill Naming + +- Use lowercase letters, digits, and hyphens only; normalize user-provided titles to hyphen-case (e.g., "Plan Mode" -> `plan-mode`). +- When generating names, generate a name under 64 characters (letters, digits, hyphens). +- Prefer short, verb-led phrases that describe the action. +- Namespace by tool when it improves clarity or triggering (e.g., `gh-address-comments`, `linear-address-issue`). +- Name the skill folder exactly after the skill name. + +### Step 1: Understanding the Skill with Concrete Examples + +Skip this step only when the skill's usage patterns are already clearly understood. It remains valuable even when working with an existing skill. + +To create an effective skill, clearly understand concrete examples of how the skill will be used. This understanding can come from either direct user examples or generated examples that are validated with user feedback. + +For example, when building an image-editor skill, relevant questions include: + +- "What functionality should the image-editor skill support? Editing, rotating, anything else?" +- "Can you give some examples of how this skill would be used?" +- "I can imagine users asking for things like 'Remove the red-eye from this image' or 'Rotate this image'. Are there other ways you imagine this skill being used?" +- "What would a user say that should trigger this skill?" + +To avoid overwhelming users, avoid asking too many questions in a single message. Start with the most important questions and follow up as needed for better effectiveness. + +Conclude this step when there is a clear sense of the functionality the skill should support. + +### Step 2: Planning the Reusable Skill Contents + +To turn concrete examples into an effective skill, analyze each example by: + +1. Considering how to execute on the example from scratch +2. Identifying what scripts, references, and assets would be helpful when executing these workflows repeatedly + +Example: When building a `pdf-editor` skill to handle queries like "Help me rotate this PDF," the analysis shows: + +1. Rotating a PDF requires re-writing the same code each time +2. A `scripts/rotate_pdf.py` script would be helpful to store in the skill + +Example: When designing a `frontend-webapp-builder` skill for queries like "Build me a todo app" or "Build me a dashboard to track my steps," the analysis shows: + +1. Writing a frontend webapp requires the same boilerplate HTML/React each time +2. An `assets/hello-world/` template containing the boilerplate HTML/React project files would be helpful to store in the skill + +Example: When building a `big-query` skill to handle queries like "How many users have logged in today?" the analysis shows: + +1. Querying BigQuery requires re-discovering the table schemas and relationships each time +2. A `references/schema.md` file documenting the table schemas would be helpful to store in the skill + +To establish the skill's contents, analyze each concrete example to create a list of the reusable resources to include: scripts, references, and assets. + +### Step 3: Initializing the Skill + +At this point, it is time to actually create the skill. + +Skip this step only if the skill being developed already exists, and iteration or packaging is needed. In this case, continue to the next step. + +When creating a new skill from scratch, always run the `init_skill.py` script. The script conveniently generates a new template skill directory that automatically includes everything a skill requires, making the skill creation process much more efficient and reliable. + +Usage: + +```bash +scripts/init_skill.py --path [--resources scripts,references,assets] [--examples] +``` + +Examples: + +```bash +scripts/init_skill.py my-skill --path skills/public +scripts/init_skill.py my-skill --path skills/public --resources scripts,references +scripts/init_skill.py my-skill --path skills/public --resources scripts --examples +``` + +The script: + +- Creates the skill directory at the specified path +- Generates a SKILL.md template with proper frontmatter and TODO placeholders +- Optionally creates resource directories based on `--resources` +- Optionally adds example files when `--examples` is set + +After initialization, customize the SKILL.md and add resources as needed. If you used `--examples`, replace or delete placeholder files. + +### Step 4: Edit the Skill + +When editing the (newly-generated or existing) skill, remember that the skill is being created for another instance of the agent to use. Include information that would be beneficial and non-obvious to the agent. Consider what procedural knowledge, domain-specific details, or reusable assets would help another the agent instance execute these tasks more effectively. + +#### Learn Proven Design Patterns + +Consult these helpful guides based on your skill's needs: + +- **Multi-step processes**: See references/workflows.md for sequential workflows and conditional logic +- **Specific output formats or quality standards**: See references/output-patterns.md for template and example patterns + +These files contain established best practices for effective skill design. + +#### Start with Reusable Skill Contents + +To begin implementation, start with the reusable resources identified above: `scripts/`, `references/`, and `assets/` files. Note that this step may require user input. For example, when implementing a `brand-guidelines` skill, the user may need to provide brand assets or templates to store in `assets/`, or documentation to store in `references/`. + +Added scripts must be tested by actually running them to ensure there are no bugs and that the output matches what is expected. If there are many similar scripts, only a representative sample needs to be tested to ensure confidence that they all work while balancing time to completion. + +If you used `--examples`, delete any placeholder files that are not needed for the skill. Only create resource directories that are actually required. + +#### Update SKILL.md + +**Writing Guidelines:** Always use imperative/infinitive form. + +##### Frontmatter + +Write the YAML frontmatter with `name` and `description`: + +- `name`: The skill name +- `description`: This is the primary triggering mechanism for your skill, and helps the agent understand when to use the skill. + - Include both what the Skill does and specific triggers/contexts for when to use it. + - Include all "when to use" information here - Not in the body. The body is only loaded after triggering, so "When to Use This Skill" sections in the body are not helpful to the agent. + - Example description for a `docx` skill: "Comprehensive document creation, editing, and analysis with support for tracked changes, comments, formatting preservation, and text extraction. Use when the agent needs to work with professional documents (.docx files) for: (1) Creating new documents, (2) Modifying or editing content, (3) Working with tracked changes, (4) Adding comments, or any other document tasks" + +Do not include any other fields in YAML frontmatter. + +##### Body + +Write instructions for using the skill and its bundled resources. + +### Step 5: Packaging a Skill + +Once development of the skill is complete, it must be packaged into a distributable .skill file that gets shared with the user. The packaging process automatically validates the skill first to ensure it meets all requirements: + +```bash +scripts/package_skill.py +``` + +Optional output directory specification: + +```bash +scripts/package_skill.py ./dist +``` + +The packaging script will: + +1. **Validate** the skill automatically, checking: + + - YAML frontmatter format and required fields + - Skill naming conventions and directory structure + - Description completeness and quality + - File organization and resource references + +2. **Package** the skill if validation passes, creating a .skill file named after the skill (e.g., `my-skill.skill`) that includes all files and maintains the proper directory structure for distribution. The .skill file is a zip file with a .skill extension. + +If validation fails, the script will report the errors and exit without creating a package. Fix any validation errors and run the packaging command again. + +### Step 6: Iterate + +After testing the skill, users may request improvements. Often this happens right after using the skill, with fresh context of how the skill performed. + +**Iteration workflow:** + +1. Use the skill on real tasks +2. Notice struggles or inefficiencies +3. Identify how SKILL.md or bundled resources should be updated +4. Implement changes and test again diff --git a/workspace/skills/summarize/SKILL.md b/workspace/skills/summarize/SKILL.md new file mode 100644 index 00000000..01805765 --- /dev/null +++ b/workspace/skills/summarize/SKILL.md @@ -0,0 +1,67 @@ +--- +name: summarize +description: Summarize or extract text/transcripts from URLs, podcasts, and local files (great fallback for “transcribe this YouTube/video”). +homepage: https://summarize.sh +metadata: {"internclaw":{"emoji":"🧾","requires":{"bins":["summarize"]},"install":[{"id":"brew","kind":"brew","formula":"steipete/tap/summarize","bins":["summarize"],"label":"Install summarize (brew)"}]}} +--- + +# Summarize + +Fast CLI to summarize URLs, local files, and YouTube links. + +## When to use (trigger phrases) + +Use this skill immediately when the user asks any of: +- “use summarize.sh” +- “what’s this link/video about?” +- “summarize this URL/article” +- “transcribe this YouTube/video” (best-effort transcript extraction; no `yt-dlp` needed) + +## Quick start + +```bash +summarize "https://example.com" --model google/gemini-3-flash-preview +summarize "/path/to/file.pdf" --model google/gemini-3-flash-preview +summarize "https://youtu.be/dQw4w9WgXcQ" --youtube auto +``` + +## YouTube: summary vs transcript + +Best-effort transcript (URLs only): + +```bash +summarize "https://youtu.be/dQw4w9WgXcQ" --youtube auto --extract-only +``` + +If the user asked for a transcript but it’s huge, return a tight summary first, then ask which section/time range to expand. + +## Model + keys + +Set the API key for your chosen provider: +- OpenAI: `OPENAI_API_KEY` +- Anthropic: `ANTHROPIC_API_KEY` +- xAI: `XAI_API_KEY` +- Google: `GEMINI_API_KEY` (aliases: `GOOGLE_GENERATIVE_AI_API_KEY`, `GOOGLE_API_KEY`) + +Default model is `google/gemini-3-flash-preview` if none is set. + +## Useful flags + +- `--length short|medium|long|xl|xxl|` +- `--max-output-tokens ` +- `--extract-only` (URLs only) +- `--json` (machine readable) +- `--firecrawl auto|off|always` (fallback extraction) +- `--youtube auto` (Apify fallback if `APIFY_API_TOKEN` set) + +## Config + +Optional config file: `~/.summarize/config.json` + +```json +{ "model": "openai/gpt-5.2" } +``` + +Optional services: +- `FIRECRAWL_API_KEY` for blocked sites +- `APIFY_API_TOKEN` for YouTube fallback diff --git a/workspace/skills/task-synthesis/SKILL.md b/workspace/skills/task-synthesis/SKILL.md new file mode 100644 index 00000000..83f1098d --- /dev/null +++ b/workspace/skills/task-synthesis/SKILL.md @@ -0,0 +1,229 @@ +--- +name: task-synthesis +description: Synthesize complete agent evaluation task packages from an OpenClaw skill. Use when given a skill (SKILL.md path, skill name, or skill content) and asked to generate a benchmark task — including task.md (with frontmatter, Prompt, Expected Behavior, Assets Preparation, Toolset Preparation, Environment Specification, Grading Criteria, Automated Checks, LLM Judge Rubric), an assets/ folder, a tools/ folder with OpenClaw-loadable tool definitions, and a scripts/ folder containing env_spec.sh and any helper scripts. Triggers on phrases like "generate a task from this skill", "synthesize a task for skill X", "create a benchmark task", "draft a task spec". +--- + +# Task Synthesis + +## Overview + +This skill produces a complete, ready-to-run evaluation task package derived from an existing OpenClaw skill. The input is a skill (SKILL.md, skill name, or a path to the skill); the output is a self-contained task directory: + +| Output | Description | +|---|---| +| `task.md` | Task definition following the standard template | +| `assets/` | All input files and directories the agent needs at runtime | +| `tools/` | OpenClaw-loadable tool definitions the agent can call | +| `scripts/` | Executable scripts: `env_spec.sh` (required) + any helpers | + +**`tools/` vs `scripts/` distinction:** +- `tools/` — tool definitions that OpenClaw loads and exposes to the agent as callable tools (e.g., `get_video_meta`, `query_database`). These are invoked by the agent during the task. +- `scripts/` — shell or Python scripts that run in the execution environment, not called by the agent directly. Always includes `env_spec.sh` (even if empty). May include helper scripts like `init_db.py`, `mock_server.py`, etc. + +For the canonical task.md schema, field rules, and blank template — see `references/task-schema.md` and `references/task-blank-template.md`. +See `references/synthesis-patterns.md` for asset, tool, script, and grading patterns. + +--- + +## Workflow + +### Step 1 — Analyze the Skill + +Read the skill's SKILL.md (and any referenced files), then establish three things: + +**1. Classify family and archetype** + +**Skill Family** — where does this skill primarily operate? Choose one: + +`Documents & Files` / `Filesystem & Workspace` / `Code / Repo / IDE` / `Browser & Web` / `SaaS / API Integrations` / `Local Apps / OS Automation` / `Media` / `DevOps / System / Cloud` / `Search & Research` / `Agent Infra / Orchestration` / `Hybrid` + +**Skill Archetype** — what is the primary action this skill performs? Choose one: + +`retrieve` / `extract` / `transform` / `organize` / `analyze` / `generate` / `edit` / `execute` / `validate` / `monitor` / `automate` / `hybrid` + +Record both as `family` and `archetype` — they feed into the task.md frontmatter and inform query realism and grading focus in later steps. + +**2. Map capability boundary (Provides / Requires)** + +| Set | Question | Examples | +|---|---|---| +| **Provides** | What tools, APIs, or capabilities does this skill itself expose? | Custom OpenClaw tools it ships, helper scripts it installs, services it starts, APIs it wraps | +| **Requires** | What must already exist for this skill to function? | CLI binaries (`ffmpeg`, `git`), language runtimes, API keys, OS packages, specific file formats | + +- **Provides** → comes with the skill; the task author does not need to prepare these +- **Requires** → what the task author must prepare: install in `env_spec.sh`, synthesize in `tools/`, declare in `## Toolset Preparation` + +**3. Extract baseline information** + +- **Trigger scenarios**: what kinds of user requests does this skill handle? +- **Required inputs**: what must the user provide for the skill to work? +- **Optional / defaultable inputs**: what can be inferred or has a sensible default? +- **Failure modes**: what can go wrong if inputs are missing or malformed? + +With these three established, you have everything needed to write a realistic query and design the execution environment and evaluation rules. + +### Step 2 — Write the Query + +With classification and capability boundary established, write the `## Prompt` as a real user message. Two principles govern both the prompt and its relationship to conditions in Step 3: + +1. **Natural first.** Write from the user's perspective — goal-oriented, at whatever level of detail feels realistic for the task. A simple request can be one line; a complex one may include steps or constraints. What it should never read like is a task spec or evaluation brief written by the task author. + +2. **Query and conditions are a pair.** Write the prompt with awareness of what conditions will (and will not) be provided. What the conditions omit shapes what the agent must do — ask, infer, or explore — and that shapes grading just as much as what the query says. + +Additional rules: +- **Language: Chinese.** Write the prompt in Chinese (Simplified). +- **No absolute paths, no `assets/`、`task/`、`workspace/` prefix.** Use only filenames or workspace-relative paths (e.g., `data.csv`, `reports/summary.md`). All assets are copied into the agent's workspace at runtime; `assets/` is invisible to the agent. + +### Step 3 — Prepare Environment + +Define the minimal runtime dependencies and write `scripts/env_spec.sh` to fulfill them. + +**Design:** Record what the runtime needs — +- OS / base image assumptions (if any) +- Language runtimes and version constraints +- Required packages (system, pip, npm, etc.) +- Environment variables (API keys, config values) +- Background services and their initial state (e.g., "SQLite DB with schema X, seeded with N rows") + +**Implement:** Refer to `references/synthesis-patterns.md` (scripts/ section) for the base template; adapt to the dependencies above. Add auxiliary scripts as needed (e.g., `init_db.py`, `seed_data.py`, `start_server.py`). + +**Verify (static checklist):** +- Every dependency checks for existence before installing (e.g., `command -v` for CLI tools, `python3 -c "import X"` for pip packages) +- Language packages use user-local install (`pip install --user`, local `node_modules`) +- System CLI tools (`ffmpeg`, `jq`, etc.) check existence first, then fall back to `sudo apt-get` only if missing, and fail with a clear error if sudo is unavailable +- Every environment variable has a corresponding `export` line +- Every background service has a start command and a readiness check +- Every one-time setup step (DB init, data seeding) is present and runs after its dependency is ready +- Script uses only `$WORKSPACE`, `$TASK_PATH`, `$HOME` — no hardcoded absolute paths +- Script exits non-zero on any failure (`set -e` or explicit error handling) + +Do not proceed to Step 4 until verification passes. + +### Step 4 — Prepare Tools + +Decide which tools the agent can call, then synthesize and verify them. + +For each tool, record: +- **Name**, **description**, **input schema**, **output schema** +- **Type** — one of: + - `built-in`: already provided by OpenClaw; check `assets/built_in_tools.json` first; no plugin needed + - `static mock`: return value is hardcoded regardless of input; asset files for this tool do not need real content + - `input-derived mock`: return value is computed from the actual input file; asset files **must** be real (synthesized in Step 5) + +> **Format fidelity:** mock return values must look like what a real implementation would return — correct field names, envelope structure, data types. A bare string where structured JSON is expected misleads the agent. + +**Implement:** For each `static mock` / `input-derived mock` tool, synthesize a plugin under `tools/{tool_name}/` following the **`openclaw-plugin-creator`** skill. `built-in` tools need no plugin. + +**Verify (static checklist):** +- Every `static mock` / `input-derived mock` tool has a plugin directory under `tools/{tool_name}/` + +Do not proceed to Step 5 until verification passes. + +### Step 5 — Prepare Assets + +Decide which input files the agent needs in its workspace, then create them. + +Keep assets consistent with the environment (Step 3) and tools (Step 4): exported files should match seeded DB records; files consumed by `input-derived mock` tools must be real files whose content drives those tools' return values. + +Create each file at `assets/{path}`. Generate realistic, self-consistent content. Fall back to a placeholder only when real content cannot be produced and the file will be consumed exclusively by `static mock` tools. + +### Step 6 — Expected Behavior & Grading + +**Expected behavior:** With query (Step 2) and conditions (Steps 3–5) fixed, reason through what a correct agent execution looks like. Document: +- **Actions**: tool calls, file operations, decisions the agent should make +- **Intermediate outputs**: transient artifacts needed for subsequent steps +- **Final outputs**: the concrete deliverable — file(s) written, message sent, state changed — and expected format/content + +**Grading:** Convert expected behaviors into a complete grading spec (must sum to 100%). + +For any criterion whose expected value depends on synthesized assets, compute it now from the actual files (run the calculation, compute the hash, etc.). + +Classify each criterion: +- *Can a script check this without ambiguity?* → `automated` +- *Requires understanding intent, quality, or semantic correctness?* → `llm-judge` + +When the conditions lock down the output to a deterministic value, grade the value. When the output is non-deterministic or the task evaluates process/reasoning, grade behavior or quality instead. + +**Part 1 — Automated** *(omit if none)* +- List each criterion as a bullet with percentage weight +- Provide a fully implemented `grade()` function — no `...` or `pass`; criterion key strings must match bullet text exactly + +**Part 2 — LLM Judge** *(omit if none)* +- List each criterion as a bullet with percentage weight +- Provide rubrics: full-score, zero-score, and any partial-credit gradations + +### Step 7 — Write task.md + +With all assets, tools, environment, and grading values known, produce `task.md` following the schema in `references/task-schema.md`. Use `references/task-blank-template.md` as the starting template — fill in every section; leave no placeholders. + +Inputs: +- Frontmatter values: status, difficulty, timeout, grading_mode, family, archetype +- Prompt (Step 2) +- Expected Behavior and Grading Criteria (Step 6) +- Assets Preparation, Toolset Preparation, Environment Specification (Steps 3–5) + +**Do not proceed to Step 8 until `task.md` is written to disk.** + +### Step 8 — Validate + +**Hard gate. Do not deliver until every item passes. For each failure: fix it, then re-check before continuing.** + +- [ ] `task.md` exists at the task package root +- [ ] All frontmatter fields present; no `TODO` placeholders remain +- [ ] `workspace_files` lists every asset path correctly +- [ ] Grading percentages sum to 100 +- [ ] Grading Criteria match query type (output-focused vs behavior-focused) +- [ ] `grade()` is fully implemented — no `...` or `pass` +- [ ] `grade()` criterion key strings match Grading Criteria bullets exactly +- [ ] Every asset in **Assets Preparation** exists in `assets/` +- [ ] Every tool in **Toolset Preparation** exists in `tools/` +- [ ] `scripts/env_spec.sh` exists and runs without errors +- [ ] All environment dependencies in **Environment Specification** are covered by `scripts/env_spec.sh` + +--- + +## Output Directory Layout + +``` +{task-id}/ +├── task.md +├── assets/ +│ ├── {file-or-dir-1} +│ └── {file-or-dir-2} +├── tools/ +│ └── {tool_name} ← OpenClaw-loadable tool definitions +└── scripts/ + ├── env_spec.sh ← always present + └── {helper_script}.py ← optional +``` + +--- + +## Quick Reference + +**Difficulty heuristics:** +- `low` — single-step, unambiguous output, < 60 s timeout +- `medium` — multi-step, some judgment or environment inspection, 60–180 s timeout +- `high` — open-ended, complex reasoning or multi-turn clarification, > 180 s timeout + +**Grading mode:** +- `automated` — all criteria deterministic (file check, regex, unit test, transcript pattern match) +- `llm-judge` — all criteria semantic (quality, reasoning, clarification behavior) +- `hybrid` — mix; default for most tasks + +**Grading derives from query + conditions together:** +- If conditions provide everything needed → grading is output-correctness → prefer automated +- If conditions omit something the query implies → grading is agent behavior (did it ask? inspect? propose?) → prefer llm-judge + transcript checks +- If query leaves the approach open → grading is reasoning quality and option coverage → prefer llm-judge + +**Step ordering rationale:** +- Step 1: understand skill → 1.1 classify (family + archetype) → 1.2 capability boundary (Provides / Requires) +- Step 2: write prompt (natural, goal-oriented; query and conditions designed as a pair) +- Step 3: prepare environment — design deps, write `env_spec.sh`, verify +- Step 4: prepare tools — design sufficient/actual sets, synthesize plugins, verify +- Step 5: prepare assets — design sufficient/actual sets, synthesize files +- Step 6: expected behavior + grading (precompute values, write `grade()`, LLM judge rubrics) +- Step 7: write `task.md` via task-markdown-editor skill +- Step 8: validate — hard gate, fix and re-check before delivering + +**Status lifecycle:** `preparing` → `ready` → `deprecated` diff --git a/workspace/skills/task-synthesis/assets/built_in_tools.json b/workspace/skills/task-synthesis/assets/built_in_tools.json new file mode 100644 index 00000000..55a7c4e1 --- /dev/null +++ b/workspace/skills/task-synthesis/assets/built_in_tools.json @@ -0,0 +1,519 @@ +{ + "_source": "openclaw built-in tools extracted from dist/tool-catalog-BjSY4C4F.js", + "_version": "2026.3.23-2", + "sections": [ + { + "id": "fs", + "label": "Files", + "tools": [ + { + "id": "read", + "label": "read", + "description": "Read file contents", + "profiles": [ + "coding" + ] + }, + { + "id": "write", + "label": "write", + "description": "Create or overwrite files", + "profiles": [ + "coding" + ] + }, + { + "id": "edit", + "label": "edit", + "description": "Make precise edits", + "profiles": [ + "coding" + ] + }, + { + "id": "apply_patch", + "label": "apply_patch", + "description": "Patch files (OpenAI)", + "profiles": [ + "coding" + ] + } + ] + }, + { + "id": "runtime", + "label": "Runtime", + "tools": [ + { + "id": "exec", + "label": "exec", + "description": "Run shell commands", + "profiles": [ + "coding" + ] + }, + { + "id": "process", + "label": "process", + "description": "Manage background processes", + "profiles": [ + "coding" + ] + } + ] + }, + { + "id": "web", + "label": "Web", + "tools": [ + { + "id": "web_search", + "label": "web_search", + "description": "Search the web", + "profiles": [ + "coding" + ] + }, + { + "id": "web_fetch", + "label": "web_fetch", + "description": "Fetch web content", + "profiles": [ + "coding" + ] + } + ] + }, + { + "id": "memory", + "label": "Memory", + "tools": [ + { + "id": "memory_search", + "label": "memory_search", + "description": "Semantic search", + "profiles": [ + "coding" + ] + }, + { + "id": "memory_get", + "label": "memory_get", + "description": "Read memory files", + "profiles": [ + "coding" + ] + } + ] + }, + { + "id": "sessions", + "label": "Sessions", + "tools": [ + { + "id": "sessions_list", + "label": "sessions_list", + "description": "List sessions", + "profiles": [ + "coding", + "messaging" + ] + }, + { + "id": "sessions_history", + "label": "sessions_history", + "description": "Session history", + "profiles": [ + "coding", + "messaging" + ] + }, + { + "id": "sessions_send", + "label": "sessions_send", + "description": "Send to session", + "profiles": [ + "coding", + "messaging" + ] + }, + { + "id": "sessions_spawn", + "label": "sessions_spawn", + "description": "Spawn sub-agent", + "profiles": [ + "coding" + ] + }, + { + "id": "sessions_yield", + "label": "sessions_yield", + "description": "End turn to receive sub-agent results", + "profiles": [ + "coding" + ] + }, + { + "id": "subagents", + "label": "subagents", + "description": "Manage sub-agents", + "profiles": [ + "coding" + ] + }, + { + "id": "session_status", + "label": "session_status", + "description": "Session status", + "profiles": [ + "minimal", + "coding", + "messaging" + ] + } + ] + }, + { + "id": "ui", + "label": "UI", + "tools": [ + { + "id": "browser", + "label": "browser", + "description": "Control web browser", + "profiles": [] + }, + { + "id": "canvas", + "label": "canvas", + "description": "Control canvases", + "profiles": [] + } + ] + }, + { + "id": "messaging", + "label": "Messaging", + "tools": [ + { + "id": "message", + "label": "message", + "description": "Send messages", + "profiles": [ + "messaging" + ] + } + ] + }, + { + "id": "automation", + "label": "Automation", + "tools": [ + { + "id": "cron", + "label": "cron", + "description": "Schedule tasks", + "profiles": [ + "coding" + ] + }, + { + "id": "gateway", + "label": "gateway", + "description": "Gateway control", + "profiles": [] + } + ] + }, + { + "id": "nodes", + "label": "Nodes", + "tools": [ + { + "id": "nodes", + "label": "nodes", + "description": "Nodes + devices", + "profiles": [] + } + ] + }, + { + "id": "agents", + "label": "Agents", + "tools": [ + { + "id": "agents_list", + "label": "agents_list", + "description": "List agents", + "profiles": [] + } + ] + }, + { + "id": "media", + "label": "Media", + "tools": [ + { + "id": "image", + "label": "image", + "description": "Image understanding", + "profiles": [ + "coding" + ] + }, + { + "id": "image_generate", + "label": "image_generate", + "description": "Image generation", + "profiles": [ + "coding" + ] + }, + { + "id": "tts", + "label": "tts", + "description": "Text-to-speech conversion", + "profiles": [] + } + ] + } + ], + "tools": [ + { + "id": "read", + "label": "read", + "description": "Read file contents", + "section": "fs", + "profiles": [ + "coding" + ] + }, + { + "id": "write", + "label": "write", + "description": "Create or overwrite files", + "section": "fs", + "profiles": [ + "coding" + ] + }, + { + "id": "edit", + "label": "edit", + "description": "Make precise edits", + "section": "fs", + "profiles": [ + "coding" + ] + }, + { + "id": "apply_patch", + "label": "apply_patch", + "description": "Patch files (OpenAI)", + "section": "fs", + "profiles": [ + "coding" + ] + }, + { + "id": "exec", + "label": "exec", + "description": "Run shell commands", + "section": "runtime", + "profiles": [ + "coding" + ] + }, + { + "id": "process", + "label": "process", + "description": "Manage background processes", + "section": "runtime", + "profiles": [ + "coding" + ] + }, + { + "id": "web_search", + "label": "web_search", + "description": "Search the web", + "section": "web", + "profiles": [ + "coding" + ] + }, + { + "id": "web_fetch", + "label": "web_fetch", + "description": "Fetch web content", + "section": "web", + "profiles": [ + "coding" + ] + }, + { + "id": "memory_search", + "label": "memory_search", + "description": "Semantic search", + "section": "memory", + "profiles": [ + "coding" + ] + }, + { + "id": "memory_get", + "label": "memory_get", + "description": "Read memory files", + "section": "memory", + "profiles": [ + "coding" + ] + }, + { + "id": "sessions_list", + "label": "sessions_list", + "description": "List sessions", + "section": "sessions", + "profiles": [ + "coding", + "messaging" + ] + }, + { + "id": "sessions_history", + "label": "sessions_history", + "description": "Session history", + "section": "sessions", + "profiles": [ + "coding", + "messaging" + ] + }, + { + "id": "sessions_send", + "label": "sessions_send", + "description": "Send to session", + "section": "sessions", + "profiles": [ + "coding", + "messaging" + ] + }, + { + "id": "sessions_spawn", + "label": "sessions_spawn", + "description": "Spawn sub-agent", + "section": "sessions", + "profiles": [ + "coding" + ] + }, + { + "id": "sessions_yield", + "label": "sessions_yield", + "description": "End turn to receive sub-agent results", + "section": "sessions", + "profiles": [ + "coding" + ] + }, + { + "id": "subagents", + "label": "subagents", + "description": "Manage sub-agents", + "section": "sessions", + "profiles": [ + "coding" + ] + }, + { + "id": "session_status", + "label": "session_status", + "description": "Session status", + "section": "sessions", + "profiles": [ + "minimal", + "coding", + "messaging" + ] + }, + { + "id": "browser", + "label": "browser", + "description": "Control web browser", + "section": "ui", + "profiles": [] + }, + { + "id": "canvas", + "label": "canvas", + "description": "Control canvases", + "section": "ui", + "profiles": [] + }, + { + "id": "message", + "label": "message", + "description": "Send messages", + "section": "messaging", + "profiles": [ + "messaging" + ] + }, + { + "id": "cron", + "label": "cron", + "description": "Schedule tasks", + "section": "automation", + "profiles": [ + "coding" + ] + }, + { + "id": "gateway", + "label": "gateway", + "description": "Gateway control", + "section": "automation", + "profiles": [] + }, + { + "id": "nodes", + "label": "nodes", + "description": "Nodes + devices", + "section": "nodes", + "profiles": [] + }, + { + "id": "agents_list", + "label": "agents_list", + "description": "List agents", + "section": "agents", + "profiles": [] + }, + { + "id": "image", + "label": "image", + "description": "Image understanding", + "section": "media", + "profiles": [ + "coding" + ] + }, + { + "id": "image_generate", + "label": "image_generate", + "description": "Image generation", + "section": "media", + "profiles": [ + "coding" + ] + }, + { + "id": "tts", + "label": "tts", + "description": "Text-to-speech conversion", + "section": "media", + "profiles": [] + } + ] +} diff --git a/workspace/skills/task-synthesis/references/synthesis-patterns.md b/workspace/skills/task-synthesis/references/synthesis-patterns.md new file mode 100644 index 00000000..6cfe15b9 --- /dev/null +++ b/workspace/skills/task-synthesis/references/synthesis-patterns.md @@ -0,0 +1,473 @@ +# Synthesis Patterns + +Patterns for generating scripts/, tools/, assets/, and grading logic. + +## Table of Contents +1. [scripts/ Template and Patterns](#scripts-template-and-patterns) +2. [Tool Synthesis Patterns](#tool-synthesis-patterns) +3. [Asset Synthesis Patterns](#asset-synthesis-patterns) +4. [Automated Grader Code Patterns](#automated-grader-code-patterns) + - [Output-focused graders](#output-focused-graders) + - [Behavior-focused graders (agent asked / inspected / proposed)](#behavior-focused-graders) +5. [LLM Judge Rubric Anchors](#llm-judge-rubric-anchors) + - [Output quality](#output-quality-content-correctness--completeness--format) + - [Reasoning / approach quality](#reasoning--approach-quality) + - [Clarification behavior](#clarification-behavior-underspecified--ambiguous-tasks) + +--- + +## scripts/ Template and Patterns + +`scripts/env_spec.sh` is always required. Additional helper scripts (init_db.py, seed_data.py, start_server.py, etc.) also go in `scripts/`. + +### env_spec.sh base template + +> ⚠️ Replace all `{PLACEHOLDER}` values before using. +> +> **Installation rules (strictly enforced):** +> - **Check before install.** Don't blindly install — first check if the binary/package already exists. Common runtimes (python3, node, git, curl) are usually pre-installed. +> - **User-local first.** For language packages, always install to user-local paths: `pip install --user`, `npm install` (local `node_modules`), `$HOME/.local/bin`, etc. +> - **System packages as fallback.** For tools that can only be installed system-wide (ffmpeg, imagemagick, etc.), check existence first, then try `sudo apt-get` only if missing. If sudo is unavailable, fail with a clear error message — never silently skip. +> +> **Path rules (strictly enforced):** +> - Use `$WORKSPACE`, `$TASK_PATH`, `$HOME` — never hardcode absolute paths like `/tmp/`, `/workspace/`, `/usr/local/bin/` +> - Output files and generated artifacts go under `$WORKSPACE`, not `/tmp/` or any fixed system path +> - To make a mock script available on PATH, use `export PATH="$HOME/.local/bin:$PATH"` rather than copying to `/usr/local/bin/` +> - Mock user-environment dependencies (e.g., a CLI tool that requires GPU/manual install) in `env_spec.sh` when the evaluation environment cannot satisfy them — but keep mocks path-clean using the rules above + +```bash +#!/usr/bin/env bash +# scripts/env_spec.sh — Task execution environment initialization +# Task: {task-id} +set -euo pipefail + +WORKSPACE="${WORKSPACE_PATH:?Error: WORKSPACE_PATH is not set}" +TASK_PATH="${TASK_PATH:?Error: TASK_PATH is not set}" + +# User-local bin on PATH +export PATH="$HOME/.local/bin:$PATH" + +echo "=== Copying assets to workspace ===" +mkdir -p "$WORKSPACE" +cp -r "$TASK_PATH/assets/." "$WORKSPACE/" + +echo "=== Checking & installing system CLI tools ===" +# Check first; install only if missing and sudo is available +require_cli() { + if command -v "$1" &>/dev/null; then + echo " ✓ $1 already installed" + elif sudo -n true 2>/dev/null; then + echo " ⟳ Installing $1 via apt-get..." + sudo apt-get update -qq && sudo apt-get install -y --no-install-recommends "$1" + else + echo " ✗ $1 not found and sudo not available" >&2; exit 1 + fi +} +# require_cli "ffmpeg" +# require_cli "jq" + +echo "=== Checking & installing Python packages ===" +# Check before install; user-local only +install_pip_pkg() { + python3 -c "import $1" 2>/dev/null || pip install --user --quiet "$2" +} +# install_pip_pkg "pandas" "pandas==2.1.0" +# install_pip_pkg "requests" "requests==2.31.0" + +echo "=== Installing custom tool plugins ===" +bash "$TASK_PATH/scripts/plugin_install.sh" + +echo "=== Environment variables ===" +# export {ENV_VAR_NAME}="{value}" + +echo "=== One-time setup ===" +# python3 "$TASK_PATH/scripts/init_db.py" + +echo "=== Environment ready ===" +``` + +### Patterns for common environments + +**System CLI tools (check → sudo fallback → fail):** +```bash +require_cli() { + if command -v "$1" &>/dev/null; then + echo " ✓ $1 already installed" + elif sudo -n true 2>/dev/null; then + sudo apt-get update -qq && sudo apt-get install -y --no-install-recommends "$1" + else + echo " ✗ $1 not found and sudo not available" >&2; exit 1 + fi +} +require_cli "ffmpeg" +require_cli "jq" +``` + +**Python packages (user-local, check first):** +```bash +install_pip_pkg() { + python3 -c "import $1" 2>/dev/null || pip install --user --quiet "$2" +} +install_pip_pkg "pandas" "pandas==2.1.0" +install_pip_pkg "numpy" "numpy" +install_pip_pkg "sklearn" "scikit-learn" +``` + +**Node.js packages (local node_modules):** +```bash +cd "$WORKSPACE" && npm install +``` + +**SQLite database init:** +```bash +python3 "$TASK_PATH/scripts/init_db.py" +``` + +**No network access (default):** +```bash +# No network setup needed — task runs offline +``` + +**Service startup (e.g., local API mock):** +```bash +python3 "$TASK_PATH/scripts/start_server.py" & +sleep 1 # wait for server to start +``` + +--- + +## Tool Synthesis Patterns + +Tool synthesis is governed by the **`openclaw-plugin-creator`** skill — load it for the complete structure (`openclaw.plugin.json`, `package.json`, `src/index.ts`) and implementation guidance. + +Task-specific note: tools in `tools/` are invoked **by the agent during evaluation**. When the task has a fixed expected output, implement a deterministic mock rather than a live integration — the mock should return realistic, precomputed data consistent with the synthesized assets. + +### plugin_install.sh + +After synthesizing all tool plugins, create `scripts/plugin_install.sh` to install them. This script is referenced in the Step 4 static checklist. + +```bash +#!/usr/bin/env bash +# scripts/plugin_install.sh — Install all custom tool plugins for this task +# Task: {task-id} +set -euo pipefail + +TASK_PATH="${TASK_PATH:?Error: TASK_PATH is not set}" + +echo "=== Installing custom tool plugins ===" +for plugin_dir in "$TASK_PATH/tools"/*/; do + if [ -f "$plugin_dir/openclaw.plugin.json" ]; then + echo "Installing plugin: $(basename "$plugin_dir")" + cd "$plugin_dir" && openclaw plugins install . + fi +done + +echo "=== Plugin installation complete ===" +``` + +--- + +## Asset Synthesis Patterns + +### Principle: Realistic, Self-Consistent Content + +Assets must be realistic enough that an agent can complete the task without confusion. Avoid Lorem Ipsum unless the task is explicitly about text corpora. + +### Text / Markdown files + +Generate content matching the task topic. Include realistic structure (headings, lists, code blocks) if the task involves document processing. + +### CSV / JSON data files + +- Define a clear schema in the task's **Assets Preparation** section first +- Generate 20–100 rows; more for data analysis tasks, fewer for simple parsing tasks +- Ensure the "expected answer" in the grader is precomputed from the generated data +- Use consistent column names and types; avoid nulls unless the task tests null handling + +```python +# Example: generating a CSV asset programmatically +import csv, random, io + +rows = [{"id": i, "name": f"Student_{i}", "score": round(random.uniform(40, 100), 2)} + for i in range(1, 51)] +out = io.StringIO() +writer = csv.DictWriter(out, fieldnames=["id", "name", "score"]) +writer.writeheader() +writer.writerows(rows) +csv_content = out.getvalue() +``` + +### Code / Project directories + +Scaffold a realistic project structure. Include: +- A main entry point (e.g., `main.py`, `index.js`) +- At least one module with meaningful stub functions +- A README or docstring explaining what the project is supposed to do +- Any config files the task requires (e.g., `requirements.txt`, `package.json`) + +### Database files (SQLite) + +Provide a Python script at `scripts/init_db.py` that creates and populates the database, and call it from `env_spec.sh`. Do not commit binary `.sqlite` files directly. + +```python +# scripts/init_db.py +import sqlite3, os + +# Use WORKSPACE_PATH env var (injected at runtime); fall back to cwd for local testing +workspace = os.environ.get("WORKSPACE_PATH", os.getcwd()) +db_path = os.environ.get("DB_PATH", os.path.join(workspace, "db.sqlite")) +conn = sqlite3.connect(db_path) +conn.execute("""CREATE TABLE IF NOT EXISTS records ( + id INTEGER PRIMARY KEY, + name TEXT, + value REAL +)""") +conn.executemany("INSERT INTO records VALUES (?, ?, ?)", [ + (1, "alpha", 42.0), + (2, "beta", 17.5), + (3, "gamma", 99.1), +]) +conn.commit() +conn.close() +print(f"Database initialized at {db_path}") +``` + +--- + +## Automated Grader Code Patterns + +Two sub-types, pick based on what the task actually grades: + +- **Output-focused** — check files or values the agent produced (existence, content, format, numeric accuracy). Primary evidence is `workspace_path`. +- **Behavior-focused** — check how the agent acted (which tools it called, what it said, whether it asked before proceeding). Primary evidence is `transcript`. + +Both types can coexist in a single `grade()` function; just mix snippets as needed. + +### Output-focused graders + +### File existence + +```python +exists = os.path.isfile(os.path.join(workspace_path, "output.md")) +scores["File `output.md` created"] = 1.0 if exists else 0.0 +``` + +### Non-empty content + +```python +content = open(path).read().strip() if os.path.isfile(path) else "" +scores["Output is non-empty"] = 1.0 if content else 0.0 +``` + +### Regex match + +```python +import re +scores["Output contains a Markdown heading"] = ( + 1.0 if re.search(r"^#{1,6} \w", content, re.MULTILINE) else 0.0 +) +``` + +### JSON validity + schema check + +```python +import json +try: + data = json.loads(content) + scores["Output is valid JSON"] = 1.0 + scores["JSON has required keys"] = 1.0 if {"name", "value"} <= data.keys() else 0.0 +except (json.JSONDecodeError, AttributeError): + scores["Output is valid JSON"] = 0.0 + scores["JSON has required keys"] = 0.0 +``` + +### Script execution + +```python +import subprocess +result = subprocess.run( + ["python3", os.path.join(workspace_path, "solution.py")], + capture_output=True, timeout=30, cwd=workspace_path +) +scores["Script runs without error"] = 1.0 if result.returncode == 0 else 0.0 +``` + +### Numeric accuracy + +```python +try: + val = float(open(os.path.join(workspace_path, "output.txt")).read().strip()) + expected = 73.42 + scores["Value within ±0.01 of expected"] = 1.0 if abs(val - expected) <= 0.01 else 0.0 +except (ValueError, FileNotFoundError): + scores["Value within ±0.01 of expected"] = 0.0 +``` + +### Transcript-based check + +```python +tool_calls = [m.get("tool_name") for m in transcript if m.get("type") == "tool_call"] +scores["Agent used query_database tool"] = 1.0 if "query_database" in tool_calls else 0.0 +``` + +### Partial credit + +```python +lines = [l for l in content.splitlines() if l.strip()] +scores["Output has sufficient lines"] = min(1.0, len(lines) / 10) +``` + +### Safe wrapper (always use in grade()) + +```python +def grade(transcript: list, workspace_path: str) -> dict: + import os, re + scores = {} + try: + # ... all grading logic here — no `pass` or `...` placeholders ... + scores["criterion text"] = 1.0 # replace with real logic + except Exception: + # Never let grade() raise — default unscored criteria to 0.0 + for key in ["criterion text"]: # list all expected keys + scores.setdefault(key, 0.0) + return scores +``` + +--- + +### Behavior-focused graders (agent asked / inspected / proposed) + +For tasks where conditions omit something the query implies, the agent's **transcript** is the primary evidence. Correct behavior is about how the agent responds to the gap, not what file it produces. + +#### Helper: extract all agent text turns + +```python +def _agent_text(transcript: list) -> str: + """Concatenate all assistant/agent text messages into one searchable string.""" + parts = [] + for m in transcript: + if m.get("role") in ("assistant", "agent") and isinstance(m.get("content"), str): + parts.append(m["content"]) + return "\n".join(parts).lower() +``` + +#### Check: agent named the missing parameter + +```python +agent_text = _agent_text(transcript) +# Replace "output format" with the actual missing parameter name +named_gap = any(kw in agent_text for kw in ["output format", "format", "which format"]) +scores["Agent identified the missing parameter by name"] = 1.0 if named_gap else 0.0 +``` + +#### Check: agent inspected the environment before responding + +```python +tool_calls = [m.get("tool_name", "") for m in transcript if m.get("type") == "tool_call"] +env_tools = {"list_files", "read_file", "read_config", "ls", "cat"} # adjust to task toolset +inspected = bool(env_tools & set(tool_calls)) +scores["Agent inspected environment before responding"] = 1.0 if inspected else 0.0 +``` + +#### Check: agent proposed concrete options (not a vague "please clarify") + +```python +import re +agent_text = _agent_text(transcript) +# Look for enumeration patterns: "1.", "- option", "either X or Y", etc. +has_options = bool(re.search( + r"(\b(option|choice|either|or)\b.*\b(option|choice|either|or)\b" + r"|\b1[\.\)]\s+\w|\b-\s+\w.*\n\s*-\s+\w)", + agent_text, re.IGNORECASE | re.DOTALL +)) +scores["Agent proposed 2+ concrete options with reasoning"] = 1.0 if has_options else 0.0 +``` + +#### Check: agent did NOT silently proceed to produce output + +```python +import os +# The task's expected output file should NOT exist if agent correctly withheld action +output_exists = os.path.isfile(os.path.join(workspace_path, "output.txt")) +scores["Agent did not silently produce output without resolving ambiguity"] = ( + 0.0 if output_exists else 1.0 +) +``` + +#### Check: agent asked a targeted clarifying question + +```python +agent_text = _agent_text(transcript) +# Look for question marks and parameter-related keywords in the same sentence +sentences = re.split(r"[.!]\s+", agent_text) +asked = any("?" in s and any(kw in s for kw in ["format", "output", "destination", "version"]) + for s in sentences) +scores["Agent asked a targeted clarifying question"] = 1.0 if asked else 0.0 +``` + +*For LLM judge rubric templates matching these graders, see [LLM Judge Rubric Anchors → Clarification behavior](#clarification-behavior-underspecified--ambiguous-tasks) below.* + +--- + +## LLM Judge Rubric Anchors + +Use LLM judge criteria whenever a script cannot capture the requirement — output quality, reasoning depth, or clarification behavior. The three most common scenarios: + +### Output quality (content correctness / completeness / format) + +Use when the output is a natural-language artifact (summary, explanation, report, code review, plan) and correctness requires understanding intent, not just pattern matching. + +```markdown +**{Criterion name} (0–1):** +- **1.0** — {describe what a fully correct, complete, well-formed response looks like} +- **0.5** — {describe a partially correct response: present but incomplete, slightly off, or minor format issue} +- **0.0** — {describe a clearly wrong response: missing, hallucinated, or structurally broken} +``` + +Example — "Summary covers all key points": +```markdown +**Summary covers all key points (0–1):** +- **1.0** — Summary mentions all {N} major topics from the source document with accurate facts +- **0.5** — Summary covers most topics but omits 1–2 important points, or includes minor inaccuracies +- **0.0** — Summary is missing, covers fewer than half the topics, or contains significant hallucinations +``` + +### Reasoning / approach quality + +Use when the task is open-ended and the path matters as much as (or more than) the final answer — e.g., debugging steps, investigation approach, multi-step planning. + +```markdown +**{Criterion name} (0–1):** +- **1.0** — Agent followed a logical, systematic approach: {describe the ideal sequence of steps or tool calls} +- **0.5** — Agent reached the correct conclusion but via an inefficient or partially incorrect route +- **0.0** — Agent skipped key steps, jumped to conclusions without evidence, or produced an incorrect result +``` + +Example — "Agent diagnosed root cause correctly": +```markdown +**Agent diagnosed root cause correctly (0–1):** +- **1.0** — Agent inspected relevant logs/files, identified the specific error and its cause, and stated it clearly +- **0.5** — Agent found the symptom but misidentified the root cause, or identified the cause without citing evidence +- **0.0** — Agent did not investigate, guessed without evidence, or identified an unrelated issue +``` + +### Clarification behavior (underspecified / ambiguous tasks) + +Use when the query intentionally omits required information and the correct behavior is to ask, not to proceed silently. + +```markdown +**Agent identified the missing parameter (0–1):** +- **1.0** — Agent explicitly named "{parameter}" as missing or unknown, before taking any action +- **0.5** — Agent expressed uncertainty but did not name the specific parameter +- **0.0** — Agent proceeded without acknowledging the gap, or asked a completely generic question + +**Agent proposed concrete options (0–1):** +- **1.0** — Agent listed 2+ specific, named options with a brief rationale for each +- **0.5** — Agent suggested options exist but did not name them, or listed only one +- **0.0** — Agent gave no options; asked only "please clarify" with no guidance + +**Agent did not proceed without required information (0–1):** +- **1.0** — Agent withheld all output-producing actions until the gap was resolved +- **0.5** — Agent produced partial output with an explicit caveat about the assumption made +- **0.0** — Agent silently assumed a value and produced output without flagging the assumption +``` diff --git a/workspace/skills/task-synthesis/references/task-blank-template.md b/workspace/skills/task-synthesis/references/task-blank-template.md new file mode 100644 index 00000000..7f413f15 --- /dev/null +++ b/workspace/skills/task-synthesis/references/task-blank-template.md @@ -0,0 +1,91 @@ +--- +id: {task-id} +name: {task-name} +scene: {scene} +category: {category} +family: {family} +archetype: {archetype} +skill_set: + - {skill-id} +grading_mode: automated|llm-judge|hybrid +timeout_seconds: 180 +workspace_files: + - assets/{filename-or-dir} +difficulty: low|medium|high +status: preparing|ready|deprecated +--- + +## Prompt + +{task prompt — natural language, goal-oriented user message} + +## Expected Behavior + +The agent should: + +- [ ] {action or intermediate step} +- [ ] {action or intermediate step} +- [ ] {final output produced} + +## Assets Preparation + +- `assets/{filename}` — {format; content description including schema, size, key values} + +## Toolset Preparation + +- `{tool_name}({param}: {type}, ...) -> {return_type}` — {description} + - Input: `{param}` ({type}) — {description} + - Output: `{structure description}` + - Type: built-in | static mock | input-derived mock + +## Environment Specification + +- Platform: Ubuntu 22.04 +- Python: 3.11 +- pip: {package==version, ...} +- System: none +- Network: none +- Environment variables: none +- Services: none + +## Grading Criteria + +### Automated Criteria ({N}%) + + + +- [ ] {deterministically checkable criterion} + +### LLM Judge Criteria ({N}%) + + + +- [ ] {semantically evaluated criterion} + +## Automated Checks + +```python +def grade(transcript: list, workspace_path: str) -> dict: + """ + Args: + transcript: Parsed JSONL conversation transcript as list of dicts. + workspace_path: Absolute path to the task's isolated workspace directory. + Returns: + Dict mapping each Automated Criteria bullet (exact text) to score 0.0–1.0. + """ + import os + + scores = {} + + # --- implement grading logic here --- + + return scores +``` + +## LLM Judge Rubric + +**{Criterion name} (0–1):** {what is being evaluated} + +- **1.0** — {excellent / fully correct} +- **0.5** — {partial / acceptable} +- **0.0** — {failure / missing} diff --git a/workspace/skills/task-synthesis/references/task-schema.md b/workspace/skills/task-synthesis/references/task-schema.md new file mode 100644 index 00000000..b7554b71 --- /dev/null +++ b/workspace/skills/task-synthesis/references/task-schema.md @@ -0,0 +1,194 @@ +# task.md Schema Reference + +Complete field-by-field reference for the task.md format. + +## Frontmatter Fields + +```yaml +--- +id: {task-id} # Provided in prompt — use as-is; do NOT modify +name: {task-name} # GENERATED: short, hyphen-linked, content-focused (see naming rules) +scene: {scene} # Provided in prompt — use as-is +category: {category} # Provided in prompt — use as-is +family: {family} # GENERATED: see allowed values below +archetype: {archetype} # GENERATED: see allowed values below +skill_set: + - {skill-id} # Provided in prompt — use as-is +grading_mode: automated|llm-judge|hybrid # Determined during synthesis +timeout_seconds: 180 # Adjust: low=60, medium=180, high=300+ +workspace_files: # List of ALL files provided to the agent + - assets/{filename-or-dir} +difficulty: low|medium|high +status: preparing|ready|deprecated +--- +``` + +### Field Rules + +**Fields provided in the prompt — use as-is, do NOT alter:** +- `id`, `scene`, `category`, `skill_set` + +**Fields always generated during task synthesis:** + +#### `name` +- Short, descriptive, hyphen-linked slug +- 3–6 words max; omit filler words ("the", "a", "task") +- Must reflect the concrete action and subject +- Examples: `extract-invoice-totals`, `rename-files-by-date`, `summarize-pr-diff`, `query-sqlite-schema` + +#### `family` +Must be exactly one of: + +`Documents & Files` / `Filesystem & Workspace` / `Code / Repo / IDE` / `Browser & Web` / `SaaS / API Integrations` / `Local Apps / OS Automation` / `Media` / `DevOps / System / Cloud` / `Search & Research` / `Agent Infra / Orchestration` / `Hybrid` + +#### `archetype` +Must be exactly one of: + +`retrieve` / `extract` / `transform` / `organize` / `analyze` / `generate` / `edit` / `execute` / `validate` / `monitor` / `automate` / `hybrid` + +#### `grading_mode` +- `automated` — all criteria are deterministically checkable via `grade()` +- `llm-judge` — all criteria require semantic evaluation +- `hybrid` — mix of both; default for most tasks + +#### `difficulty` +- `low` — single-step, unambiguous output, ≤60s +- `medium` — multi-step, some judgment or env inspection, 60–180s +- `high` — open-ended, complex reasoning or multi-turn clarification, >180s + +--- + +## Section Reference + +### `## Prompt` + +Real user message — natural language, goal-oriented. NOT a spec or step list. +- Write as the user would type it: terse, target-focused +- When referring to files, use only the filename or a relative path (e.g., `data.csv`, `reports/summary.csv`) — never absolute paths like `/workspace/data.csv` or path placeholders like `{file}`; all assets are placed in the agent's current workspace at runtime + +### `## Expected Behavior` + +Checklist for evaluators. Covers: key actions → intermediate outputs → final deliverable. + +Format: checkbox list (`- [ ]`), ordered by execution flow. + +Each item must be concrete and verifiable (not "does a good job"). +Each item should map to a Grading Criteria bullet. + +```markdown +The agent should: + +- [ ] Read `assets/input.csv` and parse its rows +- [ ] Compute the sum of the `amount` column +- [ ] Write a JSON file to `output.json` containing `{"total": }` +``` + +### `## Assets Preparation` + +All files the agent can read at runtime. Listed as bullet items. + +Each entry must include: +- File path (relative to workspace root, e.g., `assets/data.csv`) +- Format / file type +- Content description: schema, size, topic, key values, any constraints + +```markdown +- `assets/data.csv` — CSV, 50 rows, columns: `id` (int), `name` (str), `amount` (float); total of `amount` column = 1234.56 +- `assets/config.json` — JSON, fields: `output_dir` (str), `format` (str, one of "json"|"csv") +``` + +### `## Toolset Preparation` + +All custom plugins the agent may call. Listed as bullet items. + +Each entry must include: +- Tool name (snake_case) +- One-line description +- Input parameters: `name: type — description` +- Return value: type and structure +- Mock type: deterministic mock or live stub + +```markdown +- `get_file_meta(path: str) -> dict` — Returns file metadata (size, mtime, mime_type). Deterministic mock returning fixed values for known paths. + - Input: `path` (str) — absolute or workspace-relative file path + - Output: `{ "size": int, "mtime": str, "mime_type": str }` +``` + +Omit this section entirely if no custom tools are needed. + +### `## Environment Specification` + +Minimal runtime dependencies. Include only what is strictly required. + +```markdown +- Platform: Ubuntu 22.04 +- Python: 3.11 +- pip: pandas==2.1.0, requests==2.31.0 +- System: none +- Network: none +- Environment variables: none +``` + +### `## Grading Criteria` + +Two subsections summing to exactly 100%. +- `automated` tasks: 100% Automated, remove LLM section +- `llm-judge` tasks: 100% LLM, remove Automated section +- `hybrid` tasks: split as appropriate + +Each Automated bullet → key in `grade()` return dict (text must match exactly). +Each LLM bullet → dimension in LLM Judge Rubric. + +```markdown +### Automated Criteria (60%) + +- [ ] Output file `output.json` exists +- [ ] `total` field equals 1234.56 + +### LLM Judge Criteria (40%) + +- [ ] Agent explained its computation steps clearly +``` + +### `## Automated Checks` + +Python function `grade(transcript, workspace_path) -> dict`. + +Rules: +- Keys must exactly match Automated Criteria bullet text +- Scores: 0.0 (fail) to 1.0 (pass); partial credit allowed +- Never raise exceptions — catch all errors, default to 0.0 +- Standard library only unless package is guaranteed by Environment Specification +- No `...` or `pass` placeholders — fully implemented + +```python +def grade(transcript: list, workspace_path: str) -> dict: + import os, json + + scores = {} + + # criterion: "Output file `output.json` exists" + out_path = os.path.join(workspace_path, "output.json") + scores["Output file `output.json` exists"] = 1.0 if os.path.isfile(out_path) else 0.0 + + # criterion: "`total` field equals 1234.56" + try: + with open(out_path) as f: + data = json.load(f) + scores["`total` field equals 1234.56"] = 1.0 if abs(data.get("total", 0) - 1234.56) < 0.01 else 0.0 + except Exception: + scores["`total` field equals 1234.56"] = 0.0 + + return scores +``` + +### `## LLM Judge Rubric` + +One rubric block per LLM Judge Criteria bullet. Each block has explicit 1.0 / 0.5 / 0.0 anchors describing concrete, observable agent behaviors. + +```markdown +**Computation explanation (0–1):** Whether the agent explained how it computed the total. +- **1.0** — Agent explicitly stated the column used, the operation performed, and the result value +- **0.5** — Agent mentioned the computation but omitted one key detail (column name or result) +- **0.0** — Agent produced no explanation or explanation was factually wrong +``` diff --git a/workspace/skills/task-synthesis/scripts/plugin_install.sh b/workspace/skills/task-synthesis/scripts/plugin_install.sh new file mode 100755 index 00000000..a1b17138 --- /dev/null +++ b/workspace/skills/task-synthesis/scripts/plugin_install.sh @@ -0,0 +1,150 @@ +#!/usr/bin/env bash +# task_initialize.sh — Initialize a synthesized task package +# +# Usage: +# ./task_initialize.sh [skill_install_dir] +# +# Arguments: +# task_dir Path to the synthesized task directory (contains task.md, assets/, tools/, scripts/) +# skill_install_dir (optional) Directory to install the source skill into +# Default: /mnt/shared-storage-user/llmit/user/sunyanan/openclaw/runtime_skills +# +# What this script does: +# 1. Run scripts/env_spec.sh — set up the task execution environment +# 2. Install tools/ — install each custom tool plugin into OpenClaw via `openclaw plugins install --link` +# 3. Install the source skill — copy the skill directory to + +set -euo pipefail + +# ── Args ────────────────────────────────────────────────────────────────────── + +if [[ $# -lt 1 ]]; then + echo "Usage: $0 [skill_install_dir]" >&2 + exit 1 +fi + +TASK_DIR="$(realpath "$1")" +SKILL_INSTALL_DIR="$(realpath "${2:-/mnt/shared-storage-user/llmit/user/sunyanan/openclaw/runtime_skills}")" + +if [[ ! -d "$TASK_DIR" ]]; then + echo "Error: task_dir '$TASK_DIR' does not exist." >&2 + exit 1 +fi + +echo "=== task_initialize ===" +echo "Task dir : $TASK_DIR" +echo "Skill dir : $SKILL_INSTALL_DIR" +echo + +# ── Step 1: Run env_spec.sh ─────────────────────────────────────────────────── + +ENV_SCRIPT="$TASK_DIR/scripts/env_spec.sh" + +echo "--- Step 1: Running env_spec.sh ---" +if [[ -f "$ENV_SCRIPT" ]]; then + bash "$ENV_SCRIPT" + echo "env_spec.sh completed." +else + echo "WARNING: $ENV_SCRIPT not found, skipping environment setup." +fi +echo + +# ── Step 2: Install tools/ plugins into OpenClaw ───────────────────────────── + +TOOLS_DIR="$TASK_DIR/tools" + +echo "--- Step 2: Installing tool plugins ---" +if [[ -d "$TOOLS_DIR" ]]; then + TOOL_COUNT=0 + for tool_path in "$TOOLS_DIR"/*/; do + if [[ -d "$tool_path" ]]; then + tool_name="$(basename "$tool_path")" + echo " Installing plugin: $tool_name ($tool_path)" + openclaw plugins install --link "$tool_path" + TOOL_COUNT=$((TOOL_COUNT + 1)) + fi + done + if [[ $TOOL_COUNT -eq 0 ]]; then + echo " No custom tool plugins found in $TOOLS_DIR — skipping." + else + echo " Installed $TOOL_COUNT plugin(s)." + fi +else + echo " tools/ directory not found — skipping plugin installation." +fi +echo + +# ── Step 3: Install the source skill ───────────────────────────────────────── + +# Looks up the skill's relative_path from the skill catalog JSONL, then copies +# the skill directory (parent of SKILL.md) to . + +SKILL_CATALOG="/mnt/shared-storage-user/llmit/user/tangyinhao/skills/data/awesome_openclaw_skills_and_skillshtop100_tag_v01.jsonl" +SKILL_ROOT="/mnt/shared-storage-user/llmit/user/tangyinhao/skills" + +echo "--- Step 3: Installing source skill ---" + +# Extract skill_id from task.md frontmatter (skill_set: [skill-id]) +TASK_MD="$TASK_DIR/task.md" +if [[ ! -f "$TASK_MD" ]]; then + echo " WARNING: task.md not found at $TASK_MD — skipping skill installation." +else + SKILL_ID=$(grep -A1 'skill_set:' "$TASK_MD" | grep '^\s*-' | head -1 | sed 's/.*-\s*//' | tr -d '[:space:]') + echo " Detected skill_id from task.md: '$SKILL_ID'" + + if [[ -z "$SKILL_ID" || "$SKILL_ID" == "null" ]]; then + echo " WARNING: Could not parse skill_id from task.md — skipping skill installation." + elif [[ ! -f "$SKILL_CATALOG" ]]; then + echo " WARNING: Skill catalog not found at $SKILL_CATALOG — skipping skill installation." + else + # Look up relative_path by id in the JSONL catalog + RELATIVE_PATH=$(python3 - < $DEST" + mkdir -p "$SKILL_INSTALL_DIR" + cp -r "$SKILL_SOURCE_DIR" "$DEST" + echo " Skill installed to: $DEST" + fi + fi + fi +fi +echo + +echo "=== task_initialize complete ===" diff --git a/workspace/skills/tmux/SKILL.md b/workspace/skills/tmux/SKILL.md new file mode 100644 index 00000000..19d7f710 --- /dev/null +++ b/workspace/skills/tmux/SKILL.md @@ -0,0 +1,121 @@ +--- +name: tmux +description: Remote-control tmux sessions for interactive CLIs by sending keystrokes and scraping pane output. +metadata: {"internclaw":{"emoji":"🧵","os":["darwin","linux"],"requires":{"bins":["tmux"]}}} +--- + +# tmux Skill + +Use tmux only when you need an interactive TTY. Prefer exec background mode for long-running, non-interactive tasks. + +## Quickstart (isolated socket, exec tool) + +```bash +SOCKET_DIR="${internclaw_TMUX_SOCKET_DIR:-${TMPDIR:-/tmp}/internclaw-tmux-sockets}" +mkdir -p "$SOCKET_DIR" +SOCKET="$SOCKET_DIR/internclaw.sock" +SESSION=internclaw-python + +tmux -S "$SOCKET" new -d -s "$SESSION" -n shell +tmux -S "$SOCKET" send-keys -t "$SESSION":0.0 -- 'PYTHON_BASIC_REPL=1 python3 -q' Enter +tmux -S "$SOCKET" capture-pane -p -J -t "$SESSION":0.0 -S -200 +``` + +After starting a session, always print monitor commands: + +``` +To monitor: + tmux -S "$SOCKET" attach -t "$SESSION" + tmux -S "$SOCKET" capture-pane -p -J -t "$SESSION":0.0 -S -200 +``` + +## Socket convention + +- Use `internclaw_TMUX_SOCKET_DIR` environment variable. +- Default socket path: `"$internclaw_TMUX_SOCKET_DIR/internclaw.sock"`. + +## Targeting panes and naming + +- Target format: `session:window.pane` (defaults to `:0.0`). +- Keep names short; avoid spaces. +- Inspect: `tmux -S "$SOCKET" list-sessions`, `tmux -S "$SOCKET" list-panes -a`. + +## Finding sessions + +- List sessions on your socket: `{baseDir}/scripts/find-sessions.sh -S "$SOCKET"`. +- Scan all sockets: `{baseDir}/scripts/find-sessions.sh --all` (uses `internclaw_TMUX_SOCKET_DIR`). + +## Sending input safely + +- Prefer literal sends: `tmux -S "$SOCKET" send-keys -t target -l -- "$cmd"`. +- Control keys: `tmux -S "$SOCKET" send-keys -t target C-c`. + +## Watching output + +- Capture recent history: `tmux -S "$SOCKET" capture-pane -p -J -t target -S -200`. +- Wait for prompts: `{baseDir}/scripts/wait-for-text.sh -t session:0.0 -p 'pattern'`. +- Attaching is OK; detach with `Ctrl+b d`. + +## Spawning processes + +- For python REPLs, set `PYTHON_BASIC_REPL=1` (non-basic REPL breaks send-keys flows). + +## Windows / WSL + +- tmux is supported on macOS/Linux. On Windows, use WSL and install tmux inside WSL. +- This skill is gated to `darwin`/`linux` and requires `tmux` on PATH. + +## Orchestrating Coding Agents (Codex, Claude Code) + +tmux excels at running multiple coding agents in parallel: + +```bash +SOCKET="${TMPDIR:-/tmp}/codex-army.sock" + +# Create multiple sessions +for i in 1 2 3 4 5; do + tmux -S "$SOCKET" new-session -d -s "agent-$i" +done + +# Launch agents in different workdirs +tmux -S "$SOCKET" send-keys -t agent-1 "cd /tmp/project1 && codex --yolo 'Fix bug X'" Enter +tmux -S "$SOCKET" send-keys -t agent-2 "cd /tmp/project2 && codex --yolo 'Fix bug Y'" Enter + +# Poll for completion (check if prompt returned) +for sess in agent-1 agent-2; do + if tmux -S "$SOCKET" capture-pane -p -t "$sess" -S -3 | grep -q "❯"; then + echo "$sess: DONE" + else + echo "$sess: Running..." + fi +done + +# Get full output from completed session +tmux -S "$SOCKET" capture-pane -p -t agent-1 -S -500 +``` + +**Tips:** +- Use separate git worktrees for parallel fixes (no branch conflicts) +- `pnpm install` first before running codex in fresh clones +- Check for shell prompt (`❯` or `$`) to detect completion +- Codex needs `--yolo` or `--full-auto` for non-interactive fixes + +## Cleanup + +- Kill a session: `tmux -S "$SOCKET" kill-session -t "$SESSION"`. +- Kill all sessions on a socket: `tmux -S "$SOCKET" list-sessions -F '#{session_name}' | xargs -r -n1 tmux -S "$SOCKET" kill-session -t`. +- Remove everything on the private socket: `tmux -S "$SOCKET" kill-server`. + +## Helper: wait-for-text.sh + +`{baseDir}/scripts/wait-for-text.sh` polls a pane for a regex (or fixed string) with a timeout. + +```bash +{baseDir}/scripts/wait-for-text.sh -t session:0.0 -p 'pattern' [-F] [-T 20] [-i 0.5] [-l 2000] +``` + +- `-t`/`--target` pane target (required) +- `-p`/`--pattern` regex to match (required); add `-F` for fixed string +- `-T` timeout seconds (integer, default 15) +- `-i` poll interval seconds (default 0.5) +- `-l` history lines to search (integer, default 1000) diff --git a/workspace/skills/tmux/scripts/find-sessions.sh b/workspace/skills/tmux/scripts/find-sessions.sh new file mode 100755 index 00000000..2789c3be --- /dev/null +++ b/workspace/skills/tmux/scripts/find-sessions.sh @@ -0,0 +1,112 @@ +#!/usr/bin/env bash +set -euo pipefail + +usage() { + cat <<'USAGE' +Usage: find-sessions.sh [-L socket-name|-S socket-path|-A] [-q pattern] + +List tmux sessions on a socket (default tmux socket if none provided). + +Options: + -L, --socket tmux socket name (passed to tmux -L) + -S, --socket-path tmux socket path (passed to tmux -S) + -A, --all scan all sockets under internclaw_TMUX_SOCKET_DIR + -q, --query case-insensitive substring to filter session names + -h, --help show this help +USAGE +} + +socket_name="" +socket_path="" +query="" +scan_all=false +socket_dir="${internclaw_TMUX_SOCKET_DIR:-${TMPDIR:-/tmp}/internclaw-tmux-sockets}" + +while [[ $# -gt 0 ]]; do + case "$1" in + -L|--socket) socket_name="${2-}"; shift 2 ;; + -S|--socket-path) socket_path="${2-}"; shift 2 ;; + -A|--all) scan_all=true; shift ;; + -q|--query) query="${2-}"; shift 2 ;; + -h|--help) usage; exit 0 ;; + *) echo "Unknown option: $1" >&2; usage; exit 1 ;; + esac +done + +if [[ "$scan_all" == true && ( -n "$socket_name" || -n "$socket_path" ) ]]; then + echo "Cannot combine --all with -L or -S" >&2 + exit 1 +fi + +if [[ -n "$socket_name" && -n "$socket_path" ]]; then + echo "Use either -L or -S, not both" >&2 + exit 1 +fi + +if ! command -v tmux >/dev/null 2>&1; then + echo "tmux not found in PATH" >&2 + exit 1 +fi + +list_sessions() { + local label="$1"; shift + local tmux_cmd=(tmux "$@") + + if ! sessions="$("${tmux_cmd[@]}" list-sessions -F '#{session_name}\t#{session_attached}\t#{session_created_string}' 2>/dev/null)"; then + echo "No tmux server found on $label" >&2 + return 1 + fi + + if [[ -n "$query" ]]; then + sessions="$(printf '%s\n' "$sessions" | grep -i -- "$query" || true)" + fi + + if [[ -z "$sessions" ]]; then + echo "No sessions found on $label" + return 0 + fi + + echo "Sessions on $label:" + printf '%s\n' "$sessions" | while IFS=$'\t' read -r name attached created; do + attached_label=$([[ "$attached" == "1" ]] && echo "attached" || echo "detached") + printf ' - %s (%s, started %s)\n' "$name" "$attached_label" "$created" + done +} + +if [[ "$scan_all" == true ]]; then + if [[ ! -d "$socket_dir" ]]; then + echo "Socket directory not found: $socket_dir" >&2 + exit 1 + fi + + shopt -s nullglob + sockets=("$socket_dir"/*) + shopt -u nullglob + + if [[ "${#sockets[@]}" -eq 0 ]]; then + echo "No sockets found under $socket_dir" >&2 + exit 1 + fi + + exit_code=0 + for sock in "${sockets[@]}"; do + if [[ ! -S "$sock" ]]; then + continue + fi + list_sessions "socket path '$sock'" -S "$sock" || exit_code=$? + done + exit "$exit_code" +fi + +tmux_cmd=(tmux) +socket_label="default socket" + +if [[ -n "$socket_name" ]]; then + tmux_cmd+=(-L "$socket_name") + socket_label="socket name '$socket_name'" +elif [[ -n "$socket_path" ]]; then + tmux_cmd+=(-S "$socket_path") + socket_label="socket path '$socket_path'" +fi + +list_sessions "$socket_label" "${tmux_cmd[@]:1}" diff --git a/workspace/skills/tmux/scripts/wait-for-text.sh b/workspace/skills/tmux/scripts/wait-for-text.sh new file mode 100755 index 00000000..56354be8 --- /dev/null +++ b/workspace/skills/tmux/scripts/wait-for-text.sh @@ -0,0 +1,83 @@ +#!/usr/bin/env bash +set -euo pipefail + +usage() { + cat <<'USAGE' +Usage: wait-for-text.sh -t target -p pattern [options] + +Poll a tmux pane for text and exit when found. + +Options: + -t, --target tmux target (session:window.pane), required + -p, --pattern regex pattern to look for, required + -F, --fixed treat pattern as a fixed string (grep -F) + -T, --timeout seconds to wait (integer, default: 15) + -i, --interval poll interval in seconds (default: 0.5) + -l, --lines number of history lines to inspect (integer, default: 1000) + -h, --help show this help +USAGE +} + +target="" +pattern="" +grep_flag="-E" +timeout=15 +interval=0.5 +lines=1000 + +while [[ $# -gt 0 ]]; do + case "$1" in + -t|--target) target="${2-}"; shift 2 ;; + -p|--pattern) pattern="${2-}"; shift 2 ;; + -F|--fixed) grep_flag="-F"; shift ;; + -T|--timeout) timeout="${2-}"; shift 2 ;; + -i|--interval) interval="${2-}"; shift 2 ;; + -l|--lines) lines="${2-}"; shift 2 ;; + -h|--help) usage; exit 0 ;; + *) echo "Unknown option: $1" >&2; usage; exit 1 ;; + esac +done + +if [[ -z "$target" || -z "$pattern" ]]; then + echo "target and pattern are required" >&2 + usage + exit 1 +fi + +if ! [[ "$timeout" =~ ^[0-9]+$ ]]; then + echo "timeout must be an integer number of seconds" >&2 + exit 1 +fi + +if ! [[ "$lines" =~ ^[0-9]+$ ]]; then + echo "lines must be an integer" >&2 + exit 1 +fi + +if ! command -v tmux >/dev/null 2>&1; then + echo "tmux not found in PATH" >&2 + exit 1 +fi + +# End time in epoch seconds (integer, good enough for polling) +start_epoch=$(date +%s) +deadline=$((start_epoch + timeout)) + +while true; do + # -J joins wrapped lines, -S uses negative index to read last N lines + pane_text="$(tmux capture-pane -p -J -t "$target" -S "-${lines}" 2>/dev/null || true)" + + if printf '%s\n' "$pane_text" | grep $grep_flag -- "$pattern" >/dev/null 2>&1; then + exit 0 + fi + + now=$(date +%s) + if (( now >= deadline )); then + echo "Timed out after ${timeout}s waiting for pattern: $pattern" >&2 + echo "Last ${lines} lines from $target:" >&2 + printf '%s\n' "$pane_text" >&2 + exit 1 + fi + + sleep "$interval" +done diff --git a/workspace/skills/weather/SKILL.md b/workspace/skills/weather/SKILL.md new file mode 100644 index 00000000..6a5a0569 --- /dev/null +++ b/workspace/skills/weather/SKILL.md @@ -0,0 +1,49 @@ +--- +name: weather +description: Get current weather and forecasts (no API key required). +homepage: https://wttr.in/:help +metadata: {"internclaw":{"emoji":"🌤️","requires":{"bins":["curl"]}}} +--- + +# Weather + +Two free services, no API keys needed. + +## wttr.in (primary) + +Quick one-liner: +```bash +curl -s "wttr.in/London?format=3" +# Output: London: ⛅️ +8°C +``` + +Compact format: +```bash +curl -s "wttr.in/London?format=%l:+%c+%t+%h+%w" +# Output: London: ⛅️ +8°C 71% ↙5km/h +``` + +Full forecast: +```bash +curl -s "wttr.in/London?T" +``` + +Format codes: `%c` condition · `%t` temp · `%h` humidity · `%w` wind · `%l` location · `%m` moon + +Tips: +- URL-encode spaces: `wttr.in/New+York` +- Airport codes: `wttr.in/JFK` +- Units: `?m` (metric) `?u` (USCS) +- Today only: `?1` · Current only: `?0` +- PNG: `curl -s "wttr.in/Berlin.png" -o /tmp/weather.png` + +## Open-Meteo (fallback, JSON) + +Free, no key, good for programmatic use: +```bash +curl -s "https://api.open-meteo.com/v1/forecast?latitude=51.5&longitude=-0.12¤t_weather=true" +``` + +Find coordinates for a city, then query. Returns JSON with temp, windspeed, weathercode. + +Docs: https://open-meteo.com/en/docs