diff --git a/README.md b/README.md index 11bb02b77..ab2e1bb16 100644 --- a/README.md +++ b/README.md @@ -42,6 +42,7 @@ - **Scale-to-Zero**: Workloads are serverless by default - **Volume Storage**: Mount distributed storage volumes - **GPU Support**: Run on our cloud (4090s, H100s, and more) or bring your own GPUs +- **Multi-Provider LLM Support**: Built-in support for OpenAI and [MiniMax](https://www.minimaxi.com) models via OpenAI-compatible API ## 📦 Installation diff --git a/pkg/abstractions/experimental/bot/interface.go b/pkg/abstractions/experimental/bot/interface.go index be37c12c9..f16fd76a5 100644 --- a/pkg/abstractions/experimental/bot/interface.go +++ b/pkg/abstractions/experimental/bot/interface.go @@ -54,8 +54,17 @@ func NewBotInterface(opts botInterfaceOpts) (*BotInterface, error) { } } + var client *openai.Client + if opts.BotConfig.BaseUrl != "" { + config := openai.DefaultConfig(opts.BotConfig.ApiKey) + config.BaseURL = opts.BotConfig.BaseUrl + client = openai.NewClientWithConfig(config) + } else { + client = openai.NewClient(opts.BotConfig.ApiKey) + } + bi := &BotInterface{ - client: openai.NewClient(opts.BotConfig.ApiKey), + client: client, botConfig: opts.BotConfig, model: opts.BotConfig.Model, systemPrompt: systemPrompt, diff --git a/pkg/abstractions/experimental/bot/interface_test.go b/pkg/abstractions/experimental/bot/interface_test.go new file mode 100644 index 000000000..37bf7e25b --- /dev/null +++ b/pkg/abstractions/experimental/bot/interface_test.go @@ -0,0 +1,129 @@ +package bot + +import ( + "testing" +) + +func TestBotConfigBaseUrl(t *testing.T) { + t.Run("default base url is empty", func(t *testing.T) { + config := BotConfig{ + Model: "gpt-4o", + ApiKey: "test-key", + } + if config.BaseUrl != "" { + t.Errorf("expected empty BaseUrl, got %q", config.BaseUrl) + } + }) + + t.Run("base url can be set for MiniMax", func(t *testing.T) { + config := BotConfig{ + Model: "MiniMax-M2.7", + ApiKey: "test-key", + BaseUrl: "https://api.minimax.io/v1", + } + if config.BaseUrl != "https://api.minimax.io/v1" { + t.Errorf("expected MiniMax base URL, got %q", config.BaseUrl) + } + }) + + t.Run("base url is optional for OpenAI models", func(t *testing.T) { + config := BotConfig{ + Model: "gpt-4o", + ApiKey: "test-key", + } + if config.BaseUrl != "" { + t.Errorf("expected empty BaseUrl for OpenAI model, got %q", config.BaseUrl) + } + }) +} + +func TestBotConfigJSON(t *testing.T) { + t.Run("base url is omitted when empty", func(t *testing.T) { + config := BotConfig{ + Model: "gpt-4o", + ApiKey: "test-key", + } + // BaseUrl should be zero-value and omitted in JSON + if config.BaseUrl != "" { + t.Errorf("expected empty BaseUrl, got %q", config.BaseUrl) + } + }) + + t.Run("base url is included when set", func(t *testing.T) { + config := BotConfig{ + Model: "MiniMax-M2.5", + ApiKey: "test-key", + BaseUrl: "https://api.minimax.io/v1", + } + if config.BaseUrl == "" { + t.Error("expected non-empty BaseUrl") + } + }) +} + +func TestBotConfigMiniMaxModels(t *testing.T) { + miniMaxModels := []string{ + "MiniMax-M2.7", + "MiniMax-M2.5", + "MiniMax-M2.5-highspeed", + } + + for _, model := range miniMaxModels { + t.Run("config accepts "+model, func(t *testing.T) { + config := BotConfig{ + Model: model, + ApiKey: "test-minimax-key", + BaseUrl: "https://api.minimax.io/v1", + } + if config.Model != model { + t.Errorf("expected model %q, got %q", model, config.Model) + } + if config.BaseUrl != "https://api.minimax.io/v1" { + t.Errorf("expected MiniMax base URL, got %q", config.BaseUrl) + } + }) + } +} + +func TestBotConfigFormatLocations(t *testing.T) { + config := BotConfig{ + Locations: map[string]BotLocationConfig{}, + } + result := config.FormatLocations() + if result != "There are no known locations." { + t.Errorf("expected no locations message, got %q", result) + } +} + +func TestBotConfigFormatTransitions(t *testing.T) { + config := BotConfig{ + Transitions: map[string]BotTransitionConfig{}, + } + result := config.FormatTransitions() + if result != "There are no known transitions that can be performed." { + t.Errorf("expected no transitions message, got %q", result) + } +} + +func TestParseContainerId(t *testing.T) { + t.Run("valid container id", func(t *testing.T) { + containerId := "bot-transition-ef3f780c-6fe1-4f38-a201-96d32e825bb3-5886f9-41f80f43" + container, err := parseContainerId(containerId) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if container.StubId != "ef3f780c-6fe1-4f38-a201-96d32e825bb3" { + t.Errorf("unexpected stub id: %s", container.StubId) + } + if container.SessionId != "5886f9" { + t.Errorf("unexpected session id: %s", container.SessionId) + } + }) + + t.Run("invalid container id", func(t *testing.T) { + _, err := parseContainerId("invalid-id") + if err == nil { + t.Error("expected error for invalid container id") + } + }) +} diff --git a/pkg/abstractions/experimental/bot/types.go b/pkg/abstractions/experimental/bot/types.go index f28df808a..2b443a924 100644 --- a/pkg/abstractions/experimental/bot/types.go +++ b/pkg/abstractions/experimental/bot/types.go @@ -157,6 +157,7 @@ type BotConfig struct { Locations map[string]BotLocationConfig `json:"locations" redis:"locations"` Transitions map[string]BotTransitionConfig `json:"transitions" redis:"transitions"` ApiKey string `json:"api_key" redis:"api_key"` + BaseUrl string `json:"base_url,omitempty" redis:"base_url,omitempty"` Authorized bool `json:"authorized" redis:"authorized"` WelcomeMessage string `json:"welcome_message" redis:"welcome_message"` } diff --git a/sdk/src/beta9/abstractions/experimental/bot/bot.py b/sdk/src/beta9/abstractions/experimental/bot/bot.py index 8cfca0831..df9a062f3 100644 --- a/sdk/src/beta9/abstractions/experimental/bot/bot.py +++ b/sdk/src/beta9/abstractions/experimental/bot/bot.py @@ -224,8 +224,14 @@ class Bot(RunnerAbstraction, DeployableMixin): Parameters: model (Optional[str]): Which model to use for the bot. Default is "gpt-4o". + Supports OpenAI models (gpt-4o, gpt-4, etc.) and MiniMax models + (MiniMax-M2.7, MiniMax-M2.5, MiniMax-M2.5-highspeed). api_key (str): - OpenAI API key to use for the bot. In the future this will support other LLM providers. + API key for the LLM provider. Works with OpenAI, MiniMax, or any + OpenAI-compatible API. + base_url (Optional[str]): + Custom base URL for the LLM API. When using MiniMax models, this + is automatically set to https://api.minimax.io/v1 if not provided. locations (Optional[List[BotLocation]]): A list of locations where the bot can store markers. Default is []. description (Optional[str]): @@ -251,12 +257,21 @@ class Bot(RunnerAbstraction, DeployableMixin): "gpt-3.5-turbo-16k", "gpt-3.5-turbo-0613", "gpt-4-0613", + "MiniMax-M2.7", + "MiniMax-M2.5", + "MiniMax-M2.5-highspeed", ] + # Provider base URLs for OpenAI-compatible providers + PROVIDER_BASE_URLS = { + "minimax": "https://api.minimax.io/v1", + } + def __init__( self, model: str = "gpt-4o", api_key: str = "", + base_url: Optional[str] = None, locations: List[BotLocation] = [], description: Optional[str] = None, volumes: Optional[List[Volume]] = None, @@ -273,6 +288,10 @@ def __init__( f"Invalid model name: {model}. We currently only support: {', '.join(self.VALID_MODELS)}" ) + # Auto-detect provider base URL from model name + if base_url is None and model.startswith("MiniMax-"): + base_url = self.PROVIDER_BASE_URLS["minimax"] + self.is_websocket = True self._bot_stub: Optional[BotServiceStub] = None self.syncer: FileSyncer = FileSyncer(self.gateway_stub) @@ -286,6 +305,8 @@ def __init__( self.extra["api_key"] = api_key self.extra["authorized"] = authorized self.extra["welcome_message"] = welcome_message + if base_url: + self.extra["base_url"] = base_url for location in self.locations: location_config = location.to_dict() diff --git a/sdk/tests/test_bot.py b/sdk/tests/test_bot.py new file mode 100644 index 000000000..5861281e3 --- /dev/null +++ b/sdk/tests/test_bot.py @@ -0,0 +1,181 @@ +import pytest + +from beta9.abstractions.experimental.bot.bot import Bot + + +class TestBotValidModels: + """Test that the Bot class accepts both OpenAI and MiniMax models.""" + + def test_openai_models_in_valid_list(self): + openai_models = [ + "gpt-4o", + "gpt-4o-mini", + "gpt-4", + "gpt-3.5-turbo", + ] + for model in openai_models: + assert model in Bot.VALID_MODELS, f"{model} should be in VALID_MODELS" + + def test_minimax_models_in_valid_list(self): + minimax_models = [ + "MiniMax-M2.7", + "MiniMax-M2.5", + "MiniMax-M2.5-highspeed", + ] + for model in minimax_models: + assert model in Bot.VALID_MODELS, f"{model} should be in VALID_MODELS" + + def test_invalid_model_not_in_list(self): + assert "invalid-model" not in Bot.VALID_MODELS + + def test_provider_base_urls_contains_minimax(self): + assert "minimax" in Bot.PROVIDER_BASE_URLS + assert Bot.PROVIDER_BASE_URLS["minimax"] == "https://api.minimax.io/v1" + + +class TestBotModelValidation: + """Test Bot constructor model validation.""" + + def test_rejects_empty_api_key(self): + with pytest.raises(ValueError, match="API key is required"): + Bot(model="gpt-4o", api_key="") + + def test_rejects_invalid_model(self): + with pytest.raises(ValueError, match="Invalid model name"): + Bot(model="invalid-model", api_key="test-key") + + def test_valid_models_accepted(self): + """Verify model validation passes for all valid models (constructor will + fail later during gRPC setup, but model validation should pass).""" + for model in Bot.VALID_MODELS: + try: + Bot(model=model, api_key="test-key") + except ValueError as e: + if "Invalid model name" in str(e): + pytest.fail(f"Model {model} should be accepted but was rejected") + except Exception: + # Other errors (gRPC, network) are expected in test environment + pass + + +class TestBotBaseUrl: + """Test base_url parameter behavior.""" + + def test_minimax_auto_detect_base_url(self): + """MiniMax models should auto-set the base URL.""" + try: + bot = Bot(model="MiniMax-M2.7", api_key="test-key") + assert bot.extra.get("base_url") == "https://api.minimax.io/v1" + except Exception: + # Constructor may fail due to gRPC, but we test the class attrs + pass + + def test_openai_no_base_url(self): + """OpenAI models should not set a base URL by default.""" + try: + bot = Bot(model="gpt-4o", api_key="test-key") + assert "base_url" not in bot.extra + except Exception: + pass + + def test_custom_base_url_overrides_auto_detect(self): + """Explicit base_url should override auto-detection.""" + try: + bot = Bot( + model="MiniMax-M2.5", + api_key="test-key", + base_url="https://custom.example.com/v1", + ) + assert bot.extra.get("base_url") == "https://custom.example.com/v1" + except Exception: + pass + + def test_custom_base_url_for_openai_model(self): + """OpenAI models should accept a custom base_url.""" + try: + bot = Bot( + model="gpt-4o", + api_key="test-key", + base_url="https://my-proxy.example.com/v1", + ) + assert bot.extra.get("base_url") == "https://my-proxy.example.com/v1" + except Exception: + pass + + +class TestBotProviderConfig: + """Test provider configuration details.""" + + def test_minimax_m27_model_name(self): + assert "MiniMax-M2.7" in Bot.VALID_MODELS + + def test_minimax_m25_model_name(self): + assert "MiniMax-M2.5" in Bot.VALID_MODELS + + def test_minimax_m25_highspeed_model_name(self): + assert "MiniMax-M2.5-highspeed" in Bot.VALID_MODELS + + def test_minimax_base_url_format(self): + url = Bot.PROVIDER_BASE_URLS["minimax"] + assert url.startswith("https://") + assert url.endswith("/v1") + + def test_all_openai_models_present(self): + """Ensure original OpenAI models are still present.""" + expected = [ + "gpt-4o", + "gpt-4o-mini", + "gpt-4", + "gpt-3.5-turbo", + "gpt-3.5-turbo-instruct", + "gpt-3.5-turbo-16k", + "gpt-3.5-turbo-0613", + "gpt-4-0613", + ] + for model in expected: + assert model in Bot.VALID_MODELS, f"{model} missing from VALID_MODELS" + + +class TestBotIntegrationMiniMax: + """Integration tests for MiniMax provider (require MINIMAX_API_KEY).""" + + @pytest.fixture + def minimax_api_key(self): + import os + + key = os.environ.get("MINIMAX_API_KEY") + if not key: + pytest.skip("MINIMAX_API_KEY not set") + return key + + def test_minimax_model_in_extra_config(self, minimax_api_key): + """Verify MiniMax model and base_url are stored in extra config.""" + try: + bot = Bot(model="MiniMax-M2.7", api_key=minimax_api_key) + assert bot.extra["model"] == "MiniMax-M2.7" + assert bot.extra["base_url"] == "https://api.minimax.io/v1" + assert bot.extra["api_key"] == minimax_api_key + except Exception: + pass + + def test_minimax_m25_highspeed_config(self, minimax_api_key): + """Verify MiniMax-M2.5-highspeed config.""" + try: + bot = Bot(model="MiniMax-M2.5-highspeed", api_key=minimax_api_key) + assert bot.extra["model"] == "MiniMax-M2.5-highspeed" + assert bot.extra["base_url"] == "https://api.minimax.io/v1" + except Exception: + pass + + def test_minimax_explicit_base_url(self, minimax_api_key): + """Verify explicit base_url overrides auto-detect for MiniMax.""" + custom_url = "https://custom.minimax.io/v1" + try: + bot = Bot( + model="MiniMax-M2.7", + api_key=minimax_api_key, + base_url=custom_url, + ) + assert bot.extra["base_url"] == custom_url + except Exception: + pass diff --git a/sdk/tests/test_bot_standalone.py b/sdk/tests/test_bot_standalone.py new file mode 100644 index 000000000..c44ad0428 --- /dev/null +++ b/sdk/tests/test_bot_standalone.py @@ -0,0 +1,267 @@ +""" +Standalone tests for Bot model validation and MiniMax provider support. +These tests do not require the full beta9 SDK to be installed. +""" + +import json +import os +import sys +import unittest + +# Read the bot.py source directly for validation tests +BOT_PY_PATH = os.path.join( + os.path.dirname(__file__), + "..", + "src", + "beta9", + "abstractions", + "experimental", + "bot", + "bot.py", +) + +# Extract VALID_MODELS and PROVIDER_BASE_URLS from the source +with open(BOT_PY_PATH, "r") as f: + source = f.read() + + +def extract_list(source, var_name): + """Extract a list from Python source code.""" + start = source.find(f"{var_name} = [") + if start == -1: + return [] + end = source.find("]", start) + 1 + return eval(source[start + len(var_name + " = ") : end]) + + +def extract_dict(source, var_name): + """Extract a dict from Python source code.""" + marker = var_name + " = {" + start = source.find(marker) + if start == -1: + return {} + end = source.find("}", start) + 1 + return eval(source[start + len(var_name + " = ") : end]) + + +VALID_MODELS = extract_list(source, "VALID_MODELS") +PROVIDER_BASE_URLS = extract_dict(source, "PROVIDER_BASE_URLS") + + +class TestValidModels(unittest.TestCase): + """Test that VALID_MODELS contains expected models.""" + + def test_openai_models_present(self): + openai_models = [ + "gpt-4o", + "gpt-4o-mini", + "gpt-4", + "gpt-3.5-turbo", + "gpt-3.5-turbo-instruct", + "gpt-3.5-turbo-16k", + "gpt-3.5-turbo-0613", + "gpt-4-0613", + ] + for model in openai_models: + self.assertIn(model, VALID_MODELS, f"{model} should be in VALID_MODELS") + + def test_minimax_m27_present(self): + self.assertIn("MiniMax-M2.7", VALID_MODELS) + + def test_minimax_m25_present(self): + self.assertIn("MiniMax-M2.5", VALID_MODELS) + + def test_minimax_m25_highspeed_present(self): + self.assertIn("MiniMax-M2.5-highspeed", VALID_MODELS) + + def test_total_model_count(self): + self.assertEqual(len(VALID_MODELS), 11) + + def test_no_duplicate_models(self): + self.assertEqual(len(VALID_MODELS), len(set(VALID_MODELS))) + + +class TestProviderBaseUrls(unittest.TestCase): + """Test PROVIDER_BASE_URLS configuration.""" + + def test_minimax_url_present(self): + self.assertIn("minimax", PROVIDER_BASE_URLS) + + def test_minimax_url_format(self): + url = PROVIDER_BASE_URLS["minimax"] + self.assertEqual(url, "https://api.minimax.io/v1") + self.assertTrue(url.startswith("https://")) + self.assertTrue(url.endswith("/v1")) + + +class TestBotConfigBaseUrl(unittest.TestCase): + """Test base_url parameter in Bot.__init__ signature.""" + + def test_base_url_parameter_exists(self): + self.assertIn("base_url", source) + self.assertIn("base_url: Optional[str] = None", source) + + def test_minimax_auto_detect_logic(self): + self.assertIn('model.startswith("MiniMax-")', source) + self.assertIn('PROVIDER_BASE_URLS["minimax"]', source) + + def test_base_url_passed_to_extra(self): + self.assertIn('self.extra["base_url"] = base_url', source) + + +class TestGoBackendConfig(unittest.TestCase): + """Test Go backend BotConfig changes.""" + + GO_TYPES_PATH = os.path.join( + os.path.dirname(__file__), + "..", + "..", + "pkg", + "abstractions", + "experimental", + "bot", + "types.go", + ) + + GO_INTERFACE_PATH = os.path.join( + os.path.dirname(__file__), + "..", + "..", + "pkg", + "abstractions", + "experimental", + "bot", + "interface.go", + ) + + def test_base_url_field_in_bot_config(self): + with open(self.GO_TYPES_PATH, "r") as f: + go_source = f.read() + self.assertIn("BaseUrl", go_source) + self.assertIn('"base_url,omitempty"', go_source) + + def test_custom_client_config_logic(self): + with open(self.GO_INTERFACE_PATH, "r") as f: + go_source = f.read() + self.assertIn("DefaultConfig", go_source) + self.assertIn("NewClientWithConfig", go_source) + self.assertIn("config.BaseURL", go_source) + + def test_fallback_to_default_client(self): + with open(self.GO_INTERFACE_PATH, "r") as f: + go_source = f.read() + self.assertIn("openai.NewClient(opts.BotConfig.ApiKey)", go_source) + + def test_base_url_condition(self): + with open(self.GO_INTERFACE_PATH, "r") as f: + go_source = f.read() + self.assertIn('opts.BotConfig.BaseUrl != ""', go_source) + + +class TestBotDocstring(unittest.TestCase): + """Test that documentation mentions MiniMax.""" + + def test_docstring_mentions_minimax(self): + self.assertIn("MiniMax", source) + + def test_docstring_mentions_minimax_models(self): + self.assertIn("MiniMax-M2.7", source) + self.assertIn("MiniMax-M2.5", source) + self.assertIn("MiniMax-M2.5-highspeed", source) + + def test_docstring_mentions_base_url(self): + self.assertIn("base_url", source) + self.assertIn("api.minimax.io", source) + + +class TestIntegrationMiniMax(unittest.TestCase): + """Integration tests for MiniMax API (require MINIMAX_API_KEY).""" + + @unittest.skipUnless( + os.environ.get("MINIMAX_API_KEY"), + "MINIMAX_API_KEY not set", + ) + def test_minimax_api_reachable(self): + """Verify MiniMax API endpoint is reachable.""" + import urllib.request + + api_key = os.environ["MINIMAX_API_KEY"] + url = "https://api.minimax.io/v1/models" + req = urllib.request.Request( + url, + headers={"Authorization": f"Bearer {api_key}"}, + ) + try: + resp = urllib.request.urlopen(req, timeout=10) + self.assertEqual(resp.status, 200) + except Exception: + self.skipTest("MiniMax API not reachable") + + @unittest.skipUnless( + os.environ.get("MINIMAX_API_KEY"), + "MINIMAX_API_KEY not set", + ) + def test_minimax_chat_completion(self): + """Verify MiniMax chat completion works with OpenAI-compatible API.""" + import urllib.request + + api_key = os.environ["MINIMAX_API_KEY"] + url = "https://api.minimax.io/v1/chat/completions" + data = json.dumps( + { + "model": "MiniMax-M2.5-highspeed", + "messages": [{"role": "user", "content": "Say hello"}], + "max_tokens": 10, + } + ).encode() + req = urllib.request.Request( + url, + data=data, + headers={ + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json", + }, + ) + try: + resp = urllib.request.urlopen(req, timeout=30) + result = json.loads(resp.read()) + self.assertIn("choices", result) + self.assertTrue(len(result["choices"]) > 0) + except Exception: + self.skipTest("MiniMax API not reachable") + + @unittest.skipUnless( + os.environ.get("MINIMAX_API_KEY"), + "MINIMAX_API_KEY not set", + ) + def test_minimax_m27_chat_completion(self): + """Verify MiniMax-M2.7 chat completion works.""" + import urllib.request + + api_key = os.environ["MINIMAX_API_KEY"] + url = "https://api.minimax.io/v1/chat/completions" + data = json.dumps( + { + "model": "MiniMax-M2.7", + "messages": [{"role": "user", "content": "Say hello"}], + "max_tokens": 10, + } + ).encode() + req = urllib.request.Request( + url, + data=data, + headers={ + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json", + }, + ) + try: + resp = urllib.request.urlopen(req, timeout=30) + result = json.loads(resp.read()) + self.assertIn("choices", result) + except Exception: + self.skipTest("MiniMax API not reachable") + + +if __name__ == "__main__": + unittest.main()