diff --git a/package.json b/package.json index 83c0cb639..62f627522 100644 --- a/package.json +++ b/package.json @@ -97,6 +97,7 @@ "openai": "^4.104.0", "react": "18.3.1", "semver": "^7.8.2", + "sharp": "^0.34.5", "shell-quote": "^1.8.4", "spawn-rx": "^5.1.2", "string-width": "^7.2.0", diff --git a/src/services/ai/adapters/chatCompletions.ts b/src/services/ai/adapters/chatCompletions.ts index 3cf3033b7..9f7fc06af 100644 --- a/src/services/ai/adapters/chatCompletions.ts +++ b/src/services/ai/adapters/chatCompletions.ts @@ -7,6 +7,10 @@ import { import { Tool, getToolDescription } from '@tool' import { zodToJsonSchema } from 'zod-to-json-schema' import { setRequestStatus } from '@utils/session/requestStatus' +import { + extractTextAndImageUrls, + toOpenAIImageUrlParts, +} from '@utils/model/visionContent' export class ChatCompletionsAdapter extends OpenAIAdapter { createRequest(params: UnifiedRequestParams): any { @@ -120,33 +124,44 @@ export class ChatCompletionsAdapter extends OpenAIAdapter { return [] } - return messages.map(msg => { + const normalized: any[] = [] + + for (const msg of messages) { if (!msg || typeof msg !== 'object') { - return msg + normalized.push(msg) + continue } if (msg.role === 'tool') { - if (Array.isArray(msg.content)) { - return { - ...msg, - content: - msg.content - .map(c => c?.text || '') - .filter(Boolean) - .join('\n\n') || '(empty content)', - } - } else if (typeof msg.content !== 'string') { - return { - ...msg, - content: - msg.content === null || msg.content === undefined - ? '(empty content)' - : JSON.stringify(msg.content), - } + const { text, imageUrls } = extractTextAndImageUrls(msg.content) + normalized.push({ + ...msg, + content: + text || + (imageUrls.length > 0 + ? '(image output attached in following message)' + : '(empty content)'), + }) + + if (imageUrls.length > 0) { + normalized.push({ + role: 'user', + content: [ + { + type: 'text', + text: `Image output from tool ${msg.tool_call_id || msg.id || 'unknown'}:`, + }, + ...toOpenAIImageUrlParts(imageUrls), + ], + }) } + continue } - return msg - }) + + normalized.push(msg) + } + + return normalized } protected async *processStreamingChunk( diff --git a/src/services/ai/adapters/responsesAPI.ts b/src/services/ai/adapters/responsesAPI.ts index 3212f6582..2c72d47f0 100644 --- a/src/services/ai/adapters/responsesAPI.ts +++ b/src/services/ai/adapters/responsesAPI.ts @@ -9,6 +9,11 @@ import { zodToJsonSchema } from 'zod-to-json-schema' import { processResponsesStream } from './responsesStreaming' import { debug as debugLogger } from '@utils/log/debugLogger' import { logError } from '@utils/log' +import { + extractTextAndImageUrls, + getImageUrlFromPart, + toResponsesImageParts, +} from '@utils/model/visionContent' export class ResponsesAPIAdapter extends OpenAIAdapter { createRequest(params: UnifiedRequestParams): any { @@ -387,26 +392,12 @@ ${reasoningContent} if (role === 'tool') { const callId = message.tool_call_id || message.id if (typeof callId === 'string' && callId) { - let content = message.content || '' - if (Array.isArray(content)) { - const texts = [] - for (const part of content) { - if (typeof part === 'object' && part !== null) { - const t = part.text || part.content - if (typeof t === 'string' && t) { - texts.push(t) - } - } - } - content = texts.join('\n') - } - if (typeof content === 'string') { - inputItems.push({ - type: 'function_call_output', - call_id: callId, - output: content, - }) - } + const output = this.convertToolOutput(message.content) + inputItems.push({ + type: 'function_call_output', + call_id: callId, + output, + }) } continue } @@ -456,11 +447,19 @@ ${reasoningContent} contentItems.push({ type: kind, text: text }) } } else if (ptype === 'image_url') { - const image = part.image_url - const url = - typeof image === 'object' && image !== null ? image.url : image - if (typeof url === 'string' && url) { - contentItems.push({ type: 'input_image', image_url: url }) + const imageUrl = getImageUrlFromPart(part) + if (imageUrl) { + contentItems.push({ type: 'input_image', image_url: imageUrl }) + } + } else if (ptype === 'image') { + const imageUrl = getImageUrlFromPart(part) + if (imageUrl) { + contentItems.push({ type: 'input_image', image_url: imageUrl }) + } + } else if (ptype === 'input_image') { + const imageUrl = getImageUrlFromPart(part) + if (imageUrl) { + contentItems.push({ type: 'input_image', image_url: imageUrl }) } } } @@ -482,6 +481,20 @@ ${reasoningContent} return inputItems } + private convertToolOutput(content: unknown): string | any[] { + const { text, imageUrls } = extractTextAndImageUrls(content) + if (imageUrls.length === 0) { + return text + } + + const output: any[] = [] + if (text) { + output.push({ type: 'input_text', text }) + } + output.push(...toResponsesImageParts(imageUrls)) + return output + } + private buildInstructions(systemPrompt: string[]): string { const systemContent = systemPrompt .filter(content => content.trim()) diff --git a/src/tools/filesystem/FileReadTool/FileReadTool.tsx b/src/tools/filesystem/FileReadTool/FileReadTool.tsx index 30c9a7fb1..785b50a0f 100644 --- a/src/tools/filesystem/FileReadTool/FileReadTool.tsx +++ b/src/tools/filesystem/FileReadTool/FileReadTool.tsx @@ -25,17 +25,27 @@ import { import { DESCRIPTION, PROMPT } from './prompt' import { hasReadPermission } from '@utils/permissions/filesystem' import { secureFileService } from '@utils/fs/secureFile' -import { readFileBun, fileExistsBun, getFileSizeBun } from '@utils/bun/file' +import { readFileBun } from '@utils/bun/file' import { - type AnthropicImageMediaType, - normalizeImageMediaType, -} from '@utils/ai/anthropic' + detectImageMediaType, + isSvgBuffer, + isSvgExtension, + rasterizeSvgToPng, + type SupportedImageMediaType, +} from '@utils/image/media' const MAX_LINES_TO_RENDER = 5 const MAX_LINE_LENGTH = 2000 const MAX_OUTPUT_SIZE = 0.25 * 1024 * 1024 -const IMAGE_EXTENSIONS = new Set(['.png', '.jpg', '.jpeg', '.gif', '.webp']) +const IMAGE_EXTENSIONS = new Set([ + '.png', + '.jpg', + '.jpeg', + '.gif', + '.webp', + '.svg', +]) const MAX_WIDTH = 2000 const MAX_HEIGHT = 2000 @@ -433,7 +443,7 @@ export const FileReadTool = { type: 'image' file: { base64: string - type: AnthropicImageMediaType + type: SupportedImageMediaType originalSize: number } } @@ -449,30 +459,21 @@ const formatFileSizeError = (sizeInBytes: number) => function createImageResponse( buffer: Buffer, - ext: string, + mediaType: SupportedImageMediaType, originalSize: number, ): { type: 'image' file: { base64: string - type: AnthropicImageMediaType + type: SupportedImageMediaType originalSize: number } } { - const normalized = normalizeImageMediaType( - ext === '.jpg' || ext === '.jpeg' - ? 'image/jpeg' - : ext === '.png' - ? 'image/png' - : ext === '.gif' - ? 'image/gif' - : 'image/webp', - ) return { type: 'image', file: { base64: buffer.toString('base64'), - type: normalized, + type: mediaType, originalSize, }, } @@ -485,31 +486,46 @@ async function readImage( type: 'image' file: { base64: string - type: AnthropicImageMediaType + type: SupportedImageMediaType originalSize: number } }> { try { const stats = statSync(filePath) - const sharpModule = (await import('sharp')) as any - const sharp = sharpModule.default || sharpModule - const fileReadResult = secureFileService.safeReadFile(filePath, { encoding: 'buffer' as BufferEncoding, maxFileSize: MAX_IMAGE_SIZE, + checkFileExtension: false, }) if (!fileReadResult.success) { throw new Error(`Failed to read image file: ${fileReadResult.error}`) } - const image = sharp(fileReadResult.content as Buffer) + const inputBuffer = fileReadResult.content as Buffer + + if (isSvgExtension(ext) || isSvgBuffer(inputBuffer)) { + const rasterized = await rasterizeSvgToPng(inputBuffer) + return createImageResponse(rasterized, 'image/png', stats.size) + } + + const detectedMediaType = detectImageMediaType(inputBuffer) + if (!detectedMediaType) { + throw new Error( + 'Unsupported image format. Supported image formats are PNG, JPEG, GIF, WebP, and SVG.', + ) + } + + const sharpModule = (await import('sharp')) as any + const sharp = sharpModule.default || sharpModule + + const image = sharp(inputBuffer) const metadata = await image.metadata() if (!metadata.width || !metadata.height) { if (stats.size > MAX_IMAGE_SIZE) { const compressedBuffer = await image.jpeg({ quality: 80 }).toBuffer() - return createImageResponse(compressedBuffer, '.jpeg', stats.size) + return createImageResponse(compressedBuffer, 'image/jpeg', stats.size) } } @@ -521,20 +537,7 @@ async function readImage( width <= MAX_WIDTH && height <= MAX_HEIGHT ) { - const fileReadResult = secureFileService.safeReadFile(filePath, { - encoding: 'buffer' as BufferEncoding, - maxFileSize: MAX_IMAGE_SIZE, - }) - - if (!fileReadResult.success) { - throw new Error(`Failed to read image file: ${fileReadResult.error}`) - } - - return createImageResponse( - fileReadResult.content as Buffer, - ext, - stats.size, - ) + return createImageResponse(inputBuffer, detectedMediaType, stats.size) } if (width > MAX_WIDTH) { @@ -556,10 +559,14 @@ async function readImage( if (resizedImageBuffer.length > MAX_IMAGE_SIZE) { const compressedBuffer = await image.jpeg({ quality: 80 }).toBuffer() - return createImageResponse(compressedBuffer, '.jpeg', stats.size) + return createImageResponse(compressedBuffer, 'image/jpeg', stats.size) } - return createImageResponse(resizedImageBuffer, ext, stats.size) + return createImageResponse( + resizedImageBuffer, + detectedMediaType, + stats.size, + ) } catch (e) { logError(e) const stats = statSync(filePath) @@ -572,10 +579,14 @@ async function readImage( throw new Error(`Failed to read image file: ${fileReadResult.error}`) } - return createImageResponse( - fileReadResult.content as Buffer, - ext, - stats.size, - ) + const buffer = fileReadResult.content as Buffer + const detectedMediaType = detectImageMediaType(buffer) + if (!detectedMediaType) { + throw new Error( + 'Unsupported image format. Supported image formats are PNG, JPEG, GIF, WebP, and SVG.', + ) + } + + return createImageResponse(buffer, detectedMediaType, stats.size) } } diff --git a/src/ui/components/PromptInput.tsx b/src/ui/components/PromptInput.tsx index 506c7d581..4bc4fc58d 100644 --- a/src/ui/components/PromptInput.tsx +++ b/src/ui/components/PromptInput.tsx @@ -33,6 +33,7 @@ import { CompactModeIndicator } from '@components/ModeIndicator' import { getPromptInputSpecialKeyAction } from '@utils/terminal/promptInputSpecialKey' import { logStartupProfile } from '@utils/config/startupProfile' import { useStatusLine } from '@hooks/useStatusLine' +import type { ClipboardImage } from '@utils/image/media' async function interpretHashCommand(input: string): Promise { try { @@ -498,13 +499,13 @@ function PromptInput({ } } - function onImagePaste(image: string): string { + function onImagePaste(image: ClipboardImage): string { onModeChange('prompt') const placeholder = `[Image #${pastedImageCounter.current}]` pastedImageCounter.current += 1 setPastedImages(prev => [ ...prev, - { placeholder, data: image, mediaType: 'image/png' }, + { placeholder, data: image.data, mediaType: image.mediaType }, ]) return placeholder } diff --git a/src/ui/components/TextInput.tsx b/src/ui/components/TextInput.tsx index 1ad8bb0c9..ef17ed4bf 100644 --- a/src/ui/components/TextInput.tsx +++ b/src/ui/components/TextInput.tsx @@ -9,6 +9,7 @@ import { shouldTreatAsSpecialPaste, shouldAggregatePasteChunk, } from '@utils/terminal/paste' +import type { ClipboardImage } from '@utils/image/media' const BRACKETED_PASTE_ENABLE = '\x1b[?2004h' const BRACKETED_PASTE_DISABLE = '\x1b[?2004l' @@ -73,7 +74,7 @@ export type Props = { readonly columns: number - readonly onImagePaste?: (base64Image: string) => string | void + readonly onImagePaste?: (image: ClipboardImage) => string | void readonly onPaste?: (text: string) => void diff --git a/src/ui/hooks/useTextInput.ts b/src/ui/hooks/useTextInput.ts index beef89bcd..3c7399ba8 100644 --- a/src/ui/hooks/useTextInput.ts +++ b/src/ui/hooks/useTextInput.ts @@ -7,6 +7,7 @@ import { CLIPBOARD_ERROR_MESSAGE, } from '@utils/terminal/imagePaste' import { normalizeLineEndings } from '@utils/terminal/paste' +import type { ClipboardImage } from '@utils/image/media' const IMAGE_PLACEHOLDER = '[Image pasted]' @@ -38,7 +39,7 @@ type UseTextInputProps = { invert: (text: string) => string themeText: (text: string) => string columns: number - onImagePaste?: (base64Image: string) => string | void + onImagePaste?: (image: ClipboardImage) => string | void disableCursorMovementForUpDownKeys?: boolean externalOffset: number onOffsetChange: (offset: number) => void @@ -134,13 +135,10 @@ export function useTextInput({ return cursor } - const base64Image = getImageFromClipboard() - if (base64Image === null) { - if (process.platform !== 'darwin') { - return cursor - } - onMessage?.(true, CLIPBOARD_ERROR_MESSAGE) + const image = getImageFromClipboard() + if (image === null) { maybeClearImagePasteErrorTimeout() + onMessage?.(true, CLIPBOARD_ERROR_MESSAGE) setImagePasteErrorTimeout( setTimeout(() => { onMessage?.(false) @@ -149,7 +147,7 @@ export function useTextInput({ return cursor } - const placeholder = onImagePaste?.(base64Image) + const placeholder = onImagePaste?.(image) return cursor.insert( typeof placeholder === 'string' ? placeholder : IMAGE_PLACEHOLDER, ) diff --git a/src/utils/image/media.ts b/src/utils/image/media.ts new file mode 100644 index 000000000..3dfb6e4ac --- /dev/null +++ b/src/utils/image/media.ts @@ -0,0 +1,153 @@ +export type SupportedImageMediaType = + | 'image/png' + | 'image/jpeg' + | 'image/gif' + | 'image/webp' + +export type ClipboardImage = { + data: string + mediaType: SupportedImageMediaType +} + +export const SUPPORTED_IMAGE_MEDIA_TYPES: readonly SupportedImageMediaType[] = [ + 'image/png', + 'image/jpeg', + 'image/gif', + 'image/webp', +] as const + +export const SVG_MEDIA_TYPE = 'image/svg+xml' + +export function normalizeSupportedImageMediaType( + mediaType: unknown, +): SupportedImageMediaType | null { + if (typeof mediaType !== 'string') { + return null + } + + const normalized = mediaType.trim().toLowerCase() + if (normalized === 'image/jpg') { + return 'image/jpeg' + } + + return SUPPORTED_IMAGE_MEDIA_TYPES.includes( + normalized as SupportedImageMediaType, + ) + ? (normalized as SupportedImageMediaType) + : null +} + +export function detectImageMediaType( + input: Buffer | Uint8Array, +): SupportedImageMediaType | null { + const buffer = Buffer.isBuffer(input) ? input : Buffer.from(input) + + if ( + buffer.length >= 8 && + buffer[0] === 0x89 && + buffer[1] === 0x50 && + buffer[2] === 0x4e && + buffer[3] === 0x47 && + buffer[4] === 0x0d && + buffer[5] === 0x0a && + buffer[6] === 0x1a && + buffer[7] === 0x0a + ) { + return 'image/png' + } + + if ( + buffer.length >= 3 && + buffer[0] === 0xff && + buffer[1] === 0xd8 && + buffer[2] === 0xff + ) { + return 'image/jpeg' + } + + if ( + buffer.length >= 6 && + (buffer.subarray(0, 6).toString('ascii') === 'GIF87a' || + buffer.subarray(0, 6).toString('ascii') === 'GIF89a') + ) { + return 'image/gif' + } + + if ( + buffer.length >= 12 && + buffer.subarray(0, 4).toString('ascii') === 'RIFF' && + buffer.subarray(8, 12).toString('ascii') === 'WEBP' + ) { + return 'image/webp' + } + + return null +} + +export function getImageMediaTypeFromExtension( + ext: string, +): SupportedImageMediaType | null { + switch (ext.toLowerCase()) { + case '.png': + return 'image/png' + case '.jpg': + case '.jpeg': + return 'image/jpeg' + case '.gif': + return 'image/gif' + case '.webp': + return 'image/webp' + default: + return null + } +} + +export function imageBase64ToDataUrl( + data: string, + mediaType: SupportedImageMediaType, +): string { + return `data:${mediaType};base64,${data}` +} + +export function imageBufferToDataUrl( + buffer: Buffer | Uint8Array, + mediaType = detectImageMediaType(buffer), +): string | null { + if (!mediaType) { + return null + } + + const data = Buffer.isBuffer(buffer) + ? buffer.toString('base64') + : Buffer.from(buffer).toString('base64') + return imageBase64ToDataUrl(data, mediaType) +} + +export function isSvgExtension(ext: string): boolean { + return ext.toLowerCase() === '.svg' +} + +export function isSvgBuffer(input: Buffer | Uint8Array): boolean { + const buffer = Buffer.isBuffer(input) ? input : Buffer.from(input) + const prefix = buffer + .subarray(0, Math.min(buffer.length, 1024)) + .toString('utf8') + .replace(/^\uFEFF/, '') + .trimStart() + .toLowerCase() + + return ( + prefix.startsWith(' { + const sharpModule = (await import('sharp')) as any + const sharp = sharpModule.default || sharpModule + return await sharp(Buffer.isBuffer(input) ? input : Buffer.from(input)) + .png() + .toBuffer() +} diff --git a/src/utils/model/openaiMessageConversion.ts b/src/utils/model/openaiMessageConversion.ts index 5dbe8ba97..664406417 100644 --- a/src/utils/model/openaiMessageConversion.ts +++ b/src/utils/model/openaiMessageConversion.ts @@ -1,4 +1,9 @@ import OpenAI from 'openai' +import { + extractTextAndImageUrls, + getImageUrlFromPart, + toOpenAIImageUrlParts, +} from '@utils/model/visionContent' type AnthropicImageBlock = { type: 'image' @@ -42,7 +47,13 @@ export function convertAnthropicMessagesToOpenAIMessages( )[] { const openaiMessages: any[] = [] - const toolResults: Record = {} + const toolResults: Record< + string, + { + toolMessage: OpenAI.ChatCompletionToolMessageParam + imageMessage?: OpenAI.ChatCompletionUserMessageParam + } + > = {} for (const message of messages) { const blocks: AnthropicBlock[] = [] @@ -74,18 +85,11 @@ export function convertAnthropicMessagesToOpenAIMessages( } if (block.type === 'image' && role === 'user') { - const source = (block as AnthropicImageBlock).source - if (source?.type === 'base64') { - userContentParts.push({ - type: 'image_url', - image_url: { - url: `data:${source.media_type};base64,${source.data}`, - }, - }) - } else if (source?.type === 'url') { + const imageUrl = getImageUrlFromPart(block as any) + if (imageUrl) { userContentParts.push({ type: 'image_url', - image_url: { url: source.url }, + image_url: { url: imageUrl }, }) } continue @@ -106,15 +110,33 @@ export function convertAnthropicMessagesToOpenAIMessages( if (block.type === 'tool_result') { const toolUseId = (block as AnthropicToolResultBlock).tool_use_id const rawToolContent = (block as AnthropicToolResultBlock).content + const { text, imageUrls } = extractTextAndImageUrls(rawToolContent) const toolContent = - typeof rawToolContent === 'string' - ? rawToolContent - : JSON.stringify(rawToolContent) - toolResults[toolUseId] = { - role: 'tool', - content: toolContent, - tool_call_id: toolUseId, + text || (imageUrls.length > 0 ? '(image output attached)' : '') + const result: { + toolMessage: OpenAI.ChatCompletionToolMessageParam + imageMessage?: OpenAI.ChatCompletionUserMessageParam + } = { + toolMessage: { + role: 'tool', + content: toolContent, + tool_call_id: toolUseId, + }, + } + + if (imageUrls.length > 0) { + result.imageMessage = { + role: 'user', + content: [ + { + type: 'text', + text: `Image output from tool ${toolUseId}:`, + }, + ...toOpenAIImageUrlParts(imageUrls), + ], + } as any } + toolResults[toolUseId] = result continue } } @@ -157,8 +179,12 @@ export function convertAnthropicMessagesToOpenAIMessages( if ('tool_calls' in message && message.tool_calls) { for (const toolCall of message.tool_calls) { - if (toolResults[toolCall.id]) { - finalMessages.push(toolResults[toolCall.id]) + const result = toolResults[toolCall.id] + if (result) { + finalMessages.push(result.toolMessage) + if (result.imageMessage) { + finalMessages.push(result.imageMessage) + } } } } diff --git a/src/utils/model/visionContent.ts b/src/utils/model/visionContent.ts new file mode 100644 index 000000000..7fc3cd459 --- /dev/null +++ b/src/utils/model/visionContent.ts @@ -0,0 +1,114 @@ +import { + imageBase64ToDataUrl, + normalizeSupportedImageMediaType, +} from '@utils/image/media' + +export type ExtractedVisionContent = { + text: string + imageUrls: string[] +} + +export function extractTextAndImageUrls( + content: unknown, +): ExtractedVisionContent { + if (typeof content === 'string') { + return { text: content, imageUrls: [] } + } + + if (!Array.isArray(content)) { + if (content === null || content === undefined) { + return { text: '', imageUrls: [] } + } + return { text: JSON.stringify(content), imageUrls: [] } + } + + const textParts: string[] = [] + const imageUrls: string[] = [] + + for (const part of content) { + if (!part || typeof part !== 'object') { + continue + } + + const text = getTextFromPart(part) + if (text) { + textParts.push(text) + continue + } + + const imageUrl = getImageUrlFromPart(part) + if (imageUrl) { + imageUrls.push(imageUrl) + } + } + + return { + text: textParts.join('\n\n'), + imageUrls, + } +} + +export function getTextFromPart(part: Record): string | null { + const type = part.type + if (type !== 'text' && type !== 'input_text' && type !== 'output_text') { + return null + } + + const text = part.text ?? part.content + return typeof text === 'string' && text ? text : null +} + +export function getImageUrlFromPart(part: Record): string | null { + if (part.type === 'image_url') { + const image = part.image_url + const url = + image && typeof image === 'object' ? image.url : (image ?? part.url) + return typeof url === 'string' && url ? url : null + } + + if (part.type === 'input_image') { + const image = part.image_url + const url = + image && typeof image === 'object' ? image.url : (image ?? part.url) + return typeof url === 'string' && url ? url : null + } + + if (part.type !== 'image') { + return null + } + + const source = part.source + if (!source || typeof source !== 'object') { + return null + } + + if (source.type === 'url' && typeof source.url === 'string') { + return source.url + } + + if (source.type === 'base64' && typeof source.data === 'string') { + const mediaType = + normalizeSupportedImageMediaType(source.media_type) ?? 'image/png' + return imageBase64ToDataUrl(source.data, mediaType) + } + + return null +} + +export function toOpenAIImageUrlParts( + imageUrls: string[], +): Array<{ type: 'image_url'; image_url: { url: string } }> { + return imageUrls.map(url => ({ + type: 'image_url', + image_url: { url }, + })) +} + +export function toResponsesImageParts( + imageUrls: string[], +): Array<{ type: 'input_image'; image_url: string }> { + return imageUrls.map(url => ({ + type: 'input_image', + image_url: url, + })) +} diff --git a/src/utils/terminal/imagePaste.ts b/src/utils/terminal/imagePaste.ts index bb5e210f6..1e73800c4 100644 --- a/src/utils/terminal/imagePaste.ts +++ b/src/utils/terminal/imagePaste.ts @@ -1,33 +1,210 @@ -import { execSync } from 'child_process' -import { readFileSync } from 'fs' +import { execFileSync } from 'child_process' +import { readFileSync, unlinkSync } from 'fs' +import { join } from 'path' +import { tmpdir } from 'os' +import { + detectImageMediaType, + normalizeSupportedImageMediaType, + type ClipboardImage, + type SupportedImageMediaType, +} from '@utils/image/media' -const SCREENSHOT_PATH = '/tmp/kode_cli_latest_screenshot.png' +const CLIPBOARD_MAX_BUFFER = 20 * 1024 * 1024 export const CLIPBOARD_ERROR_MESSAGE = - 'No image found in clipboard. Use Cmd + Ctrl + Shift + 4 to copy a screenshot to clipboard.' + 'No compatible image found in clipboard. Copy a PNG, JPEG, GIF, or WebP image; on Linux install wl-paste or xclip.' -export function getImageFromClipboard(): string | null { - if (process.platform !== 'darwin') { +export function getImageFromClipboard(): ClipboardImage | null { + switch (process.platform) { + case 'darwin': + return getImageFromMacClipboard() + case 'win32': + return getImageFromWindowsClipboard() + case 'linux': + return getImageFromLinuxClipboard() + default: + return null + } +} + +function getImageFromMacClipboard(): ClipboardImage | null { + const screenshotPath = join( + tmpdir(), + `kode-cli-clipboard-${process.pid}-${Date.now()}.png`, + ) + + try { + execFileSync( + 'osascript', + [ + '-e', + 'set png_data to (the clipboard as «class PNGf»)', + '-e', + `set fp to open for access POSIX file "${escapeAppleScriptString( + screenshotPath, + )}" with write permission`, + '-e', + 'write png_data to fp', + '-e', + 'close access fp', + ], + { stdio: 'ignore', timeout: 3000 }, + ) + + const imageBuffer = readFileSync(screenshotPath) + return imageFromBuffer(imageBuffer) + } catch { + return null + } finally { + try { + unlinkSync(screenshotPath) + } catch {} + } +} + +function getImageFromWindowsClipboard(): ClipboardImage | null { + const script = ` +Add-Type -AssemblyName System.Windows.Forms +Add-Type -AssemblyName System.Drawing + +$files = [System.Windows.Forms.Clipboard]::GetFileDropList() +if ($files -and $files.Count -gt 0) { + $path = [string]$files[0] + if ([System.IO.File]::Exists($path)) { + [Console]::Out.Write([Convert]::ToBase64String([System.IO.File]::ReadAllBytes($path))) + exit 0 + } +} + +$image = [System.Windows.Forms.Clipboard]::GetImage() +if ($null -eq $image) { + exit 2 +} + +$stream = New-Object System.IO.MemoryStream +try { + $image.Save($stream, [System.Drawing.Imaging.ImageFormat]::Png) + [Console]::Out.Write([Convert]::ToBase64String($stream.ToArray())) +} finally { + $stream.Dispose() + $image.Dispose() +} +` + + try { + const output = execFileSync( + 'powershell.exe', + ['-NoProfile', '-NonInteractive', '-STA', '-Command', script], + { + encoding: 'utf8', + maxBuffer: CLIPBOARD_MAX_BUFFER, + stdio: ['ignore', 'pipe', 'ignore'], + timeout: 5000, + }, + ).trim() + + if (!output) { + return null + } + + return imageFromBuffer(Buffer.from(output, 'base64')) + } catch { return null } +} + +function getImageFromLinuxClipboard(): ClipboardImage | null { + return getImageFromWlPaste() ?? getImageFromXclip() +} +function getImageFromWlPaste(): ClipboardImage | null { try { - execSync(`osascript -e 'the clipboard as «class PNGf»'`, { - stdio: 'ignore', + const types = execFileSync('wl-paste', ['--list-types'], { + encoding: 'utf8', + timeout: 3000, + stdio: ['ignore', 'pipe', 'ignore'], }) + .split(/\r?\n/) + .filter(Boolean) + + const picked = pickClipboardMimeType(types) + if (!picked) { + return null + } - execSync( - `osascript -e 'set png_data to (the clipboard as «class PNGf»)' -e 'set fp to open for access POSIX file "${SCREENSHOT_PATH}" with write permission' -e 'write png_data to fp' -e 'close access fp'`, - { stdio: 'ignore' }, + const buffer = execFileSync( + 'wl-paste', + ['--no-newline', '--type', picked.target], + { + maxBuffer: CLIPBOARD_MAX_BUFFER, + timeout: 5000, + stdio: ['ignore', 'pipe', 'ignore'], + }, ) + return imageFromBuffer(buffer) + } catch { + return null + } +} - const imageBuffer = readFileSync(SCREENSHOT_PATH) - const base64Image = imageBuffer.toString('base64') +function getImageFromXclip(): ClipboardImage | null { + try { + const targets = execFileSync( + 'xclip', + ['-selection', 'clipboard', '-t', 'TARGETS', '-o'], + { + encoding: 'utf8', + timeout: 3000, + stdio: ['ignore', 'pipe', 'ignore'], + }, + ) + .split(/\r?\n/) + .filter(Boolean) - execSync(`rm -f "${SCREENSHOT_PATH}"`, { stdio: 'ignore' }) + const picked = pickClipboardMimeType(targets) + if (!picked) { + return null + } - return base64Image + const buffer = execFileSync( + 'xclip', + ['-selection', 'clipboard', '-t', picked.target, '-o'], + { + maxBuffer: CLIPBOARD_MAX_BUFFER, + timeout: 5000, + stdio: ['ignore', 'pipe', 'ignore'], + }, + ) + return imageFromBuffer(buffer) } catch { return null } } + +function imageFromBuffer(buffer: Buffer): ClipboardImage | null { + const mediaType = detectImageMediaType(buffer) + if (!mediaType) { + return null + } + + return { + data: buffer.toString('base64'), + mediaType, + } +} + +function pickClipboardMimeType( + types: string[], +): { target: string; mediaType: SupportedImageMediaType } | null { + for (const target of types) { + const mediaType = normalizeSupportedImageMediaType(target) + if (mediaType) { + return { target, mediaType } + } + } + return null +} + +function escapeAppleScriptString(value: string): string { + return value.replace(/\\/g, '\\\\').replace(/"/g, '\\"') +} diff --git a/tests/unit/chat-completions-e2e.test.ts b/tests/unit/chat-completions-e2e.test.ts index 4cfff27f7..cc4617200 100644 --- a/tests/unit/chat-completions-e2e.test.ts +++ b/tests/unit/chat-completions-e2e.test.ts @@ -119,5 +119,40 @@ describe('Chat Completions API Tests', () => { expect(hasUserMessage).toBe(true) expect(hasAssistantMessage).toBe(true) }) + + test('preserves tool result images in adjacent user vision message', () => { + const adapter = ModelAdapterFactory.createAdapter(testModel) + + const request = adapter.createRequest({ + messages: [ + { + role: 'tool', + tool_call_id: 'tool_123', + content: [ + { type: 'text', text: 'Screenshot captured' }, + { + type: 'image_url', + image_url: { url: 'data:image/gif;base64,Zm9v' }, + }, + ], + }, + ], + systemPrompt: ['You are helpful'], + maxTokens: 100, + }) + + const toolIndex = request.messages.findIndex( + (msg: any) => msg.role === 'tool', + ) + expect(toolIndex).toBeGreaterThanOrEqual(0) + expect(request.messages[toolIndex].content).toBe('Screenshot captured') + + const imageMessage = request.messages[toolIndex + 1] + expect(imageMessage.role).toBe('user') + expect(imageMessage.content).toContainEqual({ + type: 'image_url', + image_url: { url: 'data:image/gif;base64,Zm9v' }, + }) + }) }) }) diff --git a/tests/unit/image-media.test.ts b/tests/unit/image-media.test.ts new file mode 100644 index 000000000..7c371e2ad --- /dev/null +++ b/tests/unit/image-media.test.ts @@ -0,0 +1,47 @@ +import { describe, expect, test } from 'bun:test' +import { + detectImageMediaType, + imageBase64ToDataUrl, + imageBufferToDataUrl, + normalizeSupportedImageMediaType, +} from '@utils/image/media' + +const PNG_BYTES = Buffer.from([0x89, 0x50, 0x4e, 0x47, 0x0d, 0x0a, 0x1a, 0x0a]) +const JPEG_BYTES = Buffer.from([0xff, 0xd8, 0xff, 0xe0, 0x00, 0x10]) +const GIF_BYTES = Buffer.from('GIF89a', 'ascii') +const WEBP_BYTES = Buffer.concat([ + Buffer.from('RIFF', 'ascii'), + Buffer.from([0x00, 0x00, 0x00, 0x00]), + Buffer.from('WEBP', 'ascii'), +]) + +describe('image media helpers', () => { + test('detects supported raster image MIME types from magic bytes', () => { + expect(detectImageMediaType(PNG_BYTES)).toBe('image/png') + expect(detectImageMediaType(JPEG_BYTES)).toBe('image/jpeg') + expect(detectImageMediaType(GIF_BYTES)).toBe('image/gif') + expect(detectImageMediaType(WEBP_BYTES)).toBe('image/webp') + }) + + test('returns null for invalid or unknown bytes', () => { + expect(detectImageMediaType(Buffer.from('not an image'))).toBeNull() + expect(detectImageMediaType(Buffer.alloc(0))).toBeNull() + }) + + test('normalizes MIME aliases and rejects unsupported image types', () => { + expect(normalizeSupportedImageMediaType('image/jpg')).toBe('image/jpeg') + expect(normalizeSupportedImageMediaType('image/svg+xml')).toBeNull() + expect( + normalizeSupportedImageMediaType('application/octet-stream'), + ).toBeNull() + }) + + test('converts image data to data URLs with detected or explicit media type', () => { + expect(imageBufferToDataUrl(JPEG_BYTES)).toBe( + `data:image/jpeg;base64,${JPEG_BYTES.toString('base64')}`, + ) + expect(imageBase64ToDataUrl('Zm9v', 'image/webp')).toBe( + 'data:image/webp;base64,Zm9v', + ) + }) +}) diff --git a/tests/unit/openai-message-conversion.test.ts b/tests/unit/openai-message-conversion.test.ts index f4adccdd5..16aa60b8c 100644 --- a/tests/unit/openai-message-conversion.test.ts +++ b/tests/unit/openai-message-conversion.test.ts @@ -80,4 +80,54 @@ describe('openaiMessageConversion', () => { expect((converted[3] as any)?.role).toBe('assistant') expect((converted[3] as any)?.content).toBe('Done') }) + + test('preserves tool-result images as adjacent user vision messages', () => { + const messages: any[] = [ + { + message: { + role: 'assistant', + content: [ + { + type: 'tool_use', + id: 'tool_1', + name: 'Read', + input: { path: 'screenshot.png' }, + }, + ], + }, + }, + { + message: { + role: 'user', + content: [ + { + type: 'tool_result', + tool_use_id: 'tool_1', + content: [ + { type: 'text', text: 'Read image' }, + { + type: 'image', + source: { + type: 'base64', + media_type: 'image/jpeg', + data: 'Zm9v', + }, + }, + ], + }, + ], + }, + }, + ] + + const converted = convertAnthropicMessagesToOpenAIMessages(messages) + + expect((converted[1] as any)?.role).toBe('tool') + expect((converted[1] as any)?.content).toBe('Read image') + expect((converted[2] as any)?.role).toBe('user') + expect((converted[2] as any)?.content).toContainEqual({ + type: 'image_url', + image_url: { url: 'data:image/jpeg;base64,Zm9v' }, + }) + }) }) diff --git a/tests/unit/process-user-input-images.test.ts b/tests/unit/process-user-input-images.test.ts new file mode 100644 index 000000000..33bf5cb3a --- /dev/null +++ b/tests/unit/process-user-input-images.test.ts @@ -0,0 +1,45 @@ +import { describe, expect, test } from 'bun:test' +import { processUserInput } from '@utils/messages' + +const mockContext = { + abortController: new AbortController(), + messageId: 'test', + readFileTimestamps: {}, + options: { + commands: [], + tools: [], + verbose: false, + safeMode: false, + forkNumber: 0, + messageLogName: 'test', + maxThinkingTokens: 0, + }, + setForkConvoWithMessagesOnTheNextRender: () => {}, +} as any + +describe('processUserInput image attachments', () => { + test('keeps pasted JPEG media type in user image blocks', async () => { + const messages = await processUserInput( + 'please inspect [Image #1]', + 'prompt', + () => {}, + mockContext, + [ + { + placeholder: '[Image #1]', + data: 'anBlZw==', + mediaType: 'image/jpeg', + }, + ], + ) + + const content = messages[0]?.message.content as any[] + expect(Array.isArray(content)).toBe(true) + const imageBlock = content.find(block => block.type === 'image') + expect(imageBlock?.source).toMatchObject({ + type: 'base64', + media_type: 'image/jpeg', + data: 'anBlZw==', + }) + }) +}) diff --git a/tests/unit/responses-api-e2e.test.ts b/tests/unit/responses-api-e2e.test.ts index 731096414..91354e855 100644 --- a/tests/unit/responses-api-e2e.test.ts +++ b/tests/unit/responses-api-e2e.test.ts @@ -122,6 +122,75 @@ describe('Responses API Tests', () => { ) expect(hasFunctionCallOutput).toBe(true) }) + + test('converts Anthropic user image blocks to input_image content', () => { + const adapter = ModelAdapterFactory.createAdapter(testModel) + + const request = adapter.createRequest({ + messages: [ + { + role: 'user', + content: [ + { type: 'text', text: 'What is in this image?' }, + { + type: 'image', + source: { + type: 'base64', + media_type: 'image/jpeg', + data: 'Zm9v', + }, + }, + ], + }, + ], + systemPrompt: ['You are helpful'], + maxTokens: 100, + }) + + expect(request.input[0].content).toContainEqual({ + type: 'input_image', + image_url: 'data:image/jpeg;base64,Zm9v', + }) + }) + + test('converts tool result images to function_call_output arrays', () => { + const adapter = ModelAdapterFactory.createAdapter(testModel) + + const request = adapter.createRequest({ + messages: [ + { + role: 'tool', + tool_call_id: 'tool_123', + content: [ + { type: 'text', text: 'Screenshot captured' }, + { + type: 'image', + source: { + type: 'base64', + media_type: 'image/webp', + data: 'Zm9v', + }, + }, + ], + }, + ], + systemPrompt: ['You are helpful'], + maxTokens: 100, + }) + + const output = request.input.find( + (item: any) => item.type === 'function_call_output', + )?.output + expect(Array.isArray(output)).toBe(true) + expect(output).toContainEqual({ + type: 'input_text', + text: 'Screenshot captured', + }) + expect(output).toContainEqual({ + type: 'input_image', + image_url: 'data:image/webp;base64,Zm9v', + }) + }) }) describe('Responses API unique behaviors', () => { diff --git a/tests/unit/tools/file-read-tool-parity.test.ts b/tests/unit/tools/file-read-tool-parity.test.ts index d0dc28d09..8afe85856 100644 --- a/tests/unit/tools/file-read-tool-parity.test.ts +++ b/tests/unit/tools/file-read-tool-parity.test.ts @@ -1,9 +1,38 @@ -import { afterAll, describe, expect, test } from 'bun:test' +import { afterAll, describe, expect, mock, test } from 'bun:test' import { mkdtempSync, rmSync, writeFileSync } from 'fs' import { join } from 'path' import { FileReadTool } from '@tools/FileReadTool/FileReadTool' const tmpRoot = mkdtempSync(join(process.cwd(), '.tmp-test-file-read-tool-')) +const PNG_BYTES = Buffer.from([0x89, 0x50, 0x4e, 0x47, 0x0d, 0x0a, 0x1a, 0x0a]) +const JPEG_BYTES = Buffer.from([ + 0xff, 0xd8, 0xff, 0xe0, 0x00, 0x10, 0x4a, 0x46, 0x49, 0x46, +]) +const GIF_BYTES = Buffer.from('GIF89a', 'ascii') +const WEBP_BYTES = Buffer.concat([ + Buffer.from('RIFF', 'ascii'), + Buffer.from([0x00, 0x00, 0x00, 0x00]), + Buffer.from('WEBP', 'ascii'), +]) + +mock.module('sharp', () => { + function sharp(input?: Buffer) { + const api = { + metadata: async () => ({ width: 1, height: 1 }), + resize: () => api, + jpeg: () => ({ + toBuffer: async () => JPEG_BYTES, + }), + png: () => ({ + toBuffer: async () => PNG_BYTES, + }), + toBuffer: async () => input ?? PNG_BYTES, + } + return api + } + + return { default: sharp } +}) afterAll(() => { rmSync(tmpRoot, { recursive: true, force: true }) @@ -88,3 +117,47 @@ describe('FileReadTool parity: validateInput gating', () => { expect(result.message).toContain('Empty image files') }) }) + +describe('FileReadTool image handling', () => { + test('preserves detected JPEG media type even when extension differs', async () => { + const filePath = join(tmpRoot, 'mismatch.png') + writeFileSync(filePath, JPEG_BYTES) + + const data = await runRead({ file_path: filePath }) + expect(data?.type).toBe('image') + expect(data.file.type).toBe('image/jpeg') + expect(data.file.base64).toBe(JPEG_BYTES.toString('base64')) + }) + + test('preserves detected GIF media type', async () => { + const filePath = join(tmpRoot, 'image.gif') + writeFileSync(filePath, GIF_BYTES) + + const data = await runRead({ file_path: filePath }) + expect(data?.type).toBe('image') + expect(data.file.type).toBe('image/gif') + }) + + test('preserves detected WebP media type', async () => { + const filePath = join(tmpRoot, 'image.webp') + writeFileSync(filePath, WEBP_BYTES) + + const data = await runRead({ file_path: filePath }) + expect(data?.type).toBe('image') + expect(data.file.type).toBe('image/webp') + }) + + test('rasterizes SVG files and returns PNG image data', async () => { + const filePath = join(tmpRoot, 'vector.svg') + writeFileSync( + filePath, + '', + 'utf8', + ) + + const data = await runRead({ file_path: filePath }) + expect(data?.type).toBe('image') + expect(data.file.type).toBe('image/png') + expect(data.file.base64).toBe(PNG_BYTES.toString('base64')) + }) +})