Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# ellmer (development version)

* `batch_chat()` now supports `chat_groq()` for batch processing via the
Groq batch API. `chat_groq()` now also supports structured data extraction
(@xmarquez, #914).
* ellmer will now distinguish text content from thinking content while streaming, allowing downstream packages like shinychat to provide specific UI for thinking content (@simonpcouch, #909).
* `chat_github()` now uses `chat_openai_compatible()` for improved compatibility, and `models_github()` now supports custom `base_url` configuration (@D-M4rk, #877).
* `chat_ollama()` now contains a slot for `top_k` within the `params` argument (@frankiethull).
Expand Down
9 changes: 5 additions & 4 deletions R/batch-chat.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
#'
#' @description
#' `batch_chat()` and `batch_chat_structured()` currently only work with
#' [chat_openai()] and [chat_anthropic()]. They use the
#' [OpenAI](https://platform.openai.com/docs/guides/batch) and
#' [Anthropic](https://docs.claude.com/en/docs/build-with-claude/batch-processing)
#' batch APIs which allow you to submit multiple requests simultaneously.
#' [chat_openai()], [chat_anthropic()], and [chat_groq()]. They use the
#' [OpenAI](https://platform.openai.com/docs/guides/batch),
#' [Anthropic](https://docs.claude.com/en/docs/build-with-claude/batch-processing),
#' and [Groq](https://console.groq.com/docs/batch) batch APIs which allow
#' you to submit multiple requests simultaneously.
#' The results can take up to 24 hours to complete, but in return you pay 50%
#' less than usual (but note that ellmer doesn't include this discount in
#' its pricing metadata). If you want to get results back more quickly, or
Expand Down
196 changes: 186 additions & 10 deletions R/provider-groq.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,6 @@ NULL
#'
#' Built on top of [chat_openai_compatible()].
#'
#' ## Known limitations
#'
#' groq does not currently support structured data extraction.
#'
#' @export
#' @family chatbots
#' @param api_key `r lifecycle::badge("deprecated")` Use `credentials` instead.
Expand Down Expand Up @@ -69,10 +65,16 @@ method(as_json, list(ProviderGroq, Turn)) <- function(provider, x, ...) {
is_tool <- map_lgl(x@contents, is_tool_request)
tool_calls <- as_json(provider, x@contents[is_tool], ...)

# Grok contents is just a string. Hopefully it never sends back more
# than a single text response.
if (any(!is_tool)) {
content <- x@contents[!is_tool][[1]]@text
# Groq assistant content is just a string (not a list of content parts).
non_tool <- x@contents[!is_tool]
if (length(non_tool) > 0) {
first <- non_tool[[1]]
if (S7_inherits(first, ContentJson)) {
content <- first@string %||%
unclass(jsonlite::toJSON(first@data, auto_unbox = TRUE))
} else {
content <- first@text
}
} else {
content <- NULL
}
Expand All @@ -95,12 +97,23 @@ method(as_json, list(ProviderGroq, TypeObject)) <- function(provider, x, ...) {
}
required <- map_lgl(x@properties, function(prop) prop@required)

compact(list(
schema <- compact(list(
type = "object",
description = x@description,
properties = as_json(provider, x@properties, ...),
required = as.list(names2(x@properties)[required])
required = as.list(names2(x@properties)[required]),
additionalProperties = FALSE
))
add_additional_properties_false(schema)
}

method(as_json, list(ProviderGroq, TypeArray)) <- function(provider, x, ...) {
schema <- compact(list(
type = "array",
description = x@description,
items = as_json(provider, x@items, ...)
))
add_additional_properties_false(schema)
}

method(as_json, list(ProviderGroq, ToolDef)) <- function(provider, x, ...) {
Expand All @@ -117,3 +130,166 @@ method(as_json, list(ProviderGroq, ToolDef)) <- function(provider, x, ...) {
groq_key <- function() {
key_get("GROQ_API_KEY")
}

# Structured output helpers ------------------------------------------------

#' Recursively add additionalProperties: false to all objects
#'
#' Groq requires `additionalProperties: false` on all objects for strict JSON
#' validation. This helper ensures all nested objects have this property set.
#' @param node A list representing a JSON schema node
#' @return The modified node with additionalProperties: false on all objects
#' @noRd
add_additional_properties_false <- function(node) {
if (is.list(node) && !is.null(node$type) && identical(node$type, "object")) {
node$additionalProperties <- FALSE
if (!is.null(node$properties) && is.list(node$properties)) {
node$properties <- lapply(
node$properties,
add_additional_properties_false
)
}
}
if (is.list(node) && !is.null(node$items)) {
node$items <- add_additional_properties_false(node$items)
}
node
}

# Batched requests ---------------------------------------------------------

# https://console.groq.com/docs/batch
method(has_batch_support, ProviderGroq) <- function(provider) {
TRUE
}

method(batch_submit, ProviderGroq) <- function(
provider,
conversations,
type = NULL
) {
path <- withr::local_tempfile(fileext = ".jsonl")

requests <- map(seq_along(conversations), function(i) {
body <- chat_body(
provider,
stream = FALSE,
turns = conversations[[i]],
type = type
)
list(
custom_id = paste0("chat-", i),
method = "POST",
url = "/v1/chat/completions",
body = body
)
})
json <- map_chr(requests, to_json)
writeLines(json, path)

uploaded <- groq_upload_file(provider, path)

req <- base_request(provider)
req <- req_url_path_append(req, "/batches")
req <- req_body_json(
req,
list(
input_file_id = uploaded$id,
endpoint = "/v1/chat/completions",
completion_window = "24h"
)
)

resp <- req_perform(req)
resp_body_json(resp)
}

method(batch_poll, ProviderGroq) <- function(provider, batch) {
req <- base_request(provider)
req <- req_url_path_append(req, "/batches/", batch$id)

resp <- req_perform(req)
resp_body_json(resp)
}

method(batch_status, ProviderGroq) <- function(provider, batch) {
terminal_states <- c("completed", "failed", "expired", "cancelled")

total <- batch$request_counts$total %||% 0L
completed <- batch$request_counts$completed %||% 0L
failed <- batch$request_counts$failed %||% 0L

list(
working = !(batch$status %in% terminal_states),
n_processing = max(total - completed - failed, 0L),
n_succeeded = completed,
n_failed = failed
)
}

method(batch_retrieve, ProviderGroq) <- function(provider, batch) {
json <- list()

if (length(batch$output_file_id) == 1 && nzchar(batch$output_file_id)) {
path_output <- withr::local_tempfile()
groq_download_file(provider, batch$output_file_id, path_output)
json <- read_ndjson(path_output, fallback = groq_json_fallback)
}

if (length(batch$error_file_id) == 1 && nzchar(batch$error_file_id)) {
path_error <- withr::local_tempfile()
groq_download_file(provider, batch$error_file_id, path_error)
json <- c(json, read_ndjson(path_error, fallback = groq_json_fallback))
}

ids <- as.numeric(gsub("chat-", "", map_chr(json, "[[", "custom_id")))
results <- lapply(json, "[[", "response")
results[order(ids)]
}

method(batch_result_turn, ProviderGroq) <- function(
provider,
result,
has_type = FALSE
) {
if (result$status_code == 200) {
value_turn(provider, result$body, has_type = has_type)
} else {
NULL
}
}

# Batch helpers ------------------------------------------------------------

#' @noRd
groq_upload_file <- function(provider, path, purpose = "batch") {
req <- base_request(provider)
req <- req_url_path_append(req, "/files")
req <- req_body_multipart(
req,
purpose = purpose,
file = curl::form_file(path)
)
req <- req_progress(req, "up")

resp <- req_perform(req)
resp_body_json(resp)
}

#' @noRd
groq_download_file <- function(provider, id, path) {
req <- base_request(provider)
req <- req_url_path_append(req, "/files/", id, "/content")
req <- req_progress(req, "down")
req_perform(req, path = path)

invisible(path)
}

#' @noRd
groq_json_fallback <- function(line) {
list(
custom_id = extract_custom_id(line),
response = list(status_code = 500)
)
}
9 changes: 5 additions & 4 deletions man/batch_chat.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 0 additions & 4 deletions man/chat_groq.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading