From b762c23346521bd281ebe229a78641c14d757f8f Mon Sep 17 00:00:00 2001 From: Nikhil Rao Date: Thu, 9 Oct 2025 10:22:58 +0530 Subject: [PATCH] refactor: Clean up browser action handling and fix f-string bug - Refactored browser actions into dictionary-based mapping for better organization - Modularized action handling functions for improved maintainability - Fixed f-string where safety['decision'] variable wasn't inside curly braces --- agent.py | 182 ++++++++++++++++++++++++++++++++----------------------- 1 file changed, 107 insertions(+), 75 deletions(-) diff --git a/agent.py b/agent.py index 2df602e..618a25d 100644 --- a/agent.py +++ b/agent.py @@ -115,84 +115,114 @@ def __init__( ], ) + self._action_handlers = { + "open_web_browser": self._handle_open_web_browser, + "click_at": self._handle_click_at, + "hover_at": self._handle_hover_at, + "type_text_at": self._handle_type_text_at, + "scroll_document": self._handle_scroll_document, + "scroll_at": self._handle_scroll_at, + "wait_5_seconds": self._handle_wait_5_seconds, + "go_back": self._handle_go_back, + "go_forward": self._handle_go_forward, + "search": self._handle_search, + "navigate": self._handle_navigate, + "key_combination": self._handle_key_combination, + "drag_and_drop": self._handle_drag_and_drop, + multiply_numbers.__name__: self._handle_multiply_numbers, + } + def handle_action(self, action: types.FunctionCall) -> FunctionResponseT: """Handles the action and returns the environment state.""" - if action.name == "open_web_browser": - return self._browser_computer.open_web_browser() - elif action.name == "click_at": - x = self.denormalize_x(action.args["x"]) - y = self.denormalize_y(action.args["y"]) - return self._browser_computer.click_at( - x=x, - y=y, - ) - elif action.name == "hover_at": - x = self.denormalize_x(action.args["x"]) - y = self.denormalize_y(action.args["y"]) - return self._browser_computer.hover_at( - x=x, - y=y, - ) - elif action.name == "type_text_at": - x = self.denormalize_x(action.args["x"]) - y = self.denormalize_y(action.args["y"]) - press_enter = action.args.get("press_enter", False) - clear_before_typing = action.args.get("clear_before_typing", True) - return self._browser_computer.type_text_at( - x=x, - y=y, - text=action.args["text"], - press_enter=press_enter, - clear_before_typing=clear_before_typing, - ) - elif action.name == "scroll_document": - return self._browser_computer.scroll_document(action.args["direction"]) - elif action.name == "scroll_at": - x = self.denormalize_x(action.args["x"]) - y = self.denormalize_y(action.args["y"]) - magnitude = action.args.get("magnitude", 800) - direction = action.args["direction"] - - if direction in ("up", "down"): - magnitude = self.denormalize_y(magnitude) - elif direction in ("left", "right"): - magnitude = self.denormalize_x(magnitude) - else: - raise ValueError("Unknown direction: ", direction) - return self._browser_computer.scroll_at( - x=x, y=y, direction=direction, magnitude=magnitude - ) - elif action.name == "wait_5_seconds": - return self._browser_computer.wait_5_seconds() - elif action.name == "go_back": - return self._browser_computer.go_back() - elif action.name == "go_forward": - return self._browser_computer.go_forward() - elif action.name == "search": - return self._browser_computer.search() - elif action.name == "navigate": - return self._browser_computer.navigate(action.args["url"]) - elif action.name == "key_combination": - return self._browser_computer.key_combination( - action.args["keys"].split("+") - ) - elif action.name == "drag_and_drop": - x = self.denormalize_x(action.args["x"]) - y = self.denormalize_y(action.args["y"]) - destination_x = self.denormalize_x(action.args["destination_x"]) - destination_y = self.denormalize_y(action.args["destination_y"]) - return self._browser_computer.drag_and_drop( - x=x, - y=y, - destination_x=destination_x, - destination_y=destination_y, - ) - # Handle the custom function declarations here. - elif action.name == multiply_numbers.__name__: - return multiply_numbers(x=action.args["x"], y=action.args["y"]) + if handler := self._action_handlers.get(action.name): + return handler(action) else: raise ValueError(f"Unsupported function: {action}") + def _handle_open_web_browser(self, action: types.FunctionCall) -> FunctionResponseT: + return self._browser_computer.open_web_browser() + + def _handle_click_at(self, action: types.FunctionCall) -> FunctionResponseT: + x = self.denormalize_x(action.args["x"]) + y = self.denormalize_y(action.args["y"]) + return self._browser_computer.click_at(x=x, y=y) + + def _handle_hover_at(self, action: types.FunctionCall) -> FunctionResponseT: + x = self.denormalize_x(action.args["x"]) + y = self.denormalize_y(action.args["y"]) + return self._browser_computer.hover_at(x=x, y=y) + + def _handle_type_text_at(self, action: types.FunctionCall) -> FunctionResponseT: + x = self.denormalize_x(action.args["x"]) + y = self.denormalize_y(action.args["y"]) + press_enter = action.args.get("press_enter", False) + clear_before_typing = action.args.get("clear_before_typing", True) + return self._browser_computer.type_text_at( + x=x, + y=y, + text=action.args["text"], + press_enter=press_enter, + clear_before_typing=clear_before_typing, + ) + + def _handle_scroll_document(self, action: types.FunctionCall) -> FunctionResponseT: + return self._browser_computer.scroll_document(action.args["direction"]) + + def _handle_scroll_at(self, action: types.FunctionCall) -> FunctionResponseT: + x = self.denormalize_x(action.args["x"]) + y = self.denormalize_y(action.args["y"]) + magnitude = action.args.get("magnitude", 800) + direction = action.args["direction"] + + if direction in ("up", "down"): + magnitude = self.denormalize_y(magnitude) + elif direction in ("left", "right"): + magnitude = self.denormalize_x(magnitude) + else: + raise ValueError(f"Unknown direction: {direction}") + return self._browser_computer.scroll_at( + x=x, y=y, direction=direction, magnitude=magnitude + ) + + def _handle_wait_5_seconds( + self, action: types.FunctionCall + ) -> FunctionResponseT: + return self._browser_computer.wait_5_seconds() + + def _handle_go_back(self, action: types.FunctionCall) -> FunctionResponseT: + return self._browser_computer.go_back() + + def _handle_go_forward(self, action: types.FunctionCall) -> FunctionResponseT: + return self._browser_computer.go_forward() + + def _handle_search(self, action: types.FunctionCall) -> FunctionResponseT: + return self._browser_computer.search() + + def _handle_navigate(self, action: types.FunctionCall) -> FunctionResponseT: + return self._browser_computer.navigate(action.args["url"]) + + def _handle_key_combination( + self, action: types.FunctionCall + ) -> FunctionResponseT: + return self._browser_computer.key_combination(action.args["keys"].split("+")) + + def _handle_drag_and_drop(self, action: types.FunctionCall) -> FunctionResponseT: + x = self.denormalize_x(action.args["x"]) + y = self.denormalize_y(action.args["y"]) + destination_x = self.denormalize_x(action.args["destination_x"]) + destination_y = self.denormalize_y(action.args["destination_y"]) + return self._browser_computer.drag_and_drop( + x=x, + y=y, + destination_x=destination_x, + destination_y=destination_y, + ) + + def _handle_multiply_numbers( + self, action: types.FunctionCall + ) -> FunctionResponseT: + return multiply_numbers(x=action.args["x"], y=action.args["y"]) + def get_model_response( self, max_retries=5, base_delay_s=1 ) -> types.GenerateContentResponse: @@ -253,11 +283,13 @@ def run_one_iteration(self) -> Literal["COMPLETE", "CONTINUE"]: try: response = self.get_model_response() except Exception as e: + print(e) return "COMPLETE" else: try: response = self.get_model_response() except Exception as e: + print(e) return "COMPLETE" if not response.candidates: @@ -292,7 +324,7 @@ def run_one_iteration(self) -> Literal["COMPLETE", "CONTINUE"]: # Print the function call and any reasoning. function_call_str = f"Name: {function_call.name}" if function_call.args: - function_call_str += f"\nArgs:" + function_call_str += "\nArgs:" for key, value in function_call.args.items(): function_call_str += f"\n {key}: {value}" function_call_strs.append(function_call_str) @@ -390,7 +422,7 @@ def _get_safety_confirmation( self, safety: dict[str, Any] ) -> Literal["CONTINUE", "TERMINATE"]: if safety["decision"] != "require_confirmation": - raise ValueError(f"Unknown safety decision: safety['decision']") + raise ValueError(f"Unknown safety decision: {safety['decision']}") termcolor.cprint( "Safety service requires explicit confirmation!", color="yellow",