Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
249 changes: 172 additions & 77 deletions mythic-docker/src/rabbitmq/util_agent_message_actions_post_response.go
Original file line number Diff line number Diff line change
Expand Up @@ -322,35 +322,130 @@ func handleAgentMessagePostResponse(incoming *map[string]interface{}, uUIDInfo *
]
}
*/
agentMessage := agentMessagePostResponseMessage{}
err := mapstructure.Decode(incoming, &agentMessage)
cachedTaskData := make(map[string]databaseStructs.Task)
cachedFileData := make(map[string]databaseStructs.Filemeta)
if err != nil {
logging.LogError(err, "Failed to decode agent message into struct, ignoring and continuing on")
badMessageString, err2 := json.MarshalIndent(incoming, "", " ")
delete(*incoming, "responses")
if err2 != nil {
go SendAllOperationsMessage(fmt.Sprintf("Failed to process agent message: \n%s\n%s\n", err2.Error(), err.Error()),
uUIDInfo.OperationID, "agent_message_bad_post_response", database.MESSAGE_LEVEL_AGENT_MESSGAGE, true)
} else {
badMessage := fmt.Sprintf("Failed to process agent message:\n%s\n%s\n", err.Error(), string(badMessageString))
go SendAllOperationsMessage(badMessage,
uUIDInfo.OperationID, "agent_message_bad_post_response", database.MESSAGE_LEVEL_AGENT_MESSGAGE, true)
responses := []map[string]interface{}{}
var err error

// Extract the raw responses array first
responsesRaw, ok := (*incoming)["responses"]
if !ok {
logging.LogError(nil, "No responses array found in agent message")
return map[string]interface{}{}, nil
}

responsesSlice, ok := responsesRaw.([]interface{})
if !ok {
logging.LogError(nil, "Responses field is not an array")
return map[string]interface{}{}, nil
}

// Decode the outer message structure to capture "Other" fields for reflection back
agentMessage := agentMessagePostResponseMessage{}
outerMessageCopy := make(map[string]interface{})
for k, v := range *incoming {
if k != "responses" {
outerMessageCopy[k] = v
}
// returning nil so that we can continue processing other message pieces if they exist
return map[string]interface{}{}, err
}
responses := []map[string]interface{}{}
// iterate over the agent messages
for i, _ := range agentMessage.Responses {
if err := mapstructure.Decode(outerMessageCopy, &agentMessage); err != nil {
logging.LogError(err, "Failed to decode outer message fields")
}

// Iterate over raw responses and decode each individually
for i, responseRaw := range responsesSlice {
responseMap, ok := responseRaw.(map[string]interface{})
if !ok {
logging.LogError(nil, "Response is not a map", "index", i)
continue
}

// Try to decode this individual response
singleResponse := agentMessagePostResponse{}
err := mapstructure.Decode(responseMap, &singleResponse)

// Extract task_id for acknowledgement even if decode fails
taskID := ""
if taskIDRaw, ok := responseMap["task_id"]; ok {
taskID = fmt.Sprintf("%v", taskIDRaw)
}
if taskID == "" && err == nil {
taskID = singleResponse.TaskID
}

// If decode failed, log error and send error acknowledgement
if err != nil {
logging.LogError(err, "Failed to decode individual response", "index", i, "task_id", taskID)
if taskID != "" {
// Save the decode error message before it gets shadowed by database query
decodeErrorMsg := err.Error()

// Build error response preserving original task_id type and custom fields
// Note: status is "success" to acknowledge receipt (agent requirement)
// The "error" field contains the actual error details
errorResponse := map[string]interface{}{
"status": "success",
"error": fmt.Sprintf("Failed to decode response: %s", decodeErrorMsg),
}

// Preserve original task_id (don't convert to string, keep original type)
if taskIDRaw, ok := responseMap["task_id"]; ok {
errorResponse["task_id"] = taskIDRaw
} else {
errorResponse["task_id"] = taskID
}

// Reflect back any custom fields using same function as success path
// This ensures consistent behavior between error and success responses
reflectBackOtherKeys(&errorResponse, &responseMap)

responses = append(responses, errorResponse)

// Look up the task in database to update its status to error
// This ensures the GUI shows the error instead of "pending"
errorTask := databaseStructs.Task{AgentTaskID: taskID}
if dbErr := database.DB.Get(&errorTask, `SELECT
task.id, task.status, task.completed, task.status_timestamp_processed, task.operator_id, task.operation_id,
task.stdout, task.stderr, task.display_id,
task.eventstepinstance_id, task.apitokens_id,
callback.host "callback.host",
callback.user "callback.user",
callback.id "callback.id",
callback.display_id "callback.display_id",
callback.agent_callback_id "callback.agent_callback_id",
callback.mythictree_groups "callback.mythictree_groups",
payload.payload_type_id "callback.payload.payload_type_id",
payload.os "callback.payload.os"
FROM task
JOIN callback ON task.callback_id = callback.id
JOIN payload ON callback.registered_payload_id = payload.id
WHERE task.agent_task_id=$1`, errorTask.AgentTaskID); dbErr != nil {
logging.LogError(dbErr, "Failed to find task for decode error update", "task id", taskID)
} else {
// Update task to show error status
errorTask.Status = "error: Failed to decode agent response"
errorTask.Stderr = fmt.Sprintf("Failed to decode response structure: %s", decodeErrorMsg)
errorTask.Completed = true
errorTask.Timestamp = time.Now().UTC()
if !errorTask.StatusTimestampProcessed.Valid {
errorTask.StatusTimestampProcessed.Time = errorTask.Timestamp
errorTask.StatusTimestampProcessed.Valid = true
}
// Add to cache so it gets updated in batch at the end of function
cachedTaskData[errorTask.AgentTaskID] = errorTask
}
}
continue
}

// Successfully decoded - process this response normally
mythicResponse := map[string]interface{}{
"task_id": agentMessage.Responses[i].TaskID,
"task_id": singleResponse.TaskID,
"status": "success",
}
//logging.LogDebug("Got response data from agent", "response data", agentResponse, "extra keys", agentResponse.Other)
// every response should be tied to some task
currentTask := databaseStructs.Task{AgentTaskID: agentMessage.Responses[i].TaskID}
currentTask := databaseStructs.Task{AgentTaskID: singleResponse.TaskID}
if _, ok := cachedTaskData[currentTask.AgentTaskID]; ok {
currentTask = cachedTaskData[currentTask.AgentTaskID]
} else {
Expand Down Expand Up @@ -380,41 +475,41 @@ func handleAgentMessagePostResponse(incoming *map[string]interface{}, uUIDInfo *
}

// always process here
if agentMessage.Responses[i].Download != nil {
if singleResponse.Download != nil {
fileMeta := databaseStructs.Filemeta{}
if agentMessage.Responses[i].Download.FileID != nil && *agentMessage.Responses[i].Download.FileID != "" {
if _, ok := cachedFileData[*agentMessage.Responses[i].Download.FileID]; ok {
fileMeta = cachedFileData[*agentMessage.Responses[i].Download.FileID]
if singleResponse.Download.FileID != nil && *singleResponse.Download.FileID != "" {
if _, ok := cachedFileData[*singleResponse.Download.FileID]; ok {
fileMeta = cachedFileData[*singleResponse.Download.FileID]
} else {
fileMeta = databaseStructs.Filemeta{AgentFileID: *agentMessage.Responses[i].Download.FileID}
err = database.DB.Get(&fileMeta, `SELECT
fileMeta = databaseStructs.Filemeta{AgentFileID: *singleResponse.Download.FileID}
err = database.DB.Get(&fileMeta, `SELECT
id, "path", total_chunks, chunks_received, host, is_screenshot, full_remote_path, complete, md5, sha1, filename, chunk_size, operation_id
FROM filemeta
WHERE agent_file_id=$1`, *agentMessage.Responses[i].Download.FileID)
WHERE agent_file_id=$1`, *singleResponse.Download.FileID)
if err != nil {
logging.LogError(err, "Failed to find fileID in agent download request", "fileid", *agentMessage.Responses[i].Download.FileID)
logging.LogError(err, "Failed to find fileID in agent download request", "fileid", *singleResponse.Download.FileID)
continue
}
fileMeta.Task = &databaseStructs.Task{}
fileMeta.Task.OperatorID = currentTask.OperatorID
}
}
newFileID, err := handleAgentMessagePostResponseDownload(&currentTask, &agentMessage.Responses[i], &fileMeta)
newFileID, err := handleAgentMessagePostResponseDownload(&currentTask, &singleResponse, &fileMeta)
if err != nil {
mythicResponse["status"] = "error"
mythicResponse["error"] = err.Error()
} else {
mythicResponse["file_id"] = newFileID
cachedFileData[newFileID] = fileMeta
}
if agentMessage.Responses[i].Download.ChunkNum != nil {
mythicResponse["chunk_num"] = *agentMessage.Responses[i].Download.ChunkNum
if singleResponse.Download.ChunkNum != nil {
mythicResponse["chunk_num"] = *singleResponse.Download.ChunkNum
}
}

// always process here
if agentMessage.Responses[i].Upload != nil {
if uploadResponse, err := handleAgentMessagePostResponseUpload(currentTask, agentMessage.Responses[i]); err != nil {
if singleResponse.Upload != nil {
if uploadResponse, err := handleAgentMessagePostResponseUpload(currentTask, singleResponse); err != nil {
mythicResponse["status"] = "error"
mythicResponse["error"] = err.Error()
logging.LogError(err, "Failed to handle agent upload")
Expand All @@ -432,87 +527,87 @@ func handleAgentMessagePostResponse(incoming *map[string]interface{}, uUIDInfo *
currentTask.StatusTimestampProcessed.Valid = true
}
// this section can happen async, but in order
if agentMessage.Responses[i].Completed != nil {
if *agentMessage.Responses[i].Completed {
currentTask.Completed = *agentMessage.Responses[i].Completed
if singleResponse.Completed != nil {
if *singleResponse.Completed {
currentTask.Completed = *singleResponse.Completed
}
}
if agentMessage.Responses[i].Status != nil && *agentMessage.Responses[i].Status != "" {
if singleResponse.Status != nil && *singleResponse.Status != "" {
if currentTask.Status != PT_TASK_FUNCTION_STATUS_COMPLETED {
currentTask.Status = *agentMessage.Responses[i].Status
currentTask.Status = *singleResponse.Status
}
} else if agentMessage.Responses[i].Completed != nil && *agentMessage.Responses[i].Completed {
} else if singleResponse.Completed != nil && *singleResponse.Completed {
currentTask.Status = PT_TASK_FUNCTION_STATUS_COMPLETED
} else if currentTask.Status == PT_TASK_FUNCTION_STATUS_PROCESSING {
currentTask.Status = PT_TASK_FUNCTION_STATUS_PROCESSED
}
if agentMessage.Responses[i].UserOutput != nil && *agentMessage.Responses[i].UserOutput != "" {
if singleResponse.UserOutput != nil && *singleResponse.UserOutput != "" {
// do it in the background - the agent doesn't need the result of this directly
//handleAgentMessagePostResponseUserOutput(currentTask, agentResponse, true)
asyncAgentMessagePostResponseChannel <- agentAgentMessagePostResponseChannelMessage{
Task: currentTask,
Response: *agentMessage.Responses[i].UserOutput,
SequenceNum: agentMessage.Responses[i].SequenceNumber,
Response: *singleResponse.UserOutput,
SequenceNum: singleResponse.SequenceNumber,
}
}
if agentMessage.Responses[i].Stdout != nil {
currentTask.Stdout += *agentMessage.Responses[i].Stdout
if singleResponse.Stdout != nil {
currentTask.Stdout += *singleResponse.Stdout
}
if agentMessage.Responses[i].Stderr != nil {
currentTask.Stderr = *agentMessage.Responses[i].Stderr
if singleResponse.Stderr != nil {
currentTask.Stderr = *singleResponse.Stderr
}
if agentMessage.Responses[i].FileBrowser != nil {
if singleResponse.FileBrowser != nil {
// do it in the background - the agent doesn't need the result of this directly
go HandleAgentMessagePostResponseFileBrowser(currentTask, agentMessage.Responses[i].FileBrowser, 0)
go HandleAgentMessagePostResponseFileBrowser(currentTask, singleResponse.FileBrowser, 0)
}
if agentMessage.Responses[i].Processes != nil {
go HandleAgentMessagePostResponseProcesses(currentTask, agentMessage.Responses[i].Processes, 0)
if singleResponse.Processes != nil {
go HandleAgentMessagePostResponseProcesses(currentTask, singleResponse.Processes, 0)
}
if agentMessage.Responses[i].RemovedFiles != nil {
go handleAgentMessagePostResponseRemovedFiles(currentTask, agentMessage.Responses[i].RemovedFiles)
if singleResponse.RemovedFiles != nil {
go handleAgentMessagePostResponseRemovedFiles(currentTask, singleResponse.RemovedFiles)
}
if agentMessage.Responses[i].Credentials != nil {
go handleAgentMessagePostResponseCredentials(currentTask, agentMessage.Responses[i].Credentials)
if singleResponse.Credentials != nil {
go handleAgentMessagePostResponseCredentials(currentTask, singleResponse.Credentials)
}
if agentMessage.Responses[i].Keylogs != nil {
go handleAgentMessagePostResponseKeylogs(currentTask, agentMessage.Responses[i].Keylogs)
if singleResponse.Keylogs != nil {
go handleAgentMessagePostResponseKeylogs(currentTask, singleResponse.Keylogs)
}
if agentMessage.Responses[i].Tokens != nil && agentMessage.Responses[i].CallbackTokens != nil {
if singleResponse.Tokens != nil && singleResponse.CallbackTokens != nil {
// need to make sure we process tokens _then_ process callback tokens
go handleAgentMessagePostResponseCallbackTokensAndTokens(currentTask, agentMessage.Responses[i].Tokens, agentMessage.Responses[i].CallbackTokens)
go handleAgentMessagePostResponseCallbackTokensAndTokens(currentTask, singleResponse.Tokens, singleResponse.CallbackTokens)
} else {
if agentMessage.Responses[i].Tokens != nil {
go handleAgentMessagePostResponseTokens(currentTask, agentMessage.Responses[i].Tokens)
if singleResponse.Tokens != nil {
go handleAgentMessagePostResponseTokens(currentTask, singleResponse.Tokens)
}
if agentMessage.Responses[i].CallbackTokens != nil {
go handleAgentMessagePostResponseCallbackTokens(currentTask, agentMessage.Responses[i].CallbackTokens)
if singleResponse.CallbackTokens != nil {
go handleAgentMessagePostResponseCallbackTokens(currentTask, singleResponse.CallbackTokens)
}
}
if agentMessage.Responses[i].ProcessResponse != nil {
go handleAgentMessagePostResponseProcessResponse(currentTask, agentMessage.Responses[i].ProcessResponse)
if singleResponse.ProcessResponse != nil {
go handleAgentMessagePostResponseProcessResponse(currentTask, singleResponse.ProcessResponse)
}
if agentMessage.Responses[i].Commands != nil {
go handleAgentMessagePostResponseCommands(currentTask, agentMessage.Responses[i].Commands)
if singleResponse.Commands != nil {
go handleAgentMessagePostResponseCommands(currentTask, singleResponse.Commands)
}
if agentMessage.Responses[i].Edges != nil {
go handleAgentMessagePostResponseEdges(uUIDInfo, agentMessage.Responses[i].Edges)
if singleResponse.Edges != nil {
go handleAgentMessagePostResponseEdges(uUIDInfo, singleResponse.Edges)
}
if agentMessage.Responses[i].Alerts != nil {
go handleAgentMessagePostResponseAlerts(currentTask.OperationID, uUIDInfo.CallbackID, uUIDInfo.CallbackDisplayID, agentMessage.Responses[i].Alerts)
if singleResponse.Alerts != nil {
go handleAgentMessagePostResponseAlerts(currentTask.OperationID, uUIDInfo.CallbackID, uUIDInfo.CallbackDisplayID, singleResponse.Alerts)
}
if agentMessage.Responses[i].Artifacts != nil {
if singleResponse.Artifacts != nil {
// report back artifact information so that the agent can update the specific artifacts if needed
artifactResponses := handleAgentMessagePostResponseArtifacts(currentTask, agentMessage.Responses[i].Artifacts)
artifactResponses := handleAgentMessagePostResponseArtifacts(currentTask, singleResponse.Artifacts)
mythicResponse["artifacts"] = artifactResponses
}
if agentMessage.Responses[i].Callback != nil {
go handleAgentMessagePostResponseCallback(currentTask, agentMessage.Responses[i].Callback)
if singleResponse.Callback != nil {
go handleAgentMessagePostResponseCallback(currentTask, singleResponse.Callback)
}
if agentMessage.Responses[i].Events != nil && len(*agentMessage.Responses[i].Events) > 0 {
go handleAgentMessagePostResponseEvent(currentTask, agentMessage.Responses[i].Events)
if singleResponse.Events != nil && len(*singleResponse.Events) > 0 {
go handleAgentMessagePostResponseEvent(currentTask, singleResponse.Events)
}
// this section always happens
reflectBackOtherKeys(&mythicResponse, &agentMessage.Responses[i].Other)
reflectBackOtherKeys(&mythicResponse, &singleResponse.Other)
responses = append(responses, mythicResponse)
cachedTaskData[currentTask.AgentTaskID] = currentTask
}
Expand Down