diff --git a/NEWS.md b/NEWS.md index f6e5e582a..36f5134ca 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,8 @@ # ellmer (development version) +* `batch_chat()` now supports `chat_groq()` for batch processing via the + Groq batch API. `chat_groq()` now also supports structured data extraction + (@xmarquez, #914). * ellmer will now distinguish text content from thinking content while streaming, allowing downstream packages like shinychat to provide specific UI for thinking content (@simonpcouch, #909). * `chat_github()` now uses `chat_openai_compatible()` for improved compatibility, and `models_github()` now supports custom `base_url` configuration (@D-M4rk, #877). * `chat_ollama()` now contains a slot for `top_k` within the `params` argument (@frankiethull). diff --git a/R/batch-chat.R b/R/batch-chat.R index 4a4b5084c..affebd6e1 100644 --- a/R/batch-chat.R +++ b/R/batch-chat.R @@ -2,10 +2,11 @@ #' #' @description #' `batch_chat()` and `batch_chat_structured()` currently only work with -#' [chat_openai()] and [chat_anthropic()]. They use the -#' [OpenAI](https://platform.openai.com/docs/guides/batch) and -#' [Anthropic](https://docs.claude.com/en/docs/build-with-claude/batch-processing) -#' batch APIs which allow you to submit multiple requests simultaneously. +#' [chat_openai()], [chat_anthropic()], and [chat_groq()]. They use the +#' [OpenAI](https://platform.openai.com/docs/guides/batch), +#' [Anthropic](https://docs.claude.com/en/docs/build-with-claude/batch-processing), +#' and [Groq](https://console.groq.com/docs/batch) batch APIs which allow +#' you to submit multiple requests simultaneously. #' The results can take up to 24 hours to complete, but in return you pay 50% #' less than usual (but note that ellmer doesn't include this discount in #' its pricing metadata). If you want to get results back more quickly, or diff --git a/R/provider-groq.R b/R/provider-groq.R index c8065414c..0567789df 100644 --- a/R/provider-groq.R +++ b/R/provider-groq.R @@ -8,10 +8,6 @@ NULL #' #' Built on top of [chat_openai_compatible()]. #' -#' ## Known limitations -#' -#' groq does not currently support structured data extraction. -#' #' @export #' @family chatbots #' @param api_key `r lifecycle::badge("deprecated")` Use `credentials` instead. @@ -69,10 +65,16 @@ method(as_json, list(ProviderGroq, Turn)) <- function(provider, x, ...) { is_tool <- map_lgl(x@contents, is_tool_request) tool_calls <- as_json(provider, x@contents[is_tool], ...) - # Grok contents is just a string. Hopefully it never sends back more - # than a single text response. - if (any(!is_tool)) { - content <- x@contents[!is_tool][[1]]@text + # Groq assistant content is just a string (not a list of content parts). + non_tool <- x@contents[!is_tool] + if (length(non_tool) > 0) { + first <- non_tool[[1]] + if (S7_inherits(first, ContentJson)) { + content <- first@string %||% + unclass(jsonlite::toJSON(first@data, auto_unbox = TRUE)) + } else { + content <- first@text + } } else { content <- NULL } @@ -95,12 +97,23 @@ method(as_json, list(ProviderGroq, TypeObject)) <- function(provider, x, ...) { } required <- map_lgl(x@properties, function(prop) prop@required) - compact(list( + schema <- compact(list( type = "object", description = x@description, properties = as_json(provider, x@properties, ...), - required = as.list(names2(x@properties)[required]) + required = as.list(names2(x@properties)[required]), + additionalProperties = FALSE + )) + add_additional_properties_false(schema) +} + +method(as_json, list(ProviderGroq, TypeArray)) <- function(provider, x, ...) { + schema <- compact(list( + type = "array", + description = x@description, + items = as_json(provider, x@items, ...) )) + add_additional_properties_false(schema) } method(as_json, list(ProviderGroq, ToolDef)) <- function(provider, x, ...) { @@ -117,3 +130,166 @@ method(as_json, list(ProviderGroq, ToolDef)) <- function(provider, x, ...) { groq_key <- function() { key_get("GROQ_API_KEY") } + +# Structured output helpers ------------------------------------------------ + +#' Recursively add additionalProperties: false to all objects +#' +#' Groq requires `additionalProperties: false` on all objects for strict JSON +#' validation. This helper ensures all nested objects have this property set. +#' @param node A list representing a JSON schema node +#' @return The modified node with additionalProperties: false on all objects +#' @noRd +add_additional_properties_false <- function(node) { + if (is.list(node) && !is.null(node$type) && identical(node$type, "object")) { + node$additionalProperties <- FALSE + if (!is.null(node$properties) && is.list(node$properties)) { + node$properties <- lapply( + node$properties, + add_additional_properties_false + ) + } + } + if (is.list(node) && !is.null(node$items)) { + node$items <- add_additional_properties_false(node$items) + } + node +} + +# Batched requests --------------------------------------------------------- + +# https://console.groq.com/docs/batch +method(has_batch_support, ProviderGroq) <- function(provider) { + TRUE +} + +method(batch_submit, ProviderGroq) <- function( + provider, + conversations, + type = NULL +) { + path <- withr::local_tempfile(fileext = ".jsonl") + + requests <- map(seq_along(conversations), function(i) { + body <- chat_body( + provider, + stream = FALSE, + turns = conversations[[i]], + type = type + ) + list( + custom_id = paste0("chat-", i), + method = "POST", + url = "/v1/chat/completions", + body = body + ) + }) + json <- map_chr(requests, to_json) + writeLines(json, path) + + uploaded <- groq_upload_file(provider, path) + + req <- base_request(provider) + req <- req_url_path_append(req, "/batches") + req <- req_body_json( + req, + list( + input_file_id = uploaded$id, + endpoint = "/v1/chat/completions", + completion_window = "24h" + ) + ) + + resp <- req_perform(req) + resp_body_json(resp) +} + +method(batch_poll, ProviderGroq) <- function(provider, batch) { + req <- base_request(provider) + req <- req_url_path_append(req, "/batches/", batch$id) + + resp <- req_perform(req) + resp_body_json(resp) +} + +method(batch_status, ProviderGroq) <- function(provider, batch) { + terminal_states <- c("completed", "failed", "expired", "cancelled") + + total <- batch$request_counts$total %||% 0L + completed <- batch$request_counts$completed %||% 0L + failed <- batch$request_counts$failed %||% 0L + + list( + working = !(batch$status %in% terminal_states), + n_processing = max(total - completed - failed, 0L), + n_succeeded = completed, + n_failed = failed + ) +} + +method(batch_retrieve, ProviderGroq) <- function(provider, batch) { + json <- list() + + if (length(batch$output_file_id) == 1 && nzchar(batch$output_file_id)) { + path_output <- withr::local_tempfile() + groq_download_file(provider, batch$output_file_id, path_output) + json <- read_ndjson(path_output, fallback = groq_json_fallback) + } + + if (length(batch$error_file_id) == 1 && nzchar(batch$error_file_id)) { + path_error <- withr::local_tempfile() + groq_download_file(provider, batch$error_file_id, path_error) + json <- c(json, read_ndjson(path_error, fallback = groq_json_fallback)) + } + + ids <- as.numeric(gsub("chat-", "", map_chr(json, "[[", "custom_id"))) + results <- lapply(json, "[[", "response") + results[order(ids)] +} + +method(batch_result_turn, ProviderGroq) <- function( + provider, + result, + has_type = FALSE +) { + if (result$status_code == 200) { + value_turn(provider, result$body, has_type = has_type) + } else { + NULL + } +} + +# Batch helpers ------------------------------------------------------------ + +#' @noRd +groq_upload_file <- function(provider, path, purpose = "batch") { + req <- base_request(provider) + req <- req_url_path_append(req, "/files") + req <- req_body_multipart( + req, + purpose = purpose, + file = curl::form_file(path) + ) + req <- req_progress(req, "up") + + resp <- req_perform(req) + resp_body_json(resp) +} + +#' @noRd +groq_download_file <- function(provider, id, path) { + req <- base_request(provider) + req <- req_url_path_append(req, "/files/", id, "/content") + req <- req_progress(req, "down") + req_perform(req, path = path) + + invisible(path) +} + +#' @noRd +groq_json_fallback <- function(line) { + list( + custom_id = extract_custom_id(line), + response = list(status_code = 500) + ) +} diff --git a/man/batch_chat.Rd b/man/batch_chat.Rd index e6945267a..8bbf6b0f0 100644 --- a/man/batch_chat.Rd +++ b/man/batch_chat.Rd @@ -75,10 +75,11 @@ is not complete. } \description{ \code{batch_chat()} and \code{batch_chat_structured()} currently only work with -\code{\link[=chat_openai]{chat_openai()}} and \code{\link[=chat_anthropic]{chat_anthropic()}}. They use the -\href{https://platform.openai.com/docs/guides/batch}{OpenAI} and -\href{https://docs.claude.com/en/docs/build-with-claude/batch-processing}{Anthropic} -batch APIs which allow you to submit multiple requests simultaneously. +\code{\link[=chat_openai]{chat_openai()}}, \code{\link[=chat_anthropic]{chat_anthropic()}}, and \code{\link[=chat_groq]{chat_groq()}}. They use the +\href{https://platform.openai.com/docs/guides/batch}{OpenAI}, +\href{https://docs.claude.com/en/docs/build-with-claude/batch-processing}{Anthropic}, +and \href{https://console.groq.com/docs/batch}{Groq} batch APIs which allow +you to submit multiple requests simultaneously. The results can take up to 24 hours to complete, but in return you pay 50\% less than usual (but note that ellmer doesn't include this discount in its pricing metadata). If you want to get results back more quickly, or diff --git a/man/chat_groq.Rd b/man/chat_groq.Rd index 7f00b2513..72a6ba94e 100644 --- a/man/chat_groq.Rd +++ b/man/chat_groq.Rd @@ -58,10 +58,6 @@ A \link{Chat} object. Sign up at \url{https://groq.com}. Built on top of \code{\link[=chat_openai_compatible]{chat_openai_compatible()}}. -\subsection{Known limitations}{ - -groq does not currently support structured data extraction. -} } \examples{ \dontrun{ diff --git a/tests/testthat/batch/state-capitals-groq.json b/tests/testthat/batch/state-capitals-groq.json new file mode 100644 index 000000000..a321df332 --- /dev/null +++ b/tests/testthat/batch/state-capitals-groq.json @@ -0,0 +1,187 @@ +{ + "version": 1, + "stage": "done", + "batch": { + "id": "batch_01kheqkczbf7cr7mbc8k6x5crb", + "object": "batch", + "endpoint": "/v1/chat/completions", + "input_file_id": "file_01kheqkckset8r4m6gjbk0j4jg", + "completion_window": "24h", + "status": "completed", + "output_file_id": "file_01kheqkfkwf2gber5w7w53xhmn", + "error_file_id": {}, + "created_at": 1771094782, + "in_progress_at": 1771094785, + "finalizing_at": 1771094785, + "completed_at": 1771094785, + "failed_at": {}, + "cancelling_at": {}, + "cancelled_at": {}, + "expires_at": 1771181182, + "expired_at": {}, + "request_counts": { + "completed": 4, + "failed": 0, + "total": 4 + }, + "metadata": {}, + "errors": {}, + "project_id": "project_01jtkfs8evfhctdnt5dc12rh0j" + }, + "results": [ + { + "status_code": 200, + "request_id": "req_01kheqkfecf7p9d9vxd906jch0", + "body": { + "id": "chatcmpl-49373628-eadd-44e5-b8a6-0df55121679e", + "object": "chat.completion", + "created": 1771094785, + "model": "llama-3.1-8b-instant", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "Des Moines" + }, + "logprobs": {}, + "finish_reason": "stop" + } + ], + "usage": { + "queue_time": 0.0624, + "prompt_tokens": 48, + "prompt_time": 0.0038, + "completion_tokens": 3, + "completion_time": 0.0044, + "total_tokens": 51, + "total_time": 0.0082 + }, + "usage_breakdown": {}, + "system_fingerprint": "fp_8f8420ecd7", + "x_groq": { + "id": "req_01kheqkfecf7p9d9vxd906jch0", + "seed": 1014 + }, + "service_tier": "batch" + } + }, + { + "status_code": 200, + "request_id": "req_01kheqkfe9fp1b4mye6a1k3b05", + "body": { + "id": "chatcmpl-8b7db3eb-f7fa-4138-97b5-2ce05a77e578", + "object": "chat.completion", + "created": 1771094785, + "model": "llama-3.1-8b-instant", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "Albany" + }, + "logprobs": {}, + "finish_reason": "stop" + } + ], + "usage": { + "queue_time": 0.0619, + "prompt_tokens": 49, + "prompt_time": 0.0039, + "completion_tokens": 4, + "completion_time": 0.0061, + "total_tokens": 53, + "total_time": 0.01 + }, + "usage_breakdown": {}, + "system_fingerprint": "fp_8f8420ecd7", + "x_groq": { + "id": "req_01kheqkfe9fp1b4mye6a1k3b05", + "seed": 1014 + }, + "service_tier": "batch" + } + }, + { + "status_code": 200, + "request_id": "req_01kheqkff8f7q99dzdre1mbxtm", + "body": { + "id": "chatcmpl-e21e4735-0781-4fd1-bc07-1ccc8787ab0d", + "object": "chat.completion", + "created": 1771094785, + "model": "llama-3.1-8b-instant", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "Sacramento" + }, + "logprobs": {}, + "finish_reason": "stop" + } + ], + "usage": { + "queue_time": 0.0599, + "prompt_tokens": 48, + "prompt_time": 0.0026, + "completion_tokens": 3, + "completion_time": 0.0044, + "total_tokens": 51, + "total_time": 0.007 + }, + "usage_breakdown": {}, + "system_fingerprint": "fp_d317489708", + "x_groq": { + "id": "req_01kheqkff8f7q99dzdre1mbxtm", + "seed": 1014 + }, + "service_tier": "batch" + } + }, + { + "status_code": 200, + "request_id": "req_01kheqkff8f7psfwbp0qqwvnj2", + "body": { + "id": "chatcmpl-4ca7a4d9-5186-4b6a-a277-361a52f0b5da", + "object": "chat.completion", + "created": 1771094785, + "model": "llama-3.1-8b-instant", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "Austin" + }, + "logprobs": {}, + "finish_reason": "stop" + } + ], + "usage": { + "queue_time": 0.0619, + "prompt_tokens": 48, + "prompt_time": 0.0055, + "completion_tokens": 2, + "completion_time": 0.0036, + "total_tokens": 50, + "total_time": 0.0091 + }, + "usage_breakdown": {}, + "system_fingerprint": "fp_d317489708", + "x_groq": { + "id": "req_01kheqkff8f7psfwbp0qqwvnj2", + "seed": 1014 + }, + "service_tier": "batch" + } + } + ], + "started_at": 1771094782, + "hash": { + "provider": "15e4e540bae33450eee06e09bf585198", + "prompts": "b8eafe281e3cc5113058d9722be3e295", + "user_turns": "d6990a1b8a9f5db0e97de86c2669de44" + } +} diff --git a/tests/testthat/test-provider-groq-batch.R b/tests/testthat/test-provider-groq-batch.R new file mode 100644 index 000000000..d53196717 --- /dev/null +++ b/tests/testthat/test-provider-groq-batch.R @@ -0,0 +1,291 @@ +# Helper to create a dummy provider without needing real credentials +dummy_groq_provider <- function() { + ProviderGroq( + name = "Groq", + base_url = "https://api.groq.com/openai/v1", + model = "llama-3.1-8b-instant", + params = params(), + extra_args = list(), + extra_headers = character(), + credentials = NULL + ) +} + +# Groq batch helper functions ----------------------------------------------- + +test_that("groq_json_fallback extracts custom_id from malformed line", { + result <- groq_json_fallback('{"custom_id": "chat-3", broken json...') + + expect_equal(result$custom_id, "chat-3") + expect_equal(result$response$status_code, 500) +}) + +test_that("groq_json_fallback returns NA for unparseable line", { + result <- groq_json_fallback("completely broken") + + expect_true(is.na(result$custom_id)) + expect_equal(result$response$status_code, 500) +}) + +# Turn serialization ------------------------------------------------------- + +test_that("as_json(Turn) handles ContentJson in assistant turns", { + provider <- dummy_groq_provider() + + # ContentJson with data + turn_json <- AssistantTurn(list(ContentJson(data = list(answer = "4")))) + result <- as_json(provider, turn_json) + expect_equal(result[[1]]$role, "assistant") + expect_equal(result[[1]]$content, '{"answer":"4"}') + + # ContentJson with string + turn_str <- AssistantTurn(list(ContentJson(data = NULL, string = '{"x":1}'))) + result2 <- as_json(provider, turn_str) + expect_equal(result2[[1]]$content, '{"x":1}') + + # ContentText still works + turn_text <- AssistantTurn(list(ContentText("hello"))) + result3 <- as_json(provider, turn_text) + expect_equal(result3[[1]]$content, "hello") +}) + +# Schema generation -------------------------------------------------------- + +test_that("as_json(TypeObject) adds additionalProperties: false", { + provider <- dummy_groq_provider() + + type_obj <- type_object( + name = type_string(), + age = type_integer() + ) + schema <- as_json(provider, type_obj) + + expect_equal(schema$type, "object") + expect_equal(schema$additionalProperties, FALSE) + expect_true("name" %in% names(schema$properties)) + expect_true("age" %in% names(schema$properties)) +}) + +test_that("nested objects get recursive additionalProperties: false", { + provider <- dummy_groq_provider() + + type_nested <- type_object( + person = type_object( + name = type_string(), + age = type_integer() + ), + address = type_object( + city = type_string(), + country = type_string() + ) + ) + schema <- as_json(provider, type_nested) + + expect_equal(schema$additionalProperties, FALSE) + expect_equal(schema$properties$person$additionalProperties, FALSE) + expect_equal(schema$properties$address$additionalProperties, FALSE) +}) + +test_that("array schema items get additionalProperties: false", { + provider <- dummy_groq_provider() + + type_arr <- type_array( + type_object(name = type_string()) + ) + schema <- as_json(provider, type_arr) + + expect_equal(schema$type, "array") + expect_equal(schema$items$type, "object") + expect_equal(schema$items$additionalProperties, FALSE) +}) + +# Batch support ----------------------------------------------------------- + +test_that("ProviderGroq has batch support", { + provider <- dummy_groq_provider() + expect_true(has_batch_support(provider)) +}) + +test_that("batch_status parses completed state", { + provider <- dummy_groq_provider() + batch <- list( + status = "completed", + request_counts = list(total = 5L, completed = 5L, failed = 0L) + ) + status <- batch_status(provider, batch) + + expect_false(status$working) + expect_equal(status$n_processing, 0L) + expect_equal(status$n_succeeded, 5L) + expect_equal(status$n_failed, 0L) +}) + +test_that("batch_status parses in_progress state", { + provider <- dummy_groq_provider() + batch <- list( + status = "in_progress", + request_counts = list(total = 10L, completed = 3L, failed = 1L) + ) + status <- batch_status(provider, batch) + + expect_true(status$working) + expect_equal(status$n_processing, 6L) + expect_equal(status$n_succeeded, 3L) + expect_equal(status$n_failed, 1L) +}) + +test_that("batch_status clamps n_processing to zero", { + provider <- dummy_groq_provider() + batch <- list( + status = "completed", + request_counts = list(total = 5L, completed = 6L, failed = 0L) + ) + status <- batch_status(provider, batch) + + expect_equal(status$n_processing, 0L) +}) + +# Fixture-based tests ---------------------------------------------------- + +test_that("batch chat works with Groq fixture", { + withr::local_envvar(GROQ_API_KEY = "dummy-key-for-fixture-test") + chat <- chat_groq( + system_prompt = "Answer with just the city name", + model = "llama-3.1-8b-instant", + params = params(temperature = 0, seed = 1014) + ) + + prompts <- list( + "What's the capital of Iowa?", + "What's the capital of New York?", + "What's the capital of California?", + "What's the capital of Texas?" + ) + + out <- batch_chat_text( + chat, + prompts, + path = test_path("batch/state-capitals-groq.json"), + ignore_hash = TRUE + ) + expect_equal(out, c("Des Moines", "Albany", "Sacramento", "Austin")) +}) + +# Integration tests ------------------------------------------------------- + +test_that("Groq batch_chat submits and can be resumed", { + skip_if( + Sys.getenv("GROQ_API_KEY") == "", + "No Groq credentials set" + ) + + chat <- chat_groq( + system_prompt = "Reply concisely", + model = "llama-3.1-8b-instant" + ) + + prompts <- list("Reply with exactly: ok") + results_file <- withr::local_tempfile(fileext = ".json") + + chats <- tryCatch( + batch_chat( + chat, + prompts = prompts, + path = results_file, + wait = FALSE + ), + error = function(e) { + msg <- conditionMessage(e) + if (grepl("unexpected number of responses", msg, fixed = TRUE)) { + NULL + } else { + stop(e) + } + } + ) + + if (is.null(chats)) { + completed <- FALSE + for (i in seq_len(60)) { + Sys.sleep(10) + completed <- isTRUE(batch_chat_completed(chat, prompts, results_file)) + if (completed) break + } + + if (!completed) { + skip("Groq batch did not complete within test timeout.") + } + + chats <- batch_chat( + chat, + prompts = prompts, + path = results_file, + wait = TRUE + ) + } + + expect_equal(length(chats), 1) + expect_true(inherits(chats[[1]], "Chat")) +}) + +test_that("Groq batch_chat_structured works", { + skip_if( + Sys.getenv("GROQ_API_KEY") == "", + "No Groq credentials set" + ) + + chat <- chat_groq( + system_prompt = "Reply concisely", + model = "llama-3.1-8b-instant" + ) + + type_answer <- type_object( + answer = type_string() + ) + + prompts <- list("What is 2+2? Reply with just the number.") + results_file <- withr::local_tempfile(fileext = ".json") + + result <- tryCatch( + batch_chat_structured( + chat, + prompts = prompts, + path = results_file, + type = type_answer, + wait = FALSE + ), + error = function(e) { + msg <- conditionMessage(e) + if (grepl("unexpected number of responses", msg, fixed = TRUE)) { + NULL + } else { + stop(e) + } + } + ) + + if (is.null(result)) { + completed <- FALSE + for (i in seq_len(60)) { + Sys.sleep(10) + completed <- isTRUE(batch_chat_completed(chat, prompts, results_file)) + if (completed) break + } + + if (!completed) { + skip("Groq batch did not complete within test timeout.") + } + + result <- batch_chat_structured( + chat, + prompts = prompts, + path = results_file, + type = type_answer, + wait = TRUE + ) + } + + expect_true(is.data.frame(result)) + expect_equal(nrow(result), 1) + expect_true("answer" %in% names(result)) +}) diff --git a/tests/testthat/test-provider-groq.R b/tests/testthat/test-provider-groq.R index e2f5d05c6..015cde975 100644 --- a/tests/testthat/test-provider-groq.R +++ b/tests/testthat/test-provider-groq.R @@ -8,3 +8,8 @@ test_that("supports tool calling", { chat_fun <- chat_groq test_tools_simple(chat_fun) }) + +test_that("can extract data", { + chat_fun <- function(...) chat_groq(model = "openai/gpt-oss-20b", ...) + test_data_extraction(chat_fun) +})