diff --git a/mythic-docker/src/rabbitmq/util_agent_message_actions_post_response.go b/mythic-docker/src/rabbitmq/util_agent_message_actions_post_response.go index 0632048d1..4e5785617 100644 --- a/mythic-docker/src/rabbitmq/util_agent_message_actions_post_response.go +++ b/mythic-docker/src/rabbitmq/util_agent_message_actions_post_response.go @@ -322,35 +322,130 @@ func handleAgentMessagePostResponse(incoming *map[string]interface{}, uUIDInfo * ] } */ - agentMessage := agentMessagePostResponseMessage{} - err := mapstructure.Decode(incoming, &agentMessage) cachedTaskData := make(map[string]databaseStructs.Task) cachedFileData := make(map[string]databaseStructs.Filemeta) - if err != nil { - logging.LogError(err, "Failed to decode agent message into struct, ignoring and continuing on") - badMessageString, err2 := json.MarshalIndent(incoming, "", " ") - delete(*incoming, "responses") - if err2 != nil { - go SendAllOperationsMessage(fmt.Sprintf("Failed to process agent message: \n%s\n%s\n", err2.Error(), err.Error()), - uUIDInfo.OperationID, "agent_message_bad_post_response", database.MESSAGE_LEVEL_AGENT_MESSGAGE, true) - } else { - badMessage := fmt.Sprintf("Failed to process agent message:\n%s\n%s\n", err.Error(), string(badMessageString)) - go SendAllOperationsMessage(badMessage, - uUIDInfo.OperationID, "agent_message_bad_post_response", database.MESSAGE_LEVEL_AGENT_MESSGAGE, true) + responses := []map[string]interface{}{} + var err error + + // Extract the raw responses array first + responsesRaw, ok := (*incoming)["responses"] + if !ok { + logging.LogError(nil, "No responses array found in agent message") + return map[string]interface{}{}, nil + } + + responsesSlice, ok := responsesRaw.([]interface{}) + if !ok { + logging.LogError(nil, "Responses field is not an array") + return map[string]interface{}{}, nil + } + + // Decode the outer message structure to capture "Other" fields for reflection back + agentMessage := agentMessagePostResponseMessage{} + outerMessageCopy := make(map[string]interface{}) + for k, v := range *incoming { + if k != "responses" { + outerMessageCopy[k] = v } - // returning nil so that we can continue processing other message pieces if they exist - return map[string]interface{}{}, err } - responses := []map[string]interface{}{} - // iterate over the agent messages - for i, _ := range agentMessage.Responses { + if err := mapstructure.Decode(outerMessageCopy, &agentMessage); err != nil { + logging.LogError(err, "Failed to decode outer message fields") + } + + // Iterate over raw responses and decode each individually + for i, responseRaw := range responsesSlice { + responseMap, ok := responseRaw.(map[string]interface{}) + if !ok { + logging.LogError(nil, "Response is not a map", "index", i) + continue + } + + // Try to decode this individual response + singleResponse := agentMessagePostResponse{} + err := mapstructure.Decode(responseMap, &singleResponse) + + // Extract task_id for acknowledgement even if decode fails + taskID := "" + if taskIDRaw, ok := responseMap["task_id"]; ok { + taskID = fmt.Sprintf("%v", taskIDRaw) + } + if taskID == "" && err == nil { + taskID = singleResponse.TaskID + } + + // If decode failed, log error and send error acknowledgement + if err != nil { + logging.LogError(err, "Failed to decode individual response", "index", i, "task_id", taskID) + if taskID != "" { + // Save the decode error message before it gets shadowed by database query + decodeErrorMsg := err.Error() + + // Build error response preserving original task_id type and custom fields + // Note: status is "success" to acknowledge receipt (agent requirement) + // The "error" field contains the actual error details + errorResponse := map[string]interface{}{ + "status": "success", + "error": fmt.Sprintf("Failed to decode response: %s", decodeErrorMsg), + } + + // Preserve original task_id (don't convert to string, keep original type) + if taskIDRaw, ok := responseMap["task_id"]; ok { + errorResponse["task_id"] = taskIDRaw + } else { + errorResponse["task_id"] = taskID + } + + // Reflect back any custom fields using same function as success path + // This ensures consistent behavior between error and success responses + reflectBackOtherKeys(&errorResponse, &responseMap) + + responses = append(responses, errorResponse) + + // Look up the task in database to update its status to error + // This ensures the GUI shows the error instead of "pending" + errorTask := databaseStructs.Task{AgentTaskID: taskID} + if dbErr := database.DB.Get(&errorTask, `SELECT + task.id, task.status, task.completed, task.status_timestamp_processed, task.operator_id, task.operation_id, + task.stdout, task.stderr, task.display_id, + task.eventstepinstance_id, task.apitokens_id, + callback.host "callback.host", + callback.user "callback.user", + callback.id "callback.id", + callback.display_id "callback.display_id", + callback.agent_callback_id "callback.agent_callback_id", + callback.mythictree_groups "callback.mythictree_groups", + payload.payload_type_id "callback.payload.payload_type_id", + payload.os "callback.payload.os" + FROM task + JOIN callback ON task.callback_id = callback.id + JOIN payload ON callback.registered_payload_id = payload.id + WHERE task.agent_task_id=$1`, errorTask.AgentTaskID); dbErr != nil { + logging.LogError(dbErr, "Failed to find task for decode error update", "task id", taskID) + } else { + // Update task to show error status + errorTask.Status = "error: Failed to decode agent response" + errorTask.Stderr = fmt.Sprintf("Failed to decode response structure: %s", decodeErrorMsg) + errorTask.Completed = true + errorTask.Timestamp = time.Now().UTC() + if !errorTask.StatusTimestampProcessed.Valid { + errorTask.StatusTimestampProcessed.Time = errorTask.Timestamp + errorTask.StatusTimestampProcessed.Valid = true + } + // Add to cache so it gets updated in batch at the end of function + cachedTaskData[errorTask.AgentTaskID] = errorTask + } + } + continue + } + + // Successfully decoded - process this response normally mythicResponse := map[string]interface{}{ - "task_id": agentMessage.Responses[i].TaskID, + "task_id": singleResponse.TaskID, "status": "success", } //logging.LogDebug("Got response data from agent", "response data", agentResponse, "extra keys", agentResponse.Other) // every response should be tied to some task - currentTask := databaseStructs.Task{AgentTaskID: agentMessage.Responses[i].TaskID} + currentTask := databaseStructs.Task{AgentTaskID: singleResponse.TaskID} if _, ok := cachedTaskData[currentTask.AgentTaskID]; ok { currentTask = cachedTaskData[currentTask.AgentTaskID] } else { @@ -380,26 +475,26 @@ func handleAgentMessagePostResponse(incoming *map[string]interface{}, uUIDInfo * } // always process here - if agentMessage.Responses[i].Download != nil { + if singleResponse.Download != nil { fileMeta := databaseStructs.Filemeta{} - if agentMessage.Responses[i].Download.FileID != nil && *agentMessage.Responses[i].Download.FileID != "" { - if _, ok := cachedFileData[*agentMessage.Responses[i].Download.FileID]; ok { - fileMeta = cachedFileData[*agentMessage.Responses[i].Download.FileID] + if singleResponse.Download.FileID != nil && *singleResponse.Download.FileID != "" { + if _, ok := cachedFileData[*singleResponse.Download.FileID]; ok { + fileMeta = cachedFileData[*singleResponse.Download.FileID] } else { - fileMeta = databaseStructs.Filemeta{AgentFileID: *agentMessage.Responses[i].Download.FileID} - err = database.DB.Get(&fileMeta, `SELECT + fileMeta = databaseStructs.Filemeta{AgentFileID: *singleResponse.Download.FileID} + err = database.DB.Get(&fileMeta, `SELECT id, "path", total_chunks, chunks_received, host, is_screenshot, full_remote_path, complete, md5, sha1, filename, chunk_size, operation_id FROM filemeta - WHERE agent_file_id=$1`, *agentMessage.Responses[i].Download.FileID) + WHERE agent_file_id=$1`, *singleResponse.Download.FileID) if err != nil { - logging.LogError(err, "Failed to find fileID in agent download request", "fileid", *agentMessage.Responses[i].Download.FileID) + logging.LogError(err, "Failed to find fileID in agent download request", "fileid", *singleResponse.Download.FileID) continue } fileMeta.Task = &databaseStructs.Task{} fileMeta.Task.OperatorID = currentTask.OperatorID } } - newFileID, err := handleAgentMessagePostResponseDownload(¤tTask, &agentMessage.Responses[i], &fileMeta) + newFileID, err := handleAgentMessagePostResponseDownload(¤tTask, &singleResponse, &fileMeta) if err != nil { mythicResponse["status"] = "error" mythicResponse["error"] = err.Error() @@ -407,14 +502,14 @@ func handleAgentMessagePostResponse(incoming *map[string]interface{}, uUIDInfo * mythicResponse["file_id"] = newFileID cachedFileData[newFileID] = fileMeta } - if agentMessage.Responses[i].Download.ChunkNum != nil { - mythicResponse["chunk_num"] = *agentMessage.Responses[i].Download.ChunkNum + if singleResponse.Download.ChunkNum != nil { + mythicResponse["chunk_num"] = *singleResponse.Download.ChunkNum } } // always process here - if agentMessage.Responses[i].Upload != nil { - if uploadResponse, err := handleAgentMessagePostResponseUpload(currentTask, agentMessage.Responses[i]); err != nil { + if singleResponse.Upload != nil { + if uploadResponse, err := handleAgentMessagePostResponseUpload(currentTask, singleResponse); err != nil { mythicResponse["status"] = "error" mythicResponse["error"] = err.Error() logging.LogError(err, "Failed to handle agent upload") @@ -432,87 +527,87 @@ func handleAgentMessagePostResponse(incoming *map[string]interface{}, uUIDInfo * currentTask.StatusTimestampProcessed.Valid = true } // this section can happen async, but in order - if agentMessage.Responses[i].Completed != nil { - if *agentMessage.Responses[i].Completed { - currentTask.Completed = *agentMessage.Responses[i].Completed + if singleResponse.Completed != nil { + if *singleResponse.Completed { + currentTask.Completed = *singleResponse.Completed } } - if agentMessage.Responses[i].Status != nil && *agentMessage.Responses[i].Status != "" { + if singleResponse.Status != nil && *singleResponse.Status != "" { if currentTask.Status != PT_TASK_FUNCTION_STATUS_COMPLETED { - currentTask.Status = *agentMessage.Responses[i].Status + currentTask.Status = *singleResponse.Status } - } else if agentMessage.Responses[i].Completed != nil && *agentMessage.Responses[i].Completed { + } else if singleResponse.Completed != nil && *singleResponse.Completed { currentTask.Status = PT_TASK_FUNCTION_STATUS_COMPLETED } else if currentTask.Status == PT_TASK_FUNCTION_STATUS_PROCESSING { currentTask.Status = PT_TASK_FUNCTION_STATUS_PROCESSED } - if agentMessage.Responses[i].UserOutput != nil && *agentMessage.Responses[i].UserOutput != "" { + if singleResponse.UserOutput != nil && *singleResponse.UserOutput != "" { // do it in the background - the agent doesn't need the result of this directly //handleAgentMessagePostResponseUserOutput(currentTask, agentResponse, true) asyncAgentMessagePostResponseChannel <- agentAgentMessagePostResponseChannelMessage{ Task: currentTask, - Response: *agentMessage.Responses[i].UserOutput, - SequenceNum: agentMessage.Responses[i].SequenceNumber, + Response: *singleResponse.UserOutput, + SequenceNum: singleResponse.SequenceNumber, } } - if agentMessage.Responses[i].Stdout != nil { - currentTask.Stdout += *agentMessage.Responses[i].Stdout + if singleResponse.Stdout != nil { + currentTask.Stdout += *singleResponse.Stdout } - if agentMessage.Responses[i].Stderr != nil { - currentTask.Stderr = *agentMessage.Responses[i].Stderr + if singleResponse.Stderr != nil { + currentTask.Stderr = *singleResponse.Stderr } - if agentMessage.Responses[i].FileBrowser != nil { + if singleResponse.FileBrowser != nil { // do it in the background - the agent doesn't need the result of this directly - go HandleAgentMessagePostResponseFileBrowser(currentTask, agentMessage.Responses[i].FileBrowser, 0) + go HandleAgentMessagePostResponseFileBrowser(currentTask, singleResponse.FileBrowser, 0) } - if agentMessage.Responses[i].Processes != nil { - go HandleAgentMessagePostResponseProcesses(currentTask, agentMessage.Responses[i].Processes, 0) + if singleResponse.Processes != nil { + go HandleAgentMessagePostResponseProcesses(currentTask, singleResponse.Processes, 0) } - if agentMessage.Responses[i].RemovedFiles != nil { - go handleAgentMessagePostResponseRemovedFiles(currentTask, agentMessage.Responses[i].RemovedFiles) + if singleResponse.RemovedFiles != nil { + go handleAgentMessagePostResponseRemovedFiles(currentTask, singleResponse.RemovedFiles) } - if agentMessage.Responses[i].Credentials != nil { - go handleAgentMessagePostResponseCredentials(currentTask, agentMessage.Responses[i].Credentials) + if singleResponse.Credentials != nil { + go handleAgentMessagePostResponseCredentials(currentTask, singleResponse.Credentials) } - if agentMessage.Responses[i].Keylogs != nil { - go handleAgentMessagePostResponseKeylogs(currentTask, agentMessage.Responses[i].Keylogs) + if singleResponse.Keylogs != nil { + go handleAgentMessagePostResponseKeylogs(currentTask, singleResponse.Keylogs) } - if agentMessage.Responses[i].Tokens != nil && agentMessage.Responses[i].CallbackTokens != nil { + if singleResponse.Tokens != nil && singleResponse.CallbackTokens != nil { // need to make sure we process tokens _then_ process callback tokens - go handleAgentMessagePostResponseCallbackTokensAndTokens(currentTask, agentMessage.Responses[i].Tokens, agentMessage.Responses[i].CallbackTokens) + go handleAgentMessagePostResponseCallbackTokensAndTokens(currentTask, singleResponse.Tokens, singleResponse.CallbackTokens) } else { - if agentMessage.Responses[i].Tokens != nil { - go handleAgentMessagePostResponseTokens(currentTask, agentMessage.Responses[i].Tokens) + if singleResponse.Tokens != nil { + go handleAgentMessagePostResponseTokens(currentTask, singleResponse.Tokens) } - if agentMessage.Responses[i].CallbackTokens != nil { - go handleAgentMessagePostResponseCallbackTokens(currentTask, agentMessage.Responses[i].CallbackTokens) + if singleResponse.CallbackTokens != nil { + go handleAgentMessagePostResponseCallbackTokens(currentTask, singleResponse.CallbackTokens) } } - if agentMessage.Responses[i].ProcessResponse != nil { - go handleAgentMessagePostResponseProcessResponse(currentTask, agentMessage.Responses[i].ProcessResponse) + if singleResponse.ProcessResponse != nil { + go handleAgentMessagePostResponseProcessResponse(currentTask, singleResponse.ProcessResponse) } - if agentMessage.Responses[i].Commands != nil { - go handleAgentMessagePostResponseCommands(currentTask, agentMessage.Responses[i].Commands) + if singleResponse.Commands != nil { + go handleAgentMessagePostResponseCommands(currentTask, singleResponse.Commands) } - if agentMessage.Responses[i].Edges != nil { - go handleAgentMessagePostResponseEdges(uUIDInfo, agentMessage.Responses[i].Edges) + if singleResponse.Edges != nil { + go handleAgentMessagePostResponseEdges(uUIDInfo, singleResponse.Edges) } - if agentMessage.Responses[i].Alerts != nil { - go handleAgentMessagePostResponseAlerts(currentTask.OperationID, uUIDInfo.CallbackID, uUIDInfo.CallbackDisplayID, agentMessage.Responses[i].Alerts) + if singleResponse.Alerts != nil { + go handleAgentMessagePostResponseAlerts(currentTask.OperationID, uUIDInfo.CallbackID, uUIDInfo.CallbackDisplayID, singleResponse.Alerts) } - if agentMessage.Responses[i].Artifacts != nil { + if singleResponse.Artifacts != nil { // report back artifact information so that the agent can update the specific artifacts if needed - artifactResponses := handleAgentMessagePostResponseArtifacts(currentTask, agentMessage.Responses[i].Artifacts) + artifactResponses := handleAgentMessagePostResponseArtifacts(currentTask, singleResponse.Artifacts) mythicResponse["artifacts"] = artifactResponses } - if agentMessage.Responses[i].Callback != nil { - go handleAgentMessagePostResponseCallback(currentTask, agentMessage.Responses[i].Callback) + if singleResponse.Callback != nil { + go handleAgentMessagePostResponseCallback(currentTask, singleResponse.Callback) } - if agentMessage.Responses[i].Events != nil && len(*agentMessage.Responses[i].Events) > 0 { - go handleAgentMessagePostResponseEvent(currentTask, agentMessage.Responses[i].Events) + if singleResponse.Events != nil && len(*singleResponse.Events) > 0 { + go handleAgentMessagePostResponseEvent(currentTask, singleResponse.Events) } // this section always happens - reflectBackOtherKeys(&mythicResponse, &agentMessage.Responses[i].Other) + reflectBackOtherKeys(&mythicResponse, &singleResponse.Other) responses = append(responses, mythicResponse) cachedTaskData[currentTask.AgentTaskID] = currentTask }