diff --git a/agent/agent_startup/src/main/java/com/intuit/tank/agent/AgentStartup.java b/agent/agent_startup/src/main/java/com/intuit/tank/agent/AgentStartup.java index 48af1d8b5..813d8396c 100644 --- a/agent/agent_startup/src/main/java/com/intuit/tank/agent/AgentStartup.java +++ b/agent/agent_startup/src/main/java/com/intuit/tank/agent/AgentStartup.java @@ -54,42 +54,48 @@ public void run() { logger.info("Starting up..."); HttpClient client = HttpClient.newHttpClient(); try { + boolean controllerInitiatedWsEnabled = Boolean.parseBoolean( + AmazonUtil.getUserDataAsMap().getOrDefault(TankConstants.KEY_CONTROLLER_INITIATED_WS_ENABLED, "false")); logger.info("Starting up: ControllerBaseUrl={}", controllerBaseUrl); - HttpRequest request = HttpRequest.newBuilder().uri(URI.create( - controllerBaseUrl + SERVICE_RELATIVE_PATH + METHOD_SETTINGS)) - .header("Authorization", "bearer "+token).build(); - logger.info("Starting up: making call to tank service url to get settings.xml {} {} {}", - controllerBaseUrl, SERVICE_RELATIVE_PATH, METHOD_SUPPORT); - client.send(request, BodyHandlers.ofFile(Paths.get(TANK_AGENT_DIR, "settings.xml"))); - logger.info("got settings file..."); - // Download Support Files - request = HttpRequest.newBuilder().uri(URI.create( - controllerBaseUrl + SERVICE_RELATIVE_PATH + METHOD_SUPPORT)) - .header("Authorization", "bearer "+token).build(); - logger.info("Making call to tank service url to get support files {} {} {}", - controllerBaseUrl, SERVICE_RELATIVE_PATH, METHOD_SUPPORT); - int retryCount = 0; - while (true) { - try (ZipInputStream zip = new ZipInputStream( - client.send(request, BodyHandlers.ofInputStream()).body())) { - ZipEntry entry = zip.getNextEntry(); - Path agentDirPath = Paths.get(TANK_AGENT_DIR).toAbsolutePath().normalize(); - while (entry != null) { - String filename = entry.getName(); - logger.info("Got file from controller: {}", filename); - Path targetPath = agentDirPath.resolve(filename).normalize(); - if (!targetPath.startsWith(agentDirPath)) // Protect "Zip Slip" - throw new ZipException("Bad zip entry"); - Files.write(targetPath, zip.readAllBytes()); - entry = zip.getNextEntry(); + if (!controllerInitiatedWsEnabled) { + HttpRequest request = HttpRequest.newBuilder().uri(URI.create( + controllerBaseUrl + SERVICE_RELATIVE_PATH + METHOD_SETTINGS)) + .header("Authorization", "bearer "+token).build(); + logger.info("Starting up: making call to tank service url to get settings.xml {} {} {}", + controllerBaseUrl, SERVICE_RELATIVE_PATH, METHOD_SUPPORT); + client.send(request, BodyHandlers.ofFile(Paths.get(TANK_AGENT_DIR, "settings.xml"))); + logger.info("got settings file..."); + // Download Support Files + request = HttpRequest.newBuilder().uri(URI.create( + controllerBaseUrl + SERVICE_RELATIVE_PATH + METHOD_SUPPORT)) + .header("Authorization", "bearer "+token).build(); + logger.info("Making call to tank service url to get support files {} {} {}", + controllerBaseUrl, SERVICE_RELATIVE_PATH, METHOD_SUPPORT); + int retryCount = 0; + while (true) { + try (ZipInputStream zip = new ZipInputStream( + client.send(request, BodyHandlers.ofInputStream()).body())) { + ZipEntry entry = zip.getNextEntry(); + Path agentDirPath = Paths.get(TANK_AGENT_DIR).toAbsolutePath().normalize(); + while (entry != null) { + String filename = entry.getName(); + logger.info("Got file from controller: {}", filename); + Path targetPath = agentDirPath.resolve(filename).normalize(); + if (!targetPath.startsWith(agentDirPath)) // Protect "Zip Slip" + throw new ZipException("Bad zip entry"); + Files.write(targetPath, zip.readAllBytes()); + entry = zip.getNextEntry(); + } + break; + } catch (EOFException | ZipException e) { + logger.error("Error unzipping support files : retryCount={} : {}", retryCount, e.getMessage()); + if (retryCount < FIBONACCI.length) { + Thread.sleep( FIBONACCI[retryCount++] * 1000 ); + } else throw e; } - break; - } catch (EOFException | ZipException e) { - logger.error("Error unzipping support files : retryCount={} : {}", retryCount, e.getMessage()); - if (retryCount < FIBONACCI.length) { - Thread.sleep( FIBONACCI[retryCount++] * 1000 ); - } else throw e; } + } else { + logger.info("Controller-initiated WS mode enabled - skipping settings/support download from controller."); } // now start the harness String controllerArg = " -http=" + controllerBaseUrl; diff --git a/agent/apiharness/pom.xml b/agent/apiharness/pom.xml index 1b16f203f..2e12e2965 100644 --- a/agent/apiharness/pom.xml +++ b/agent/apiharness/pom.xml @@ -73,5 +73,11 @@ org.openjdk.nashorn nashorn-core + + + org.java-websocket + Java-WebSocket + 1.5.6 + diff --git a/agent/apiharness/src/main/java/com/intuit/tank/harness/APIMonitor.java b/agent/apiharness/src/main/java/com/intuit/tank/harness/APIMonitor.java index 7211a8467..f8ead2488 100644 --- a/agent/apiharness/src/main/java/com/intuit/tank/harness/APIMonitor.java +++ b/agent/apiharness/src/main/java/com/intuit/tank/harness/APIMonitor.java @@ -84,7 +84,7 @@ private void updateInstanceStatus() { sendTps(tpsInfo); } - if (!isLocal) { + if (!isLocal && !APITestHarness.getInstance().isControllerInitiatedWsModeEnabled()) { setInstanceStatus(newStatus.getInstanceId(), newStatus); } APITestHarness.getInstance().checkAgentThreads(); @@ -162,7 +162,9 @@ public synchronized static void setJobStatus(JobStatus jobStatus) { stats.getMaxVirtualUsers(), stats.getCurrentNumberUsers(), status.getStartTime(), endTime); status.setUserDetails(APITestHarness.getInstance().getUserTracker().getSnapshot()); - setInstanceStatus(status.getInstanceId(), status); + if (!APITestHarness.getInstance().isControllerInitiatedWsModeEnabled()) { + setInstanceStatus(status.getInstanceId(), status); + } } catch (Exception e) { LOG.error("Error sending status to controller: {}", e.toString(), e); } diff --git a/agent/apiharness/src/main/java/com/intuit/tank/harness/APITestHarness.java b/agent/apiharness/src/main/java/com/intuit/tank/harness/APITestHarness.java index fff55d082..07c597733 100644 --- a/agent/apiharness/src/main/java/com/intuit/tank/harness/APITestHarness.java +++ b/agent/apiharness/src/main/java/com/intuit/tank/harness/APITestHarness.java @@ -105,6 +105,7 @@ public class APITestHarness { private TPSMonitor tpsMonitor; private ResultsReporter resultsReporter; private String tankHttpClientClass; + private AgentCommandWebSocketServer controllerInitiatedWsServer; private Date send = new Date(); private static final int interval = 15; // SECONDS @@ -226,6 +227,39 @@ private void initializeFromArgs(String[] args) { } } + public boolean isControllerInitiatedWsModeEnabled() { + try { + String value = AmazonUtil.getUserDataAsMap().get(TankConstants.KEY_CONTROLLER_INITIATED_WS_ENABLED); + if (StringUtils.isNotBlank(value)) { + return Boolean.parseBoolean(value); + } + } catch (Exception ignored) { + } + return tankConfig.getAgentConfig().isControllerInitiatedWsEnabled(); + } + + private boolean isControllerInitiatedWsDisableAgentHttpEnabled() { + try { + String value = AmazonUtil.getUserDataAsMap().get(TankConstants.KEY_CONTROLLER_INITIATED_WS_DISABLE_AGENT_HTTP); + if (StringUtils.isNotBlank(value)) { + return Boolean.parseBoolean(value); + } + } catch (Exception ignored) { + } + return tankConfig.getAgentConfig().isControllerInitiatedWsDisableAgentHttp(); + } + + private String getControllerInitiatedWsScriptPath() { + try { + String value = AmazonUtil.getUserDataAsMap().get(TankConstants.KEY_CONTROLLER_INITIATED_WS_SCRIPT_PATH); + if (StringUtils.isNotBlank(value)) { + return value; + } + } catch (Exception ignored) { + } + return tankConfig.getAgentConfig().getControllerInitiatedWsScriptPath(); + } + private String getLocalInstanceId() { isLocal = true; try { @@ -255,7 +289,12 @@ private static void usage() { private void startHttp(String baseUrl, String token) { isLocal = false; HostInfo hostInfo = new HostInfo(); - CommandListener.startHttpServer(tankConfig.getAgentConfig().getAgentPort()); + boolean controllerInitiatedWsMode = isControllerInitiatedWsModeEnabled(); + boolean controllerInitiatedWsDisableAgentHttp = isControllerInitiatedWsDisableAgentHttpEnabled(); + + if (!controllerInitiatedWsMode || !controllerInitiatedWsDisableAgentHttp) { + CommandListener.startHttpServer(tankConfig.getAgentConfig().getAgentPort()); + } baseUrl = (baseUrl == null) ? AmazonUtil.getControllerBaseUrl() : baseUrl; token = (token == null) ? AmazonUtil.getAgentToken() : token; String instanceUrl = null; @@ -291,6 +330,42 @@ private void startHttp(String baseUrl, String token) { agentRunData.setStopBehavior(AmazonUtil.getStopBehavior()); LogUtil.getLogEvent().setJobId(agentRunData.getJobId()); + if (controllerInitiatedWsMode && controllerInitiatedWsDisableAgentHttp) { + try { + if (controllerInitiatedWsServer == null) { + controllerInitiatedWsServer = new AgentCommandWebSocketServer( + tankConfig.getAgentConfig().getAgentPort(), + instanceId, + agentRunData.getJobId(), + capacity); + controllerInitiatedWsServer.start(); + } + + String scriptPath = getControllerInitiatedWsScriptPath(); + if (StringUtils.isNotBlank(scriptPath) && new File(scriptPath).exists()) { + LOG.info(new ObjectMessage(Map.of("Message", "Controller-initiated WS mode loading script from " + scriptPath))); + TestPlanSingleton.getInstance().setTestPlans(scriptPath); + } else if (StringUtils.isNotBlank(testPlans) && AgentUtil.validateTestPlans(testPlans)) { + LOG.info(new ObjectMessage(Map.of("Message", "Controller-initiated WS mode loading script from args " + testPlans))); + TestPlanSingleton.getInstance().setTestPlans(testPlans); + } else { + LOG.warn(new ObjectMessage(Map.of("Message", "Controller-initiated WS mode has no valid local script path configured. Awaiting controller command anyway."))); + } + + Thread thread = new Thread(new StartedChecker()); + thread.setName("StartedChecker"); + thread.setDaemon(false); + thread.start(); + return; + } catch (Exception e) { + LOG.error("Error starting controller-initiated WS mode: " + e, e); + System.exit(0); + } + } else if (controllerInitiatedWsMode) { + LOG.warn(new ObjectMessage(Map.of("Message", + "Controller-initiated WS is enabled but HTTP disable flag is false; running legacy HTTP lifecycle path"))); + } + AgentData data = new AgentData(agentRunData.getJobId(), instanceId, instanceUrl, capacity, AmazonUtil.getVMRegion(), AmazonUtil.getZone()); try { @@ -340,6 +415,15 @@ private void startHttp(String baseUrl, String token) { saveDataFile(dfRequest, token); } } + // Start WS control channel if enabled + if (tankConfig.getAgentConfig().isCommandWsEnabled()) { + String wsPath = tankConfig.getAgentConfig().getCommandWsPath(); + LOG.info(new ObjectMessage(Map.of("Message", "Starting WS control channel to " + baseUrl + wsPath))); + AgentCommandWebSocketClient wsClient = new AgentCommandWebSocketClient( + baseUrl, wsPath, token, instanceId, agentRunData.getJobId()); + wsClient.connect(); + } + Thread thread = new Thread(new StartedChecker()); thread.setName("StartedChecker"); thread.setDaemon(false); diff --git a/agent/apiharness/src/main/java/com/intuit/tank/harness/AgentCommandWebSocketClient.java b/agent/apiharness/src/main/java/com/intuit/tank/harness/AgentCommandWebSocketClient.java new file mode 100644 index 000000000..64e15f25d --- /dev/null +++ b/agent/apiharness/src/main/java/com/intuit/tank/harness/AgentCommandWebSocketClient.java @@ -0,0 +1,255 @@ +package com.intuit.tank.harness; + +import com.intuit.tank.vm.agent.messages.AgentWsEnvelope; +import com.intuit.tank.vm.agent.messages.AgentWsEnvelope.AckStatus; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ObjectMessage; + +import java.io.IOException; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.WebSocket; +import java.nio.ByteBuffer; +import java.util.Map; +import java.util.Set; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionStage; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; + +public class AgentCommandWebSocketClient implements WebSocket.Listener { + + private static final Logger LOG = LogManager.getLogger(AgentCommandWebSocketClient.class); + + private static final int[] BACKOFF_SECONDS = {1, 2, 5, 10, 20, 30}; + + private final String controllerBaseUrl; + private final String wsPath; + private final String token; + private final String instanceId; + private final String jobId; + private final String agentSessionId; + private final HttpClient httpClient; + + private static final int MAX_APPLIED_COMMAND_IDS = 10_000; + + private final AtomicReference webSocketRef = new AtomicReference<>(); + private final AtomicBoolean running = new AtomicBoolean(true); + private final AtomicBoolean reconnecting = new AtomicBoolean(false); + private final Set appliedCommandIds = ConcurrentHashMap.newKeySet(); + private volatile String lastAppliedCommandId = null; + + // Buffer for accumulating partial text frames + private final StringBuilder messageBuffer = new StringBuilder(); + + public AgentCommandWebSocketClient(String controllerBaseUrl, String wsPath, String token, + String instanceId, String jobId) { + this.controllerBaseUrl = controllerBaseUrl; + this.wsPath = wsPath; + this.token = token; + this.instanceId = instanceId; + this.jobId = jobId; + this.agentSessionId = UUID.randomUUID().toString(); + this.httpClient = HttpClient.newHttpClient(); + } + + public void connect() { + Thread connectThread = new Thread(this::connectWithRetry, "WS-Connect-" + instanceId); + connectThread.setDaemon(true); + connectThread.start(); + } + + private void connectWithRetry() { + int attempt = 0; + while (running.get()) { + try { + String wsUrl = buildWsUrl(); + LOG.info(new ObjectMessage(Map.of("Message", "Connecting WS to " + wsUrl))); + + WebSocket ws = httpClient.newWebSocketBuilder() + .header("Authorization", "bearer " + token) + .buildAsync(URI.create(wsUrl), this) + .join(); + + webSocketRef.set(ws); + attempt = 0; // reset on successful connect + + // Send hello + AgentWsEnvelope hello = AgentWsEnvelope.hello(instanceId, jobId, agentSessionId, lastAppliedCommandId); + ws.sendText(hello.toJson(), true); + LOG.info(new ObjectMessage(Map.of("Message", "WS hello sent for agent " + instanceId + " job " + jobId))); + + // Block until connection closes (handled by listener callbacks) + // The listener methods will handle incoming frames + return; + + } catch (Exception e) { + int backoff = BACKOFF_SECONDS[Math.min(attempt, BACKOFF_SECONDS.length - 1)]; + // Add jitter: 0-500ms + int jitter = (int) (Math.random() * 500); + LOG.warn(new ObjectMessage(Map.of("Message", "WS connect failed (attempt " + attempt + "): " + e.getMessage() + + ", retrying in " + backoff + "s"))); + try { + Thread.sleep(backoff * 1000L + jitter); + } catch (InterruptedException ignored) { + Thread.currentThread().interrupt(); + return; + } + attempt++; + } + } + } + + private String buildWsUrl() { + String base = controllerBaseUrl; + // Convert http(s) to ws(s) + if (base.startsWith("https://")) { + base = "wss://" + base.substring("https://".length()); + } else if (base.startsWith("http://")) { + base = "ws://" + base.substring("http://".length()); + } + // Remove trailing slash + if (base.endsWith("/")) { + base = base.substring(0, base.length() - 1); + } + return base + wsPath; + } + + @Override + public void onOpen(WebSocket webSocket) { + LOG.info(new ObjectMessage(Map.of("Message", "WS connection opened for agent " + instanceId))); + webSocket.request(1); + } + + @Override + public CompletionStage onText(WebSocket webSocket, CharSequence data, boolean last) { + messageBuffer.append(data); + if (last) { + String fullMessage = messageBuffer.toString(); + messageBuffer.setLength(0); + handleMessage(webSocket, fullMessage); + } + webSocket.request(1); + return null; + } + + @Override + public CompletionStage onPing(WebSocket webSocket, ByteBuffer message) { + webSocket.request(1); + return null; + } + + @Override + public CompletionStage onClose(WebSocket webSocket, int statusCode, String reason) { + LOG.info(new ObjectMessage(Map.of("Message", "WS closed: code=" + statusCode + " reason=" + reason))); + webSocketRef.set(null); + scheduleReconnect(); + return null; + } + + @Override + public void onError(WebSocket webSocket, Throwable error) { + LOG.error(new ObjectMessage(Map.of("Message", "WS error for agent " + instanceId + ": " + error.getMessage())), error); + webSocketRef.set(null); + scheduleReconnect(); + } + + private void scheduleReconnect() { + if (running.get() && reconnecting.compareAndSet(false, true)) { + Thread reconnect = new Thread(() -> { + try { + connectWithRetry(); + } finally { + reconnecting.set(false); + } + }, "WS-Reconnect-" + instanceId); + reconnect.setDaemon(true); + reconnect.start(); + } + } + + private void handleMessage(WebSocket webSocket, String text) { + try { + AgentWsEnvelope envelope = AgentWsEnvelope.fromJson(text); + if (envelope.getType() == null) { + LOG.warn(new ObjectMessage(Map.of("Message", "Received WS frame with no type"))); + return; + } + + switch (envelope.getType()) { + case command -> handleCommand(webSocket, envelope); + case ping -> handlePing(webSocket, envelope); + case ack -> LOG.info(new ObjectMessage(Map.of("Message", "Received ack: " + envelope.getAckForType()))); + default -> LOG.warn(new ObjectMessage(Map.of("Message", "Unexpected WS frame type: " + envelope.getType()))); + } + } catch (IOException e) { + LOG.warn(new ObjectMessage(Map.of("Message", "Failed to parse WS frame: " + e.getMessage()))); + } + } + + private void handleCommand(WebSocket webSocket, AgentWsEnvelope envelope) { + String commandId = envelope.getCommandId(); + String command = envelope.getCommand(); + + // Bound the dedup set to prevent unbounded growth + if (appliedCommandIds.size() > MAX_APPLIED_COMMAND_IDS) { + appliedCommandIds.clear(); + LOG.info(new ObjectMessage(Map.of("Message", "Cleared appliedCommandIds set (exceeded " + MAX_APPLIED_COMMAND_IDS + ")"))); + } + + // Deduplicate + if (commandId != null && !appliedCommandIds.add(commandId)) { + LOG.info(new ObjectMessage(Map.of("Message", "Duplicate WS command " + commandId + " ignored"))); + sendAck(webSocket, commandId, AckStatus.duplicate); + return; + } + + LOG.info(new ObjectMessage(Map.of("Message", "Received WS command: " + command + " (id=" + commandId + ")"))); + + try { + applyCommand(command); + lastAppliedCommandId = commandId; + sendAck(webSocket, commandId, AckStatus.ok); + } catch (Exception e) { + LOG.error(new ObjectMessage(Map.of("Message", "Error applying WS command " + command + ": " + e.getMessage())), e); + sendAck(webSocket, commandId, AckStatus.failed); + } + } + + private void applyCommand(String command) { + CommandListener.applyCommand(command); + } + + private void handlePing(WebSocket webSocket, AgentWsEnvelope envelope) { + try { + AgentWsEnvelope pong = AgentWsEnvelope.pong(instanceId, agentSessionId, envelope.getPingId(), lastAppliedCommandId); + webSocket.sendText(pong.toJson(), true); + } catch (IOException e) { + LOG.warn(new ObjectMessage(Map.of("Message", "Failed to send pong: " + e.getMessage()))); + } + } + + private void sendAck(WebSocket webSocket, String commandId, AckStatus status) { + try { + AgentWsEnvelope ack = AgentWsEnvelope.ack(instanceId, "command", commandId, status); + webSocket.sendText(ack.toJson(), true); + } catch (IOException e) { + LOG.warn(new ObjectMessage(Map.of("Message", "Failed to send ack for " + commandId + ": " + e.getMessage()))); + } + } + + public void close() { + running.set(false); + WebSocket ws = webSocketRef.get(); + if (ws != null) { + try { + AgentWsEnvelope closeEnvelope = AgentWsEnvelope.close(instanceId, "agent_shutdown", "Agent shutting down"); + ws.sendText(closeEnvelope.toJson(), true); + } catch (IOException ignored) {} + ws.sendClose(WebSocket.NORMAL_CLOSURE, "Agent shutting down"); + } + } +} diff --git a/agent/apiharness/src/main/java/com/intuit/tank/harness/AgentCommandWebSocketServer.java b/agent/apiharness/src/main/java/com/intuit/tank/harness/AgentCommandWebSocketServer.java new file mode 100644 index 000000000..752a8a246 --- /dev/null +++ b/agent/apiharness/src/main/java/com/intuit/tank/harness/AgentCommandWebSocketServer.java @@ -0,0 +1,134 @@ +package com.intuit.tank.harness; + +import com.intuit.tank.vm.agent.messages.AgentWsEnvelope; +import com.intuit.tank.vm.agent.messages.AgentWsEnvelope.AckStatus; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ObjectMessage; +import org.java_websocket.WebSocket; +import org.java_websocket.handshake.ClientHandshake; +import org.java_websocket.server.WebSocketServer; + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.util.Map; +import java.util.Set; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; + +public class AgentCommandWebSocketServer extends WebSocketServer { + + private static final Logger LOG = LogManager.getLogger(AgentCommandWebSocketServer.class); + private static final int MAX_APPLIED_COMMAND_IDS = 10_000; + + private final String instanceId; + private final String jobId; + private final int capacity; + private final String agentSessionId; + + private final Set appliedCommandIds = ConcurrentHashMap.newKeySet(); + private volatile String lastAppliedCommandId; + + public AgentCommandWebSocketServer(int port, String instanceId, String jobId, int capacity) { + super(new InetSocketAddress(port)); + this.instanceId = instanceId; + this.jobId = jobId; + this.capacity = capacity; + this.agentSessionId = UUID.randomUUID().toString(); + } + + @Override + public void onOpen(WebSocket connection, ClientHandshake handshake) { + LOG.info(new ObjectMessage(Map.of("Message", + "Controller WS connected to agent " + instanceId + " from " + connection.getRemoteSocketAddress()))); + try { + AgentWsEnvelope hello = AgentWsEnvelope.hello(instanceId, jobId, agentSessionId, lastAppliedCommandId, capacity); + connection.send(hello.toJson()); + } catch (IOException e) { + LOG.warn(new ObjectMessage(Map.of("Message", "Failed to send WS hello: " + e.getMessage()))); + connection.close(); + } + } + + @Override + public void onMessage(WebSocket connection, String message) { + try { + AgentWsEnvelope envelope = AgentWsEnvelope.fromJson(message); + if (envelope.getType() == null) { + LOG.warn(new ObjectMessage(Map.of("Message", "WS frame missing type"))); + return; + } + + switch (envelope.getType()) { + case command -> handleCommand(connection, envelope); + case ping -> handlePing(connection, envelope); + case close -> { + LOG.info(new ObjectMessage(Map.of("Message", "Controller closed WS: " + envelope.getReason()))); + connection.close(); + } + default -> LOG.warn(new ObjectMessage(Map.of("Message", "Unexpected WS frame type: " + envelope.getType()))); + } + } catch (IOException e) { + LOG.warn(new ObjectMessage(Map.of("Message", "Failed parsing WS message: " + e.getMessage()))); + } + } + + private void handleCommand(WebSocket connection, AgentWsEnvelope envelope) { + String commandId = envelope.getCommandId(); + String command = envelope.getCommand(); + + if (appliedCommandIds.size() > MAX_APPLIED_COMMAND_IDS) { + appliedCommandIds.clear(); + } + + if (commandId != null && !appliedCommandIds.add(commandId)) { + sendAck(connection, commandId, AckStatus.duplicate, null); + return; + } + + try { + CommandListener.applyCommand(command); + lastAppliedCommandId = commandId; + sendAck(connection, commandId, AckStatus.ok, null); + } catch (Exception e) { + LOG.warn(new ObjectMessage(Map.of("Message", + "Failed to apply command " + command + " for agent " + instanceId + ": " + e.getMessage()))); + sendAck(connection, commandId, AckStatus.failed, e.getMessage()); + } + } + + private void handlePing(WebSocket connection, AgentWsEnvelope envelope) { + try { + AgentWsEnvelope pong = AgentWsEnvelope.pong(instanceId, agentSessionId, envelope.getPingId(), lastAppliedCommandId); + connection.send(pong.toJson()); + } catch (IOException e) { + LOG.warn(new ObjectMessage(Map.of("Message", "Failed to send pong: " + e.getMessage()))); + } + } + + private void sendAck(WebSocket connection, String commandId, AckStatus status, String error) { + try { + AgentWsEnvelope ack = AgentWsEnvelope.ack(instanceId, "command", commandId, status); + ack.setError(error); + connection.send(ack.toJson()); + } catch (IOException e) { + LOG.warn(new ObjectMessage(Map.of("Message", "Failed to send command ack: " + e.getMessage()))); + } + } + + @Override + public void onClose(WebSocket connection, int code, String reason, boolean remote) { + LOG.info(new ObjectMessage(Map.of("Message", + "Controller WS disconnected from agent " + instanceId + " code=" + code + " reason=" + reason))); + } + + @Override + public void onError(WebSocket connection, Exception ex) { + LOG.error(new ObjectMessage(Map.of("Message", "Agent WS server transport error: " + ex.getMessage())), ex); + } + + @Override + public void onStart() { + LOG.info(new ObjectMessage(Map.of("Message", "Agent WS control server started for " + instanceId))); + } +} diff --git a/agent/apiharness/src/main/java/com/intuit/tank/harness/CommandListener.java b/agent/apiharness/src/main/java/com/intuit/tank/harness/CommandListener.java index cff991a2b..6de334a51 100644 --- a/agent/apiharness/src/main/java/com/intuit/tank/harness/CommandListener.java +++ b/agent/apiharness/src/main/java/com/intuit/tank/harness/CommandListener.java @@ -76,29 +76,10 @@ private static void handleRequest(HttpExchange exchange) { try { String response = "Not Found"; String path = exchange.getRequestURI().getPath(); - if (path.equals(AgentCommand.start.getPath()) || path.equals(AgentCommand.run.getPath())) { - response = "Received command " + path + ", Starting Test JobId=" + APITestHarness.getInstance().getAgentRunData().getJobId(); - LOG.info(LogUtil.getLogMessage("Received START command - launching test threads for job " + - APITestHarness.getInstance().getAgentRunData().getJobId())); - startTest(); - } else if (path.startsWith(AgentCommand.stop.getPath())) { - response = "Received command " + path + ", Stopping Test JobId=" + APITestHarness.getInstance().getAgentRunData().getJobId(); - APITestHarness.getInstance().setCommand(AgentCommand.stop); - } else if (path.startsWith(AgentCommand.kill.getPath())) { - response = "Received command " + path + ", Killing Test JobId=" + APITestHarness.getInstance().getAgentRunData().getJobId(); - System.exit(0); - } else if (path.startsWith(AgentCommand.pause.getPath())) { - response = "Received command " + path + ", Pausing Test JobId=" + APITestHarness.getInstance().getAgentRunData().getJobId(); - APITestHarness.getInstance().setCommand(AgentCommand.pause); - } else if (path.startsWith(AgentCommand.pause_ramp.getPath())) { - response = "Received command " + path + ", Pausing Ramp for Test JobId=" + APITestHarness.getInstance().getAgentRunData().getJobId(); - APITestHarness.getInstance().setCommand(AgentCommand.pause_ramp); - } else if (path.startsWith(AgentCommand.resume_ramp.getPath())) { - response = "Received command " + path + ", Resume Test JobId=" + APITestHarness.getInstance().getAgentRunData().getJobId(); - APITestHarness.getInstance().setCommand(AgentCommand.resume_ramp); - } else if (path.startsWith(AgentCommand.status.getPath())) { - response = APITestHarness.getInstance().getStatus().toString(); - APITestHarness.getInstance().setCommand(AgentCommand.resume_ramp); + try { + response = applyHttpPathCommand(path); + } catch (UnsupportedOperationException ignored) { + response = "Not Found"; } LOG.info(LogUtil.getLogMessage(response)); @@ -118,6 +99,68 @@ private static void handleRequest(HttpExchange exchange) { } } + private static String applyHttpPathCommand(String path) { + if (path.equals(AgentCommand.start.getPath()) || path.equals(AgentCommand.run.getPath())) { + LOG.info(LogUtil.getLogMessage("Received START command - launching test threads for job " + + APITestHarness.getInstance().getAgentRunData().getJobId())); + executeCommand(AgentCommand.start); + return "Received command " + path + ", Starting Test JobId=" + APITestHarness.getInstance().getAgentRunData().getJobId(); + } + if (path.startsWith(AgentCommand.stop.getPath())) { + executeCommand(AgentCommand.stop); + return "Received command " + path + ", Stopping Test JobId=" + APITestHarness.getInstance().getAgentRunData().getJobId(); + } + if (path.startsWith(AgentCommand.kill.getPath())) { + executeCommand(AgentCommand.kill); + return "Received command " + path + ", Killing Test JobId=" + APITestHarness.getInstance().getAgentRunData().getJobId(); + } + if (path.startsWith(AgentCommand.pause.getPath())) { + executeCommand(AgentCommand.pause); + return "Received command " + path + ", Pausing Test JobId=" + APITestHarness.getInstance().getAgentRunData().getJobId(); + } + if (path.startsWith(AgentCommand.pause_ramp.getPath())) { + executeCommand(AgentCommand.pause_ramp); + return "Received command " + path + ", Pausing Ramp for Test JobId=" + APITestHarness.getInstance().getAgentRunData().getJobId(); + } + if (path.startsWith(AgentCommand.resume_ramp.getPath())) { + executeCommand(AgentCommand.resume_ramp); + return "Received command " + path + ", Resume Test JobId=" + APITestHarness.getInstance().getAgentRunData().getJobId(); + } + if (path.startsWith(AgentCommand.status.getPath())) { + executeCommand(AgentCommand.status); + return APITestHarness.getInstance().getStatus().toString(); + } + throw new UnsupportedOperationException("Unknown command path: " + path); + } + + public static void applyCommand(String command) { + if (command == null) { + throw new UnsupportedOperationException("Unknown command: null"); + } + switch (command) { + case "start", "run" -> executeCommand(AgentCommand.start); + case "stop" -> executeCommand(AgentCommand.stop); + case "kill" -> executeCommand(AgentCommand.kill); + case "pause" -> executeCommand(AgentCommand.pause); + case "pause_ramp" -> executeCommand(AgentCommand.pause_ramp); + case "resume_ramp", "resume" -> executeCommand(AgentCommand.resume_ramp); + case "status" -> executeCommand(AgentCommand.status); + default -> throw new UnsupportedOperationException("Unknown command: " + command); + } + } + + private static void executeCommand(AgentCommand command) { + switch (command) { + case start, run -> startTest(); + case stop -> APITestHarness.getInstance().setCommand(AgentCommand.stop); + case kill -> System.exit(0); + case pause -> APITestHarness.getInstance().setCommand(AgentCommand.pause); + case pause_ramp -> APITestHarness.getInstance().setCommand(AgentCommand.pause_ramp); + case resume_ramp, status -> APITestHarness.getInstance().setCommand(AgentCommand.resume_ramp); + default -> throw new UnsupportedOperationException("Unknown command: " + command); + } + } + public static void startTest() { Thread thread = new Thread( () -> APITestHarness.getInstance().runConcurrentTestPlans()); thread.setDaemon(true); diff --git a/agent/apiharness/src/test/java/com/intuit/tank/harness/APITestHarnessTest.java b/agent/apiharness/src/test/java/com/intuit/tank/harness/APITestHarnessTest.java index 96f9e513f..cc622eb9d 100644 --- a/agent/apiharness/src/test/java/com/intuit/tank/harness/APITestHarnessTest.java +++ b/agent/apiharness/src/test/java/com/intuit/tank/harness/APITestHarnessTest.java @@ -19,6 +19,7 @@ import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.PrintStream; +import java.util.Map; import java.util.stream.IntStream; import static org.junit.jupiter.api.Assertions.*; @@ -84,6 +85,15 @@ public void testInitializeFromMainArgs() throws IOException { } } + @Test + public void testControllerInitiatedWsModeEnabledFromUserData() { + try (MockedStatic amazonUtil = Mockito.mockStatic(AmazonUtil.class)) { + amazonUtil.when(AmazonUtil::getUserDataAsMap).thenReturn( + Map.of("controllerInitiatedWsEnabled", "true")); + assertTrue(APITestHarness.getInstance().isControllerInitiatedWsModeEnabled()); + } + } + @Test public void testUsageMain() { PrintStream originalOut = System.out; diff --git a/agent/apiharness/src/test/java/com/intuit/tank/harness/AgentCommandWebSocketClientTest.java b/agent/apiharness/src/test/java/com/intuit/tank/harness/AgentCommandWebSocketClientTest.java new file mode 100644 index 000000000..f0681421c --- /dev/null +++ b/agent/apiharness/src/test/java/com/intuit/tank/harness/AgentCommandWebSocketClientTest.java @@ -0,0 +1,45 @@ +package com.intuit.tank.harness; + +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.*; + +public class AgentCommandWebSocketClientTest { + + @Test + void testConstructorSetsFields() { + AgentCommandWebSocketClient client = new AgentCommandWebSocketClient( + "http://controller.example.com/tank", + "/v2/agent/ws/control", + "test-token", + "i-123", + "job-1" + ); + assertNotNull(client); + } + + @Test + void testConstructorWithHttpsUrl() { + AgentCommandWebSocketClient client = new AgentCommandWebSocketClient( + "https://controller.example.com/tank", + "/v2/agent/ws/control", + "test-token", + "i-456", + "job-2" + ); + assertNotNull(client); + } + + @Test + void testCloseBeforeConnect() { + AgentCommandWebSocketClient client = new AgentCommandWebSocketClient( + "http://localhost:8080/tank", + "/v2/agent/ws/control", + "token", + "i-789", + "job-3" + ); + // Should not throw + client.close(); + } +} diff --git a/agent/apiharness/src/test/java/com/intuit/tank/harness/CommandListenerTest.java b/agent/apiharness/src/test/java/com/intuit/tank/harness/CommandListenerTest.java new file mode 100644 index 000000000..aedc6bf4d --- /dev/null +++ b/agent/apiharness/src/test/java/com/intuit/tank/harness/CommandListenerTest.java @@ -0,0 +1,13 @@ +package com.intuit.tank.harness; + +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertThrows; + +public class CommandListenerTest { + + @Test + public void testApplyCommandUnknownThrows() { + assertThrows(UnsupportedOperationException.class, () -> CommandListener.applyCommand("not-a-real-command")); + } +} diff --git a/api/src/main/java/com/intuit/tank/vm/agent/messages/AgentWsCommandSender.java b/api/src/main/java/com/intuit/tank/vm/agent/messages/AgentWsCommandSender.java new file mode 100644 index 000000000..18330d8d1 --- /dev/null +++ b/api/src/main/java/com/intuit/tank/vm/agent/messages/AgentWsCommandSender.java @@ -0,0 +1,19 @@ +package com.intuit.tank.vm.agent.messages; + +/** + * Interface for sending commands to agents via WebSocket. + * Implemented by the controller's WS handler, consumed by JobManager via CDI. + */ +public interface AgentWsCommandSender { + + /** + * Check if an agent has an active WS session. + */ + boolean hasSession(String instanceId); + + /** + * Send a command to an agent via WS and wait for ack. + * @return true if command was acked successfully within timeout + */ + boolean sendCommand(String instanceId, String jobId, String command, long ackTimeoutMillis); +} diff --git a/api/src/main/java/com/intuit/tank/vm/agent/messages/AgentWsEnvelope.java b/api/src/main/java/com/intuit/tank/vm/agent/messages/AgentWsEnvelope.java new file mode 100644 index 000000000..4844682dd --- /dev/null +++ b/api/src/main/java/com/intuit/tank/vm/agent/messages/AgentWsEnvelope.java @@ -0,0 +1,212 @@ +package com.intuit.tank.vm.agent.messages; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.databind.ObjectMapper; + +import java.io.IOException; + +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonInclude(JsonInclude.Include.NON_NULL) +public class AgentWsEnvelope { + + private static final ObjectMapper MAPPER = new ObjectMapper(); + + public static final int PROTOCOL_VERSION = 1; + + public enum Type { + hello, command, ack, ping, pong, close + } + + public enum AckStatus { + ok, duplicate, failed, unsupported + } + + @JsonProperty("type") + private Type type; + + @JsonProperty("instanceId") + private String instanceId; + + @JsonProperty("sentAtMs") + private long sentAtMs; + + @JsonProperty("protocolVersion") + private int protocolVersion = PROTOCOL_VERSION; + + // hello fields + @JsonProperty("jobId") + private String jobId; + + @JsonProperty("agentSessionId") + private String agentSessionId; + + @JsonProperty("capacity") + private Integer capacity; + + // command fields + @JsonProperty("commandId") + private String commandId; + + @JsonProperty("command") + private String command; + + // ack fields + @JsonProperty("ackForType") + private String ackForType; + + @JsonProperty("ackForId") + private String ackForId; + + @JsonProperty("status") + private AckStatus status; + + @JsonProperty("error") + private String error; + + // ping/pong fields + @JsonProperty("pingId") + private String pingId; + + @JsonProperty("lastAppliedCommandId") + private String lastAppliedCommandId; + + // close fields + @JsonProperty("reasonCode") + private String reasonCode; + + @JsonProperty("reason") + private String reason; + + public AgentWsEnvelope() {} + + public Type getType() { return type; } + public void setType(Type type) { this.type = type; } + + public String getInstanceId() { return instanceId; } + public void setInstanceId(String instanceId) { this.instanceId = instanceId; } + + public long getSentAtMs() { return sentAtMs; } + public void setSentAtMs(long sentAtMs) { this.sentAtMs = sentAtMs; } + + public int getProtocolVersion() { return protocolVersion; } + public void setProtocolVersion(int protocolVersion) { this.protocolVersion = protocolVersion; } + + public String getJobId() { return jobId; } + public void setJobId(String jobId) { this.jobId = jobId; } + + public String getAgentSessionId() { return agentSessionId; } + public void setAgentSessionId(String agentSessionId) { this.agentSessionId = agentSessionId; } + + public Integer getCapacity() { return capacity; } + public void setCapacity(Integer capacity) { this.capacity = capacity; } + + public String getCommandId() { return commandId; } + public void setCommandId(String commandId) { this.commandId = commandId; } + + public String getCommand() { return command; } + public void setCommand(String command) { this.command = command; } + + public String getAckForType() { return ackForType; } + public void setAckForType(String ackForType) { this.ackForType = ackForType; } + + public String getAckForId() { return ackForId; } + public void setAckForId(String ackForId) { this.ackForId = ackForId; } + + public AckStatus getStatus() { return status; } + public void setStatus(AckStatus status) { this.status = status; } + + public String getError() { return error; } + public void setError(String error) { this.error = error; } + + public String getPingId() { return pingId; } + public void setPingId(String pingId) { this.pingId = pingId; } + + public String getLastAppliedCommandId() { return lastAppliedCommandId; } + public void setLastAppliedCommandId(String lastAppliedCommandId) { this.lastAppliedCommandId = lastAppliedCommandId; } + + public String getReasonCode() { return reasonCode; } + public void setReasonCode(String reasonCode) { this.reasonCode = reasonCode; } + + public String getReason() { return reason; } + public void setReason(String reason) { this.reason = reason; } + + public String toJson() throws IOException { + return MAPPER.writeValueAsString(this); + } + + public static AgentWsEnvelope fromJson(String json) throws IOException { + return MAPPER.readValue(json, AgentWsEnvelope.class); + } + + // Factory methods for common frame types + + public static AgentWsEnvelope hello(String instanceId, String jobId, String agentSessionId, String lastAppliedCommandId) { + return hello(instanceId, jobId, agentSessionId, lastAppliedCommandId, null); + } + + public static AgentWsEnvelope hello(String instanceId, String jobId, String agentSessionId, + String lastAppliedCommandId, Integer capacity) { + AgentWsEnvelope env = new AgentWsEnvelope(); + env.setType(Type.hello); + env.setInstanceId(instanceId); + env.setJobId(jobId); + env.setAgentSessionId(agentSessionId); + env.setLastAppliedCommandId(lastAppliedCommandId); + env.setCapacity(capacity); + env.setSentAtMs(System.currentTimeMillis()); + return env; + } + + public static AgentWsEnvelope command(String commandId, String instanceId, String jobId, String command) { + AgentWsEnvelope env = new AgentWsEnvelope(); + env.setType(Type.command); + env.setCommandId(commandId); + env.setInstanceId(instanceId); + env.setJobId(jobId); + env.setCommand(command); + env.setSentAtMs(System.currentTimeMillis()); + return env; + } + + public static AgentWsEnvelope ack(String instanceId, String ackForType, String ackForId, AckStatus status) { + AgentWsEnvelope env = new AgentWsEnvelope(); + env.setType(Type.ack); + env.setInstanceId(instanceId); + env.setAckForType(ackForType); + env.setAckForId(ackForId); + env.setStatus(status); + env.setSentAtMs(System.currentTimeMillis()); + return env; + } + + public static AgentWsEnvelope ping(String pingId) { + AgentWsEnvelope env = new AgentWsEnvelope(); + env.setType(Type.ping); + env.setPingId(pingId); + env.setSentAtMs(System.currentTimeMillis()); + return env; + } + + public static AgentWsEnvelope pong(String instanceId, String agentSessionId, String pingId, String lastAppliedCommandId) { + AgentWsEnvelope env = new AgentWsEnvelope(); + env.setType(Type.pong); + env.setInstanceId(instanceId); + env.setAgentSessionId(agentSessionId); + env.setPingId(pingId); + env.setLastAppliedCommandId(lastAppliedCommandId); + env.setSentAtMs(System.currentTimeMillis()); + return env; + } + + public static AgentWsEnvelope close(String instanceId, String reasonCode, String reason) { + AgentWsEnvelope env = new AgentWsEnvelope(); + env.setType(Type.close); + env.setInstanceId(instanceId); + env.setReasonCode(reasonCode); + env.setReason(reason); + env.setSentAtMs(System.currentTimeMillis()); + return env; + } +} diff --git a/api/src/main/java/com/intuit/tank/vm/common/TankConstants.java b/api/src/main/java/com/intuit/tank/vm/common/TankConstants.java index d730d284e..124749dda 100644 --- a/api/src/main/java/com/intuit/tank/vm/common/TankConstants.java +++ b/api/src/main/java/com/intuit/tank/vm/common/TankConstants.java @@ -58,6 +58,9 @@ public class TankConstants { public static final String KEY_JVM_ARGS = "jvmArgs"; public static final String KEY_AWS_SECRET_KEY_ID = "AWS_SECRET_KEY_ID"; public static final String KEY_AWS_SECRET_KEY = "AWS_SECRET_KEY"; + public static final String KEY_CONTROLLER_INITIATED_WS_ENABLED = "controllerInitiatedWsEnabled"; + public static final String KEY_CONTROLLER_INITIATED_WS_DISABLE_AGENT_HTTP = "controllerInitiatedWsDisableAgentHttp"; + public static final String KEY_CONTROLLER_INITIATED_WS_SCRIPT_PATH = "controllerInitiatedWsScriptPath"; public static final String HTTP_CASE_SKIP = "SKIP"; public static final String HTTP_CASE_SKIPGROUP = "SKIPGROUP"; diff --git a/api/src/main/java/com/intuit/tank/vm/settings/AgentConfig.java b/api/src/main/java/com/intuit/tank/vm/settings/AgentConfig.java index 54efd4b68..db7072efa 100644 --- a/api/src/main/java/com/intuit/tank/vm/settings/AgentConfig.java +++ b/api/src/main/java/com/intuit/tank/vm/settings/AgentConfig.java @@ -64,6 +64,14 @@ public class AgentConfig implements Serializable { private static final String KEY_LOG_VARIABLES = "log-variables"; private static final String KEY_CONNECTION_TIMEOUT = "connection-timeout"; + private static final String KEY_COMMAND_WS_ENABLED = "command-ws-enabled"; + private static final String KEY_COMMAND_WS_HTTP_FALLBACK_ENABLED = "command-ws-http-fallback-enabled"; + private static final String KEY_COMMAND_WS_ACK_TIMEOUT_MILLIS = "command-ws-ack-timeout-millis"; + private static final String KEY_COMMAND_WS_PATH = "command-ws-path"; + private static final String KEY_CONTROLLER_INITIATED_WS_ENABLED = "controller-initiated-ws-enabled"; + private static final String KEY_CONTROLLER_INITIATED_WS_DISABLE_AGENT_HTTP = "controller-initiated-ws-disable-agent-http"; + private static final String KEY_CONTROLLER_INITIATED_WS_SCRIPT_PATH = "controller-initiated-ws-script-path"; + private static final String KEY_REQUEST_HEADERS = "request-headers/header"; // private static final String KEY_RESULT_PROVIDERS = // "result-providers/provider"; @@ -336,4 +344,32 @@ public long getStatusReportIntervalMilis(long pollTime) { return config.getLong(KEY_POLL_TIME_MILIS, pollTime); } + public boolean isCommandWsEnabled() { + return config.getBoolean(KEY_COMMAND_WS_ENABLED, false); + } + + public boolean isCommandWsHttpFallbackEnabled() { + return config.getBoolean(KEY_COMMAND_WS_HTTP_FALLBACK_ENABLED, true); + } + + public long getCommandWsAckTimeoutMillis() { + return config.getLong(KEY_COMMAND_WS_ACK_TIMEOUT_MILLIS, 3000L); + } + + public String getCommandWsPath() { + return config.getString(KEY_COMMAND_WS_PATH, "/v2/agent/ws/control"); + } + + public boolean isControllerInitiatedWsEnabled() { + return config.getBoolean(KEY_CONTROLLER_INITIATED_WS_ENABLED, false); + } + + public boolean isControllerInitiatedWsDisableAgentHttp() { + return config.getBoolean(KEY_CONTROLLER_INITIATED_WS_DISABLE_AGENT_HTTP, true); + } + + public String getControllerInitiatedWsScriptPath() { + return config.getString(KEY_CONTROLLER_INITIATED_WS_SCRIPT_PATH, "script.xml"); + } + } diff --git a/api/src/main/resources/settings.xml b/api/src/main/resources/settings.xml index b6f1105bc..63a9a7ee8 100644 --- a/api/src/main/resources/settings.xml +++ b/api/src/main/resources/settings.xml @@ -137,6 +137,15 @@ JDK HttpClient + + false + true + 3000 + /v2/agent/ws/control + false + true + script.xml + diff --git a/api/src/test/java/com/intuit/tank/vm/agent/messages/AgentWsEnvelopeTest.java b/api/src/test/java/com/intuit/tank/vm/agent/messages/AgentWsEnvelopeTest.java new file mode 100644 index 000000000..2f7aebaa4 --- /dev/null +++ b/api/src/test/java/com/intuit/tank/vm/agent/messages/AgentWsEnvelopeTest.java @@ -0,0 +1,184 @@ +package com.intuit.tank.vm.agent.messages; + +import com.intuit.tank.vm.agent.messages.AgentWsEnvelope.AckStatus; +import com.intuit.tank.vm.agent.messages.AgentWsEnvelope.Type; +import org.junit.jupiter.api.Test; + +import java.io.IOException; + +import static org.junit.jupiter.api.Assertions.*; + +public class AgentWsEnvelopeTest { + + @Test + public void testHelloFactory() throws IOException { + AgentWsEnvelope env = AgentWsEnvelope.hello("i-123", "job-1", "sess-1", "cmd-99"); + + assertEquals(Type.hello, env.getType()); + assertEquals("i-123", env.getInstanceId()); + assertEquals("job-1", env.getJobId()); + assertEquals("sess-1", env.getAgentSessionId()); + assertEquals("cmd-99", env.getLastAppliedCommandId()); + assertEquals(AgentWsEnvelope.PROTOCOL_VERSION, env.getProtocolVersion()); + assertTrue(env.getSentAtMs() > 0); + } + + @Test + public void testHelloFactoryWithCapacity() throws IOException { + AgentWsEnvelope env = AgentWsEnvelope.hello("i-123", "job-1", "sess-1", "cmd-99", 4000); + + assertEquals(Type.hello, env.getType()); + assertEquals("i-123", env.getInstanceId()); + assertEquals("job-1", env.getJobId()); + assertEquals("sess-1", env.getAgentSessionId()); + assertEquals("cmd-99", env.getLastAppliedCommandId()); + assertEquals(4000, env.getCapacity()); + + String json = env.toJson(); + AgentWsEnvelope parsed = AgentWsEnvelope.fromJson(json); + assertEquals(4000, parsed.getCapacity()); + } + + @Test + public void testCommandFactory() throws IOException { + AgentWsEnvelope env = AgentWsEnvelope.command("cmd-1", "i-123", "job-1", "start"); + + assertEquals(Type.command, env.getType()); + assertEquals("cmd-1", env.getCommandId()); + assertEquals("i-123", env.getInstanceId()); + assertEquals("job-1", env.getJobId()); + assertEquals("start", env.getCommand()); + } + + @Test + public void testAckFactory() { + AgentWsEnvelope env = AgentWsEnvelope.ack("i-123", "command", "cmd-1", AckStatus.ok); + + assertEquals(Type.ack, env.getType()); + assertEquals("i-123", env.getInstanceId()); + assertEquals("command", env.getAckForType()); + assertEquals("cmd-1", env.getAckForId()); + assertEquals(AckStatus.ok, env.getStatus()); + } + + @Test + public void testPingFactory() { + AgentWsEnvelope env = AgentWsEnvelope.ping("ping-1"); + + assertEquals(Type.ping, env.getType()); + assertEquals("ping-1", env.getPingId()); + } + + @Test + public void testPongFactory() { + AgentWsEnvelope env = AgentWsEnvelope.pong("i-123", "sess-1", "ping-1", "cmd-5"); + + assertEquals(Type.pong, env.getType()); + assertEquals("i-123", env.getInstanceId()); + assertEquals("sess-1", env.getAgentSessionId()); + assertEquals("ping-1", env.getPingId()); + assertEquals("cmd-5", env.getLastAppliedCommandId()); + } + + @Test + public void testCloseFactory() { + AgentWsEnvelope env = AgentWsEnvelope.close("i-123", "shutdown", "Agent shutting down"); + + assertEquals(Type.close, env.getType()); + assertEquals("i-123", env.getInstanceId()); + assertEquals("shutdown", env.getReasonCode()); + assertEquals("Agent shutting down", env.getReason()); + } + + @Test + public void testJsonRoundTrip() throws IOException { + AgentWsEnvelope original = AgentWsEnvelope.command("cmd-1", "i-123", "job-1", "stop"); + String json = original.toJson(); + + assertNotNull(json); + assertTrue(json.contains("\"type\":\"command\"")); + assertTrue(json.contains("\"commandId\":\"cmd-1\"")); + assertTrue(json.contains("\"command\":\"stop\"")); + + AgentWsEnvelope parsed = AgentWsEnvelope.fromJson(json); + assertEquals(Type.command, parsed.getType()); + assertEquals("cmd-1", parsed.getCommandId()); + assertEquals("i-123", parsed.getInstanceId()); + assertEquals("job-1", parsed.getJobId()); + assertEquals("stop", parsed.getCommand()); + } + + @Test + public void testJsonNullFieldsOmitted() throws IOException { + AgentWsEnvelope env = AgentWsEnvelope.ping("ping-1"); + String json = env.toJson(); + + // Null fields should not be present + assertFalse(json.contains("\"instanceId\"")); + assertFalse(json.contains("\"jobId\"")); + assertFalse(json.contains("\"commandId\"")); + } + + @Test + public void testFromJsonUnknownFieldsIgnored() throws IOException { + String json = "{\"type\":\"hello\",\"instanceId\":\"i-1\",\"jobId\":\"j-1\",\"agentSessionId\":\"s-1\",\"unknownField\":\"value\",\"sentAtMs\":1000,\"protocolVersion\":1}"; + AgentWsEnvelope env = AgentWsEnvelope.fromJson(json); + + assertEquals(Type.hello, env.getType()); + assertEquals("i-1", env.getInstanceId()); + assertEquals("j-1", env.getJobId()); + } + + @Test + public void testFromJsonInvalidType() { + String json = "{\"type\":\"bogus\",\"instanceId\":\"i-1\"}"; + assertThrows(IOException.class, () -> AgentWsEnvelope.fromJson(json)); + } + + @Test + public void testFromJsonMalformed() { + assertThrows(IOException.class, () -> AgentWsEnvelope.fromJson("not json")); + } + + @Test + public void testFromJsonEmptyObject() throws IOException { + AgentWsEnvelope env = AgentWsEnvelope.fromJson("{}"); + assertNull(env.getType()); + assertNull(env.getInstanceId()); + } + + @Test + public void testAckStatusValues() { + assertEquals(4, AckStatus.values().length); + assertNotNull(AckStatus.valueOf("ok")); + assertNotNull(AckStatus.valueOf("duplicate")); + assertNotNull(AckStatus.valueOf("failed")); + assertNotNull(AckStatus.valueOf("unsupported")); + } + + @Test + public void testTypeValues() { + assertEquals(6, Type.values().length); + assertNotNull(Type.valueOf("hello")); + assertNotNull(Type.valueOf("command")); + assertNotNull(Type.valueOf("ack")); + assertNotNull(Type.valueOf("ping")); + assertNotNull(Type.valueOf("pong")); + assertNotNull(Type.valueOf("close")); + } + + @Test + public void testAckRoundTrip() throws IOException { + AgentWsEnvelope ack = AgentWsEnvelope.ack("i-123", "command", "cmd-1", AckStatus.duplicate); + ack.setError("already applied"); + ack.setAgentSessionId("sess-1"); + + String json = ack.toJson(); + AgentWsEnvelope parsed = AgentWsEnvelope.fromJson(json); + + assertEquals(Type.ack, parsed.getType()); + assertEquals(AckStatus.duplicate, parsed.getStatus()); + assertEquals("already applied", parsed.getError()); + assertEquals("sess-1", parsed.getAgentSessionId()); + } +} diff --git a/api/src/test/java/com/intuit/tank/vm/settings/AgentConfigCpTest.java b/api/src/test/java/com/intuit/tank/vm/settings/AgentConfigCpTest.java index 23065e86a..6e119e915 100644 --- a/api/src/test/java/com/intuit/tank/vm/settings/AgentConfigCpTest.java +++ b/api/src/test/java/com/intuit/tank/vm/settings/AgentConfigCpTest.java @@ -480,4 +480,17 @@ public void testSetResultsTypeMap_1() fixture.setResultsTypeMap(resultsTypeMap); } + + @Test + public void testWsConfigDefaults() throws Exception { + AgentConfig fixture = new AgentConfig(new BasicConfigurationBuilder<>(XMLConfiguration.class).getConfiguration()); + + assertFalse(fixture.isCommandWsEnabled()); + assertTrue(fixture.isCommandWsHttpFallbackEnabled()); + assertEquals(3000L, fixture.getCommandWsAckTimeoutMillis()); + assertEquals("/v2/agent/ws/control", fixture.getCommandWsPath()); + assertFalse(fixture.isControllerInitiatedWsEnabled()); + assertTrue(fixture.isControllerInitiatedWsDisableAgentHttp()); + assertEquals("script.xml", fixture.getControllerInitiatedWsScriptPath()); + } } \ No newline at end of file diff --git a/pom.xml b/pom.xml index 24a06fca5..c5056e1a3 100644 --- a/pom.xml +++ b/pom.xml @@ -769,6 +769,11 @@ spring-webmvc ${version.spring-webmvc} + + org.springframework + spring-websocket + ${version.spring-webmvc} + org.wiremock wiremock diff --git a/rest-mvc/impl/pom.xml b/rest-mvc/impl/pom.xml index e562670d4..a7e5eeff0 100644 --- a/rest-mvc/impl/pom.xml +++ b/rest-mvc/impl/pom.xml @@ -47,6 +47,10 @@ tank-script-processor ${project.version} + + org.springframework + spring-websocket + org.apache.tomcat tomcat-coyote diff --git a/rest-mvc/impl/src/main/java/com/intuit/tank/rest/mvc/AgentCommandWebSocketConfig.java b/rest-mvc/impl/src/main/java/com/intuit/tank/rest/mvc/AgentCommandWebSocketConfig.java new file mode 100644 index 000000000..689c40a60 --- /dev/null +++ b/rest-mvc/impl/src/main/java/com/intuit/tank/rest/mvc/AgentCommandWebSocketConfig.java @@ -0,0 +1,34 @@ +package com.intuit.tank.rest.mvc; + +import com.intuit.tank.vm.settings.TankConfig; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.web.socket.config.annotation.EnableWebSocket; +import org.springframework.web.socket.config.annotation.WebSocketConfigurer; +import org.springframework.web.socket.config.annotation.WebSocketHandlerRegistry; + +@Configuration +@EnableWebSocket +public class AgentCommandWebSocketConfig implements WebSocketConfigurer { + + @Bean + public AgentCommandWebSocketHandler agentCommandWebSocketHandler() { + AgentCommandWebSocketHandler handler = new AgentCommandWebSocketHandler(); + AgentWsCommandSenderHolder.set(handler); + return handler; + } + + @Override + public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) { + String path = "/v2/agent/ws/control"; + try { + path = new TankConfig().getAgentConfig().getCommandWsPath(); + } catch (Exception e) { + // Fall back to default path if config not available during Spring init + } + // Agents connect via JDK HttpClient (no Origin header), not browsers. + // Allow all origins since auth is handled via bearer token in handshake. + registry.addHandler(agentCommandWebSocketHandler(), path) + .setAllowedOriginPatterns("*"); + } +} diff --git a/rest-mvc/impl/src/main/java/com/intuit/tank/rest/mvc/AgentCommandWebSocketHandler.java b/rest-mvc/impl/src/main/java/com/intuit/tank/rest/mvc/AgentCommandWebSocketHandler.java new file mode 100644 index 000000000..7bdee2ce8 --- /dev/null +++ b/rest-mvc/impl/src/main/java/com/intuit/tank/rest/mvc/AgentCommandWebSocketHandler.java @@ -0,0 +1,204 @@ +package com.intuit.tank.rest.mvc; + +import com.intuit.tank.vm.agent.messages.AgentWsCommandSender; +import com.intuit.tank.vm.agent.messages.AgentWsEnvelope; +import com.intuit.tank.vm.agent.messages.AgentWsEnvelope.AckStatus; +import com.intuit.tank.vm.agent.messages.AgentWsEnvelope.Type; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ObjectMessage; +import org.springframework.web.socket.CloseStatus; +import org.springframework.web.socket.TextMessage; +import org.springframework.web.socket.WebSocketSession; +import org.springframework.web.socket.handler.TextWebSocketHandler; + +import java.io.IOException; +import java.util.Map; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.TimeUnit; + +public class AgentCommandWebSocketHandler extends TextWebSocketHandler implements AgentWsCommandSender { + + private static final Logger LOG = LogManager.getLogger(AgentCommandWebSocketHandler.class); + + // instanceId -> active WS session + private final ConcurrentHashMap agentSessions = new ConcurrentHashMap<>(); + + // sessionId -> bound instanceId (one identity per session, immutable after hello) + private final ConcurrentHashMap sessionIdentity = new ConcurrentHashMap<>(); + + // instanceId -> last seen timestamp + private final ConcurrentHashMap agentLastSeen = new ConcurrentHashMap<>(); + + // commandId -> ack future (for pending ack tracking) + private final ConcurrentHashMap> pendingAcks = new ConcurrentHashMap<>(); + + @Override + public void afterConnectionEstablished(WebSocketSession session) { + LOG.info(new ObjectMessage(Map.of("Message", "WS connection established: " + session.getId()))); + } + + @Override + protected void handleTextMessage(WebSocketSession session, TextMessage message) throws Exception { + AgentWsEnvelope envelope; + try { + envelope = AgentWsEnvelope.fromJson(message.getPayload()); + } catch (IOException e) { + LOG.warn(new ObjectMessage(Map.of("Message", "Invalid WS frame, closing session: " + e.getMessage()))); + session.close(CloseStatus.BAD_DATA); + return; + } + + if (envelope.getType() == null) { + LOG.warn(new ObjectMessage(Map.of("Message", "WS frame missing type, closing session"))); + session.close(CloseStatus.BAD_DATA); + return; + } + + // Require hello before any other frame type + if (envelope.getType() != Type.hello && !sessionIdentity.containsKey(session.getId())) { + LOG.warn(new ObjectMessage(Map.of("Message", "WS frame received before hello, closing session"))); + session.close(CloseStatus.POLICY_VIOLATION); + return; + } + + switch (envelope.getType()) { + case hello -> handleHello(session, envelope); + case ack -> handleAck(session, envelope); + case pong -> handlePong(session, envelope); + case close -> handleClose(session, envelope); + default -> { + LOG.warn(new ObjectMessage(Map.of("Message", "Unexpected WS frame type from agent: " + envelope.getType()))); + sendAck(session, envelope.getInstanceId(), "hello", session.getId(), AckStatus.unsupported); + } + } + } + + private void handleHello(WebSocketSession session, AgentWsEnvelope envelope) throws IOException { + String instanceId = envelope.getInstanceId(); + if (instanceId == null || envelope.getJobId() == null || envelope.getAgentSessionId() == null) { + LOG.warn(new ObjectMessage(Map.of("Message", "Hello frame missing required fields, closing"))); + session.close(CloseStatus.BAD_DATA); + return; + } + + // Check if this session already has a bound identity — reject rebind + String existingIdentity = sessionIdentity.get(session.getId()); + if (existingIdentity != null && !existingIdentity.equals(instanceId)) { + LOG.warn(new ObjectMessage(Map.of("Message", "Session " + session.getId() + " already bound to " + existingIdentity + + ", rejecting rebind to " + instanceId))); + session.close(CloseStatus.POLICY_VIOLATION); + return; + } + + // Bind identity to session + sessionIdentity.put(session.getId(), instanceId); + + // Replace old session if agent reconnects from a different connection + WebSocketSession oldSession = agentSessions.put(instanceId, session); + if (oldSession != null && oldSession.isOpen() && !oldSession.getId().equals(session.getId())) { + LOG.info(new ObjectMessage(Map.of("Message", "Replacing old WS session for agent " + instanceId))); + sessionIdentity.remove(oldSession.getId()); + try { oldSession.close(CloseStatus.GOING_AWAY); } catch (IOException ignored) {} + } + + agentLastSeen.put(instanceId, System.currentTimeMillis()); + LOG.info(new ObjectMessage(Map.of("Message", "Agent " + instanceId + " registered via WS for job " + envelope.getJobId() + + " (agentSession=" + envelope.getAgentSessionId() + ")"))); + + sendAck(session, instanceId, "hello", envelope.getAgentSessionId(), AckStatus.ok); + } + + private void handleAck(WebSocketSession session, AgentWsEnvelope envelope) { + String ackForId = envelope.getAckForId(); + if (ackForId != null) { + CompletableFuture future = pendingAcks.remove(ackForId); + if (future != null) { + future.complete(envelope); + } + } + String boundId = sessionIdentity.get(session.getId()); + if (boundId != null) { + agentLastSeen.put(boundId, System.currentTimeMillis()); + } + } + + private void handlePong(WebSocketSession session, AgentWsEnvelope envelope) { + String boundId = sessionIdentity.get(session.getId()); + if (boundId != null) { + agentLastSeen.put(boundId, System.currentTimeMillis()); + } + } + + private void handleClose(WebSocketSession session, AgentWsEnvelope envelope) throws IOException { + String boundId = sessionIdentity.remove(session.getId()); + if (boundId != null) { + LOG.info(new ObjectMessage(Map.of("Message", "Agent " + boundId + " sent close: " + envelope.getReason()))); + agentSessions.remove(boundId, session); + agentLastSeen.remove(boundId); + } + session.close(CloseStatus.NORMAL); + } + + @Override + public void afterConnectionClosed(WebSocketSession session, CloseStatus status) { + String boundId = sessionIdentity.remove(session.getId()); + if (boundId != null) { + agentSessions.remove(boundId, session); + agentLastSeen.remove(boundId); + } + LOG.info(new ObjectMessage(Map.of("Message", "WS session closed: " + session.getId() + " status=" + status))); + } + + @Override + public void handleTransportError(WebSocketSession session, Throwable exception) { + LOG.error(new ObjectMessage(Map.of("Message", "WS transport error for session " + session.getId() + ": " + exception.getMessage())), exception); + String boundId = sessionIdentity.remove(session.getId()); + if (boundId != null) { + agentSessions.remove(boundId, session); + agentLastSeen.remove(boundId); + } + } + + @Override + public boolean hasSession(String instanceId) { + WebSocketSession session = agentSessions.get(instanceId); + return session != null && session.isOpen(); + } + + @Override + public boolean sendCommand(String instanceId, String jobId, String command, long ackTimeoutMillis) { + WebSocketSession session = agentSessions.get(instanceId); + if (session == null || !session.isOpen()) { + return false; + } + + String commandId = UUID.randomUUID().toString(); + CompletableFuture ackFuture = new CompletableFuture<>(); + pendingAcks.put(commandId, ackFuture); + + try { + AgentWsEnvelope cmdEnvelope = AgentWsEnvelope.command(commandId, instanceId, jobId, command); + synchronized (session) { + session.sendMessage(new TextMessage(cmdEnvelope.toJson())); + } + LOG.info(new ObjectMessage(Map.of("Message", "Sent WS command " + command + " (id=" + commandId + ") to agent " + instanceId))); + + AgentWsEnvelope ack = ackFuture.get(ackTimeoutMillis, TimeUnit.MILLISECONDS); + return ack != null && ack.getStatus() == AckStatus.ok; + } catch (Exception e) { + LOG.warn(new ObjectMessage(Map.of("Message", "WS command " + command + " to agent " + instanceId + " failed: " + e.getMessage()))); + pendingAcks.remove(commandId); + return false; + } + } + + private void sendAck(WebSocketSession session, String instanceId, String ackForType, String ackForId, AckStatus status) throws IOException { + AgentWsEnvelope ack = AgentWsEnvelope.ack(instanceId, ackForType, ackForId, status); + synchronized (session) { + session.sendMessage(new TextMessage(ack.toJson())); + } + } +} diff --git a/rest-mvc/impl/src/main/java/com/intuit/tank/rest/mvc/AgentWsCommandSenderHolder.java b/rest-mvc/impl/src/main/java/com/intuit/tank/rest/mvc/AgentWsCommandSenderHolder.java new file mode 100644 index 000000000..2ad5da825 --- /dev/null +++ b/rest-mvc/impl/src/main/java/com/intuit/tank/rest/mvc/AgentWsCommandSenderHolder.java @@ -0,0 +1,20 @@ +package com.intuit.tank.rest.mvc; + +import com.intuit.tank.vm.agent.messages.AgentWsCommandSender; + +/** + * Static holder so the Spring-managed WS handler can be accessed by CDI producer. + * Set by AgentCommandWebSocketConfig at Spring init time. + */ +public class AgentWsCommandSenderHolder { + + private static volatile AgentWsCommandSender instance; + + public static void set(AgentWsCommandSender sender) { + instance = sender; + } + + public static AgentWsCommandSender get() { + return instance; + } +} diff --git a/rest-mvc/impl/src/main/java/com/intuit/tank/rest/mvc/AgentWsCommandSenderProducer.java b/rest-mvc/impl/src/main/java/com/intuit/tank/rest/mvc/AgentWsCommandSenderProducer.java new file mode 100644 index 000000000..e607b39ab --- /dev/null +++ b/rest-mvc/impl/src/main/java/com/intuit/tank/rest/mvc/AgentWsCommandSenderProducer.java @@ -0,0 +1,18 @@ +package com.intuit.tank.rest.mvc; + +import com.intuit.tank.vm.agent.messages.AgentWsCommandSender; +import jakarta.enterprise.context.ApplicationScoped; +import jakarta.enterprise.inject.Produces; + +/** + * CDI producer that bridges the Spring-managed AgentCommandWebSocketHandler + * into the CDI container so JobManager can inject it via Instance. + */ +@ApplicationScoped +public class AgentWsCommandSenderProducer { + + @Produces + public AgentWsCommandSender getAgentWsCommandSender() { + return AgentWsCommandSenderHolder.get(); + } +} diff --git a/rest-mvc/impl/src/test/java/com/intuit/tank/rest/mvc/AgentCommandWebSocketHandlerTest.java b/rest-mvc/impl/src/test/java/com/intuit/tank/rest/mvc/AgentCommandWebSocketHandlerTest.java new file mode 100644 index 000000000..2d1631098 --- /dev/null +++ b/rest-mvc/impl/src/test/java/com/intuit/tank/rest/mvc/AgentCommandWebSocketHandlerTest.java @@ -0,0 +1,241 @@ +package com.intuit.tank.rest.mvc; + +import com.intuit.tank.vm.agent.messages.AgentWsEnvelope; +import com.intuit.tank.vm.agent.messages.AgentWsEnvelope.AckStatus; +import com.intuit.tank.vm.agent.messages.AgentWsEnvelope.Type; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.springframework.web.socket.CloseStatus; +import org.springframework.web.socket.TextMessage; +import org.springframework.web.socket.WebSocketSession; + +import java.io.IOException; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.Mockito.*; + +public class AgentCommandWebSocketHandlerTest { + + private AgentCommandWebSocketHandler handler; + private WebSocketSession session; + + @BeforeEach + void setUp() { + handler = new AgentCommandWebSocketHandler(); + session = mock(WebSocketSession.class); + when(session.getId()).thenReturn("test-session-1"); + when(session.isOpen()).thenReturn(true); + } + + @Test + void testHelloRegistersSession() throws Exception { + AgentWsEnvelope hello = AgentWsEnvelope.hello("i-123", "job-1", "sess-1", null); + handler.handleTextMessage(session, new TextMessage(hello.toJson())); + + assertTrue(handler.hasSession("i-123")); + + // Verify ack was sent + verify(session).sendMessage(argThat(msg -> { + String payload = ((TextMessage) msg).getPayload(); + return payload.contains("\"type\":\"ack\"") && payload.contains("\"status\":\"ok\""); + })); + } + + @Test + void testHelloReRegisterReplacesSession() throws Exception { + AgentWsEnvelope hello1 = AgentWsEnvelope.hello("i-123", "job-1", "sess-1", null); + handler.handleTextMessage(session, new TextMessage(hello1.toJson())); + + WebSocketSession session2 = mock(WebSocketSession.class); + when(session2.getId()).thenReturn("test-session-2"); + when(session2.isOpen()).thenReturn(true); + + AgentWsEnvelope hello2 = AgentWsEnvelope.hello("i-123", "job-1", "sess-2", null); + handler.handleTextMessage(session2, new TextMessage(hello2.toJson())); + + assertTrue(handler.hasSession("i-123")); + // Old session should have been closed + verify(session).close(CloseStatus.GOING_AWAY); + } + + @Test + void testHelloMissingFieldsClosesSession() throws Exception { + // Missing jobId and agentSessionId + AgentWsEnvelope badHello = new AgentWsEnvelope(); + badHello.setType(Type.hello); + badHello.setInstanceId("i-123"); + badHello.setSentAtMs(System.currentTimeMillis()); + + handler.handleTextMessage(session, new TextMessage(badHello.toJson())); + + assertFalse(handler.hasSession("i-123")); + verify(session).close(CloseStatus.BAD_DATA); + } + + @Test + void testInvalidJsonClosesSession() throws Exception { + handler.handleTextMessage(session, new TextMessage("not valid json")); + + verify(session).close(CloseStatus.BAD_DATA); + } + + @Test + void testMissingTypeClosesSession() throws Exception { + handler.handleTextMessage(session, new TextMessage("{\"instanceId\":\"i-123\"}")); + + verify(session).close(CloseStatus.BAD_DATA); + } + + @Test + void testPongUpdatesLastSeen() throws Exception { + // Register first + AgentWsEnvelope hello = AgentWsEnvelope.hello("i-123", "job-1", "sess-1", null); + handler.handleTextMessage(session, new TextMessage(hello.toJson())); + + // Send pong + AgentWsEnvelope pong = AgentWsEnvelope.pong("i-123", "sess-1", "ping-1", "cmd-5"); + handler.handleTextMessage(session, new TextMessage(pong.toJson())); + + // Session should still be active + assertTrue(handler.hasSession("i-123")); + } + + @Test + void testCloseUnregistersSession() throws Exception { + // Register + AgentWsEnvelope hello = AgentWsEnvelope.hello("i-123", "job-1", "sess-1", null); + handler.handleTextMessage(session, new TextMessage(hello.toJson())); + assertTrue(handler.hasSession("i-123")); + + // Close + AgentWsEnvelope close = AgentWsEnvelope.close("i-123", "shutdown", "done"); + handler.handleTextMessage(session, new TextMessage(close.toJson())); + + assertFalse(handler.hasSession("i-123")); + } + + @Test + void testAfterConnectionClosedCleansUp() throws Exception { + AgentWsEnvelope hello = AgentWsEnvelope.hello("i-123", "job-1", "sess-1", null); + handler.handleTextMessage(session, new TextMessage(hello.toJson())); + assertTrue(handler.hasSession("i-123")); + + handler.afterConnectionClosed(session, CloseStatus.NORMAL); + + assertFalse(handler.hasSession("i-123")); + } + + @Test + void testHasSessionReturnsFalseForUnknown() { + assertFalse(handler.hasSession("i-unknown")); + } + + @Test + void testHasSessionReturnsFalseForClosedSession() throws Exception { + AgentWsEnvelope hello = AgentWsEnvelope.hello("i-123", "job-1", "sess-1", null); + handler.handleTextMessage(session, new TextMessage(hello.toJson())); + + when(session.isOpen()).thenReturn(false); + assertFalse(handler.hasSession("i-123")); + } + + @Test + void testSendCommandReturnsFalseWhenNoSession() { + boolean result = handler.sendCommand("i-unknown", "job-1", "start", 1000); + assertFalse(result); + } + + @Test + void testSendCommandReturnsFalseWhenSessionClosed() throws Exception { + AgentWsEnvelope hello = AgentWsEnvelope.hello("i-123", "job-1", "sess-1", null); + handler.handleTextMessage(session, new TextMessage(hello.toJson())); + + when(session.isOpen()).thenReturn(false); + + boolean result = handler.sendCommand("i-123", "job-1", "start", 1000); + assertFalse(result); + } + + @Test + void testSendCommandTimesOutWithNoAck() throws Exception { + AgentWsEnvelope hello = AgentWsEnvelope.hello("i-123", "job-1", "sess-1", null); + handler.handleTextMessage(session, new TextMessage(hello.toJson())); + + // Send command with very short timeout - no ack will come + boolean result = handler.sendCommand("i-123", "job-1", "start", 50); + assertFalse(result); + } + + @Test + void testSendCommandSucceedsWithAck() throws Exception { + AgentWsEnvelope hello = AgentWsEnvelope.hello("i-123", "job-1", "sess-1", null); + handler.handleTextMessage(session, new TextMessage(hello.toJson())); + + // Simulate the ack arriving when command is sent + doAnswer(invocation -> { + TextMessage msg = invocation.getArgument(0); + AgentWsEnvelope cmd = AgentWsEnvelope.fromJson(msg.getPayload()); + if (cmd.getType() == Type.command) { + // Simulate agent ack + AgentWsEnvelope ack = AgentWsEnvelope.ack("i-123", "command", cmd.getCommandId(), AckStatus.ok); + handler.handleTextMessage(session, new TextMessage(ack.toJson())); + } + return null; + }).when(session).sendMessage(any(TextMessage.class)); + + boolean result = handler.sendCommand("i-123", "job-1", "start", 5000); + assertTrue(result); + } + + @Test + void testTransportErrorCleansUpSession() throws Exception { + AgentWsEnvelope hello = AgentWsEnvelope.hello("i-123", "job-1", "sess-1", null); + handler.handleTextMessage(session, new TextMessage(hello.toJson())); + assertTrue(handler.hasSession("i-123")); + + handler.handleTransportError(session, new RuntimeException("connection lost")); + + assertFalse(handler.hasSession("i-123")); + } + + @Test + void testFrameBeforeHelloRejected() throws Exception { + // Send pong without hello first + AgentWsEnvelope pong = AgentWsEnvelope.pong("i-123", "sess-1", "ping-1", null); + handler.handleTextMessage(session, new TextMessage(pong.toJson())); + + verify(session).close(CloseStatus.POLICY_VIOLATION); + } + + @Test + void testIdentityRebindRejected() throws Exception { + // Register with one instanceId + AgentWsEnvelope hello1 = AgentWsEnvelope.hello("i-123", "job-1", "sess-1", null); + handler.handleTextMessage(session, new TextMessage(hello1.toJson())); + assertTrue(handler.hasSession("i-123")); + + // Try to rebind same session to different instanceId + AgentWsEnvelope hello2 = AgentWsEnvelope.hello("i-999", "job-1", "sess-1", null); + handler.handleTextMessage(session, new TextMessage(hello2.toJson())); + + // Session should be closed due to policy violation + verify(session).close(CloseStatus.POLICY_VIOLATION); + // Original identity should NOT be replaced + assertFalse(handler.hasSession("i-999")); + } + + @Test + void testSameIdentityReHelloAllowed() throws Exception { + // Register + AgentWsEnvelope hello1 = AgentWsEnvelope.hello("i-123", "job-1", "sess-1", null); + handler.handleTextMessage(session, new TextMessage(hello1.toJson())); + + // Re-hello with same instanceId on same session (e.g., agent restart detection) + AgentWsEnvelope hello2 = AgentWsEnvelope.hello("i-123", "job-1", "sess-2", null); + handler.handleTextMessage(session, new TextMessage(hello2.toJson())); + + assertTrue(handler.hasSession("i-123")); + // Should NOT have been closed with policy violation + verify(session, never()).close(CloseStatus.POLICY_VIOLATION); + } +} diff --git a/tank_vmManager/src/main/java/com/intuit/tank/perfManager/workLoads/ControllerInitiatedAgentWsClient.java b/tank_vmManager/src/main/java/com/intuit/tank/perfManager/workLoads/ControllerInitiatedAgentWsClient.java new file mode 100644 index 000000000..212c1ca52 --- /dev/null +++ b/tank_vmManager/src/main/java/com/intuit/tank/perfManager/workLoads/ControllerInitiatedAgentWsClient.java @@ -0,0 +1,215 @@ +package com.intuit.tank.perfManager.workLoads; + +import com.intuit.tank.vm.agent.messages.AgentWsEnvelope; +import com.intuit.tank.vm.agent.messages.AgentWsEnvelope.AckStatus; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ObjectMessage; + +import jakarta.enterprise.context.ApplicationScoped; +import java.io.IOException; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.WebSocket; +import java.nio.ByteBuffer; +import java.util.Map; +import java.util.Optional; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionStage; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; + +@ApplicationScoped +public class ControllerInitiatedAgentWsClient { + + private static final Logger LOG = LogManager.getLogger(ControllerInitiatedAgentWsClient.class); + + private final HttpClient httpClient = HttpClient.newHttpClient(); + + private final ConcurrentHashMap sessions = new ConcurrentHashMap<>(); + private final ConcurrentHashMap> pendingAcks = new ConcurrentHashMap<>(); + + public Optional connect(String instanceId, String wsUrl, String token, long helloTimeoutMillis) { + try { + SessionContext existing = sessions.get(instanceId); + if (existing != null && existing.isOpen()) { + return Optional.empty(); + } + + CompletableFuture helloFuture = new CompletableFuture<>(); + Listener listener = new Listener(instanceId, helloFuture); + + WebSocket webSocket = httpClient.newWebSocketBuilder() + .header("Authorization", "bearer " + token) + .buildAsync(URI.create(wsUrl), listener) + .join(); + + SessionContext context = new SessionContext(webSocket, helloFuture); + SessionContext previous = sessions.put(instanceId, context); + if (previous != null) { + previous.close(); + } + + AgentWsEnvelope hello = helloFuture.get(helloTimeoutMillis, TimeUnit.MILLISECONDS); + if (hello != null) { + LOG.info(new ObjectMessage(Map.of("Message", + "Controller initiated WS connected to agent " + instanceId + " via " + wsUrl))); + return Optional.of(hello); + } + } catch (Exception e) { + LOG.warn(new ObjectMessage(Map.of("Message", + "Failed to connect WS to agent " + instanceId + ": " + e.getMessage()))); + } + + SessionContext failed = sessions.remove(instanceId); + if (failed != null) { + failed.close(); + } + return Optional.empty(); + } + + public boolean hasSession(String instanceId) { + SessionContext context = sessions.get(instanceId); + return context != null && context.isOpen(); + } + + public boolean sendCommand(String instanceId, String jobId, String command, long ackTimeoutMillis) { + SessionContext context = sessions.get(instanceId); + if (context == null || !context.isOpen()) { + return false; + } + + String commandId = UUID.randomUUID().toString(); + CompletableFuture ackFuture = new CompletableFuture<>(); + pendingAcks.put(commandId, ackFuture); + + try { + AgentWsEnvelope cmd = AgentWsEnvelope.command(commandId, instanceId, jobId, command); + context.send(cmd.toJson()); + + AgentWsEnvelope ack = ackFuture.get(ackTimeoutMillis, TimeUnit.MILLISECONDS); + return ack != null && ack.getStatus() == AckStatus.ok; + } catch (Exception e) { + LOG.warn(new ObjectMessage(Map.of("Message", + "WS command " + command + " to " + instanceId + " failed: " + e.getMessage()))); + return false; + } finally { + pendingAcks.remove(commandId); + } + } + + private void handleText(String instanceId, String text, CompletableFuture helloFuture) { + try { + AgentWsEnvelope envelope = AgentWsEnvelope.fromJson(text); + if (envelope.getType() == null) { + return; + } + switch (envelope.getType()) { + case hello -> helloFuture.complete(envelope); + case ack -> { + String ackForId = envelope.getAckForId(); + if (ackForId != null) { + CompletableFuture future = pendingAcks.remove(ackForId); + if (future != null) { + future.complete(envelope); + } + } + } + default -> { + // no-op for pong/close/ping in PoC client path + } + } + } catch (IOException ignored) { + } + } + + private void onClosed(String instanceId) { + SessionContext context = sessions.remove(instanceId); + if (context != null) { + context.markClosed(); + } + } + + private class Listener implements WebSocket.Listener { + private final String instanceId; + private final CompletableFuture helloFuture; + private final StringBuilder messageBuffer = new StringBuilder(); + + private Listener(String instanceId, CompletableFuture helloFuture) { + this.instanceId = instanceId; + this.helloFuture = helloFuture; + } + + @Override + public void onOpen(WebSocket webSocket) { + webSocket.request(1); + } + + @Override + public CompletionStage onText(WebSocket webSocket, CharSequence data, boolean last) { + messageBuffer.append(data); + if (last) { + String fullMessage = messageBuffer.toString(); + messageBuffer.setLength(0); + handleText(instanceId, fullMessage, helloFuture); + } + webSocket.request(1); + return null; + } + + @Override + public CompletionStage onPing(WebSocket webSocket, ByteBuffer message) { + webSocket.request(1); + return null; + } + + @Override + public CompletionStage onClose(WebSocket webSocket, int statusCode, String reason) { + onClosed(instanceId); + return null; + } + + @Override + public void onError(WebSocket webSocket, Throwable error) { + LOG.warn(new ObjectMessage(Map.of("Message", + "Controller initiated WS listener error for " + instanceId + ": " + error.getMessage()))); + onClosed(instanceId); + } + } + + private static class SessionContext { + private final WebSocket webSocket; + private final CompletableFuture helloFuture; + private final AtomicBoolean closed = new AtomicBoolean(false); + + private SessionContext(WebSocket webSocket, CompletableFuture helloFuture) { + this.webSocket = webSocket; + this.helloFuture = helloFuture; + } + + private boolean isOpen() { + return !closed.get() && !webSocket.isInputClosed() && !webSocket.isOutputClosed(); + } + + private void send(String text) { + webSocket.sendText(text, true).join(); + } + + private void close() { + if (closed.compareAndSet(false, true)) { + try { + webSocket.sendClose(WebSocket.NORMAL_CLOSURE, "Closing previous session"); + } catch (Exception ignored) { + } + helloFuture.completeExceptionally(new IllegalStateException("Closed")); + } + } + + private void markClosed() { + closed.set(true); + helloFuture.completeExceptionally(new IllegalStateException("Closed")); + } + } +} diff --git a/tank_vmManager/src/main/java/com/intuit/tank/perfManager/workLoads/JobManager.java b/tank_vmManager/src/main/java/com/intuit/tank/perfManager/workLoads/JobManager.java index 3e57328ea..9f534e855 100644 --- a/tank_vmManager/src/main/java/com/intuit/tank/perfManager/workLoads/JobManager.java +++ b/tank_vmManager/src/main/java/com/intuit/tank/perfManager/workLoads/JobManager.java @@ -57,6 +57,7 @@ import com.intuit.tank.vm.agent.messages.DataFileRequest; import com.intuit.tank.vm.agent.messages.StandaloneAgentRequest; import com.intuit.tank.vm.perfManager.StandaloneAgentTracker; +import com.intuit.tank.vm.settings.AgentConfig; import com.intuit.tank.vm.settings.TankConfig; import com.intuit.tank.vm.vmManager.JobRequest; import com.intuit.tank.vm.vmManager.JobVmCalculator; @@ -93,6 +94,12 @@ public class JobManager implements Serializable { @Inject private TankConfig tankConfig; + @Inject + private jakarta.enterprise.inject.Instance wsCommandSenderInstance; + + @Inject + private jakarta.enterprise.inject.Instance controllerInitiatedWsClientInstance; + /** * @param id * @throws Exception @@ -214,17 +221,65 @@ private void startTest(final JobInfo info) { LOG.info(new ObjectMessage(Map.of("Message", "Start agents command received - Sending start commands for job " + jobId + " asynchronously to following agents: " + info.agentData.stream().collect(Collectors.toMap(AgentData::getInstanceId, AgentData::getInstanceUrl))))); } - LOG.info(new ObjectMessage(Map.of("Message", "Sending START commands to " + info.agentData.size() + + LOG.info(new ObjectMessage(Map.of("Message", "Sending START commands to " + info.agentData.size() + " agents for job " + jobId))); + + AgentConfig agentConfig = tankConfig != null ? tankConfig.getAgentConfig() : null; + boolean wsEnabled = agentConfig != null && agentConfig.isCommandWsEnabled(); + boolean controllerInitiatedWsEnabled = agentConfig != null && agentConfig.isControllerInitiatedWsEnabled(); + boolean httpFallback = agentConfig == null || agentConfig.isCommandWsHttpFallbackEnabled(); + long ackTimeout = agentConfig != null ? agentConfig.getCommandWsAckTimeoutMillis() : 3000L; + com.intuit.tank.vm.agent.messages.AgentWsCommandSender wsSender = getWsCommandSender(); + ControllerInitiatedAgentWsClient controllerWsClient = getControllerInitiatedWsClient(); + info.agentData.parallelStream() - .map(agentData -> { + .forEach(agentData -> { + String instanceId = agentData.getInstanceId(); + + if (controllerInitiatedWsEnabled && controllerWsClient != null + && controllerWsClient.hasSession(instanceId)) { + boolean acked = controllerWsClient.sendCommand(instanceId, jobId, AgentCommand.start.name(), ackTimeout); + if (acked) { + LOG.info(new ObjectMessage(Map.of("Message", + "Controller-initiated WS START command to agent " + instanceId + " was SUCCESSFUL for job " + jobId))); + return; + } + LOG.warn(new ObjectMessage(Map.of("Message", "Controller-initiated WS START command to agent " + + instanceId + " failed, " + (httpFallback ? "falling back to HTTP" : "no fallback")))); + if (!httpFallback) { + return; + } + } else if (controllerInitiatedWsEnabled && !httpFallback) { + LOG.warn(new ObjectMessage(Map.of("Message", "Controller-initiated WS enabled but no session for agent " + + instanceId + " and fallback disabled, skipping"))); + return; + } + + // Try WS first if enabled + if (wsEnabled) { + if (wsSender != null && wsSender.hasSession(instanceId)) { + boolean acked = wsSender.sendCommand(instanceId, jobId, AgentCommand.start.name(), ackTimeout); + if (acked) { + LOG.info(new ObjectMessage(Map.of("Message", "WS START command to agent " + instanceId + " was SUCCESSFUL for job " + jobId))); + return; + } + LOG.warn(new ObjectMessage(Map.of("Message", "WS START command to agent " + instanceId + " failed, " + + (httpFallback ? "falling back to HTTP" : "no fallback")))); + if (!httpFallback) { + return; + } + } else if (!httpFallback) { + LOG.warn(new ObjectMessage(Map.of("Message", "WS enabled but " + + (wsSender == null ? "sender unavailable" : "no session") + + " for agent " + instanceId + " and fallback disabled, skipping"))); + return; + } + } + + // HTTP fallback (or WS not enabled) String url = agentData.getInstanceUrl() + AgentCommand.start.getPath(); LOG.info(new ObjectMessage(Map.of("Message", "Sending command to url " + url))); - return url; - }) - .map(URI::create) - .map(uri -> sendCommand(uri, MAX_RETRIES)) - .forEach(future -> { + CompletableFuture future = sendCommand(URI.create(url), MAX_RETRIES); HttpResponse response = (HttpResponse) future.join(); if (response != null && Set.of(HttpStatus.SC_OK, HttpStatus.SC_ACCEPTED).contains(response.statusCode())) { LOG.info(new ObjectMessage(Map.of( @@ -246,16 +301,110 @@ private CloudVmStatus createFailureStatus(AgentData data) { /** * Convert the List of instanceIds to instanceUrls to instanceUris and pass it to sendCommand(List, retry) to send asynchttp commands. + * Uses WS transport when available and enabled, falls back to HTTP. * @param instanceIds * @param cmd * @return Array of CompletableFuture that will probably never be looked at. */ public List> sendCommand(List instanceIds, AgentCommand cmd) { + AgentConfig agentConfig = tankConfig != null ? tankConfig.getAgentConfig() : null; + boolean wsEnabled = agentConfig != null && agentConfig.isCommandWsEnabled(); + boolean controllerInitiatedWsEnabled = agentConfig != null && agentConfig.isControllerInitiatedWsEnabled(); + boolean httpFallback = agentConfig == null || agentConfig.isCommandWsHttpFallbackEnabled(); + long ackTimeout = agentConfig != null ? agentConfig.getCommandWsAckTimeoutMillis() : 3000L; + + com.intuit.tank.vm.agent.messages.AgentWsCommandSender wsSender = getWsCommandSender(); + ControllerInitiatedAgentWsClient controllerWsClient = getControllerInitiatedWsClient(); + + // Resolve instanceId -> jobId mapping for WS commands + Map instanceJobMap = new HashMap<>(); + if ((wsEnabled && wsSender != null) || controllerInitiatedWsEnabled) { + for (String instanceId : instanceIds) { + for (JobInfo info : jobInfoMapLocalCache.values()) { + for (AgentData data : info.agentData) { + if (instanceId.equals(data.getInstanceId())) { + instanceJobMap.put(instanceId, info.jobRequest.getId()); + } + } + } + } + } + List instanceUrls = getInstanceUrl(instanceIds); + + // Build a parallel list of instanceId -> instanceUrl for fallback + List orderedInstanceIds = new ArrayList<>(instanceIds); + return instanceUrls.parallelStream() - .map(instanceUrl -> instanceUrl + cmd.getPath()) - .map(URI::create) - .map(uri -> sendCommand(uri, 0)) + .map(instanceUrl -> { + // Find matching instanceId for this URL + String matchedInstanceId = null; + for (int i = 0; i < orderedInstanceIds.size(); i++) { + String candidateId = orderedInstanceIds.get(i); + // Match by checking if this URL belongs to this instanceId + for (JobInfo info : jobInfoMapLocalCache.values()) { + for (AgentData data : info.agentData) { + if (candidateId.equals(data.getInstanceId()) && instanceUrl.equals(data.getInstanceUrl())) { + matchedInstanceId = candidateId; + } + } + } + } + + // Try WS first if enabled + if (controllerInitiatedWsEnabled && controllerWsClient != null && matchedInstanceId != null) { + String jobId = instanceJobMap.get(matchedInstanceId); + if (jobId != null && controllerWsClient.hasSession(matchedInstanceId)) { + boolean acked = controllerWsClient.sendCommand(matchedInstanceId, jobId, cmd.name(), ackTimeout); + if (acked) { + LOG.info(new ObjectMessage(Map.of("Message", "Controller-initiated WS command " + cmd + " to agent " + matchedInstanceId + " succeeded"))); + return CompletableFuture.completedFuture(null); + } + LOG.warn(new ObjectMessage(Map.of("Message", "Controller-initiated WS command " + cmd + " to agent " + matchedInstanceId + " failed, " + + (httpFallback ? "falling back to HTTP" : "no fallback")))); + if (!httpFallback) { + return CompletableFuture.completedFuture(null); + } + } else if (!httpFallback) { + LOG.warn(new ObjectMessage(Map.of("Message", "Controller-initiated WS enabled but no session for agent " + + matchedInstanceId + " and fallback disabled, skipping"))); + return CompletableFuture.completedFuture(null); + } + } else if (controllerInitiatedWsEnabled && !httpFallback) { + LOG.warn(new ObjectMessage(Map.of("Message", "Controller-initiated WS enabled but no matched instanceId and fallback disabled, skipping"))); + return CompletableFuture.completedFuture(null); + } + + if (wsEnabled) { + if (wsSender != null && matchedInstanceId != null) { + String jobId = instanceJobMap.get(matchedInstanceId); + if (jobId != null && wsSender.hasSession(matchedInstanceId)) { + boolean acked = wsSender.sendCommand(matchedInstanceId, jobId, cmd.name(), ackTimeout); + if (acked) { + LOG.info(new ObjectMessage(Map.of("Message", "WS command " + cmd + " to agent " + matchedInstanceId + " succeeded"))); + return CompletableFuture.completedFuture(null); + } + LOG.warn(new ObjectMessage(Map.of("Message", "WS command " + cmd + " to agent " + matchedInstanceId + " failed, " + + (httpFallback ? "falling back to HTTP" : "no fallback")))); + if (!httpFallback) { + return CompletableFuture.completedFuture(null); + } + } else if (!httpFallback) { + LOG.warn(new ObjectMessage(Map.of("Message", "WS enabled but no session for agent " + matchedInstanceId + " and fallback disabled, skipping"))); + return CompletableFuture.completedFuture(null); + } + } else if (!httpFallback) { + LOG.warn(new ObjectMessage(Map.of("Message", "WS enabled but " + + (wsSender == null ? "sender unavailable" : "no matched instanceId") + + " and fallback disabled, skipping"))); + return CompletableFuture.completedFuture(null); + } + } + + // HTTP fallback (or WS not enabled) + URI uri = URI.create(instanceUrl + cmd.getPath()); + return sendCommand(uri, 0); + }) .collect(Collectors.toList()); } @@ -364,6 +513,20 @@ private int getNumberOfLines(Integer dataFileId) { return ret; } + private com.intuit.tank.vm.agent.messages.AgentWsCommandSender getWsCommandSender() { + if (wsCommandSenderInstance != null && wsCommandSenderInstance.isResolvable()) { + return wsCommandSenderInstance.get(); + } + return null; + } + + private ControllerInitiatedAgentWsClient getControllerInitiatedWsClient() { + if (controllerInitiatedWsClientInstance != null && controllerInitiatedWsClientInstance.isResolvable()) { + return controllerInitiatedWsClientInstance.get(); + } + return null; + } + public void startAgents(String jobId){ LOG.info(new ObjectMessage(Map.of("Message","Sending start agents command to start test for job " + jobId))); if(!jobInfoMapLocalCache.get(jobId).isStarted()){ diff --git a/tank_vmManager/src/main/java/com/intuit/tank/vmManager/VmMessageProcessorImpl.java b/tank_vmManager/src/main/java/com/intuit/tank/vmManager/VmMessageProcessorImpl.java index f30f33460..edcfabb80 100644 --- a/tank_vmManager/src/main/java/com/intuit/tank/vmManager/VmMessageProcessorImpl.java +++ b/tank_vmManager/src/main/java/com/intuit/tank/vmManager/VmMessageProcessorImpl.java @@ -17,6 +17,9 @@ import jakarta.inject.Inject; import com.amazonaws.xray.AWSXRay; +import com.intuit.tank.perfManager.workLoads.ControllerInitiatedAgentWsClient; +import com.intuit.tank.perfManager.workLoads.JobManager; +import com.intuit.tank.vm.settings.TankConfig; import com.intuit.tank.vm.vmManager.*; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; @@ -42,6 +45,15 @@ public class VmMessageProcessorImpl implements VmMessageProcessor { @Inject private VMTracker vmTracker; + @Inject + private ControllerInitiatedAgentWsClient controllerInitiatedAgentWsClient; + + @Inject + private JobManager jobManager; + + @Inject + private TankConfig tankConfig; + /** * @param messageObject */ @@ -52,7 +64,8 @@ public void handleVMRequest(VMRequest messageObject) { CreateInstance instance = new CreateInstance((VMInstanceRequest) messageObject, vmTracker); Objects.requireNonNull(AWSXRay.getGlobalRecorder().getTraceEntity()).run(instance); } else if (messageObject instanceof VMJobRequest) { - JobRequest instance = new JobRequest((VMJobRequest) messageObject, vmTracker); + JobRequest instance = new JobRequest((VMJobRequest) messageObject, vmTracker, + controllerInitiatedAgentWsClient, jobManager, tankConfig); Objects.requireNonNull(AWSXRay.getGlobalRecorder().getTraceEntity()).run(instance); } else if (messageObject instanceof VMKillRequest) { logger.debug("vmManager received VMKillRequest"); diff --git a/tank_vmManager/src/main/java/com/intuit/tank/vmManager/environment/JobRequest.java b/tank_vmManager/src/main/java/com/intuit/tank/vmManager/environment/JobRequest.java index e8a189d6f..2eb04665d 100644 --- a/tank_vmManager/src/main/java/com/intuit/tank/vmManager/environment/JobRequest.java +++ b/tank_vmManager/src/main/java/com/intuit/tank/vmManager/environment/JobRequest.java @@ -15,11 +15,13 @@ import java.util.List; import java.util.Map; +import java.util.Optional; import com.intuit.tank.vm.api.enumerated.IncrementStrategy; import com.intuit.tank.logging.ControllerLoggingConfig; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.apache.commons.lang3.StringUtils; import com.intuit.tank.vm.vmManager.VMTracker; import com.intuit.tank.vm.vmManager.models.CloudVmStatus; @@ -29,6 +31,12 @@ import com.intuit.tank.vm.api.enumerated.JobStatus; import com.intuit.tank.vm.api.enumerated.VMImageType; import com.intuit.tank.vm.api.enumerated.VMProvider; +import com.intuit.tank.vm.api.enumerated.VMRegion; +import com.intuit.tank.vm.agent.messages.AgentData; +import com.intuit.tank.vm.agent.messages.AgentWsEnvelope; +import com.intuit.tank.vm.settings.TankConfig; +import com.intuit.tank.perfManager.workLoads.ControllerInitiatedAgentWsClient; +import com.intuit.tank.perfManager.workLoads.JobManager; import com.intuit.tank.vm.vmManager.JobVmCalculator; import com.intuit.tank.vm.vmManager.VMInformation; import com.intuit.tank.vm.vmManager.VMInstanceRequest; @@ -43,10 +51,24 @@ public class JobRequest implements Runnable { private VMJobRequest request = null; private VMTracker vmTracker; + private ControllerInitiatedAgentWsClient controllerInitiatedAgentWsClient; + private JobManager jobManager; + private TankConfig tankConfig; public JobRequest(VMJobRequest request, VMTracker tracker) { + this(request, tracker, null, null, new TankConfig()); + } + + public JobRequest(VMJobRequest request, + VMTracker tracker, + ControllerInitiatedAgentWsClient controllerInitiatedAgentWsClient, + JobManager jobManager, + TankConfig tankConfig) { this.request = request; this.vmTracker = tracker; + this.controllerInitiatedAgentWsClient = controllerInitiatedAgentWsClient; + this.jobManager = jobManager; + this.tankConfig = tankConfig; } @Override @@ -95,6 +117,7 @@ private void persistInstances(VMInstanceRequest instanceRequest, List hello = Optional.empty(); + int attempts = 0; + while (hello.isEmpty() && attempts < 8) { + attempts++; + hello = controllerInitiatedAgentWsClient.connect( + info.getInstanceId(), + wsUrl, + tankConfig.getAgentConfig().getAgentToken(), + 10000L); + if (hello.isPresent()) { + break; + } + try { + Thread.sleep(5000L); + } catch (InterruptedException ignored) { + Thread.currentThread().interrupt(); + break; + } + } + + if (hello.isEmpty()) { + logger.warn("Controller-initiated WS hello not received for {}", info.getInstanceId()); + return; + } + + int capacity = hello.get().getCapacity() != null ? hello.get().getCapacity() : instanceRequest.getNumUsersPerAgent(); + String instanceUrl = buildHttpUrl(instanceRequest, info); + AgentData agentData = new AgentData(instanceRequest.getJobId(), info.getInstanceId(), instanceUrl, + capacity, instanceRequest.getRegion(), "unknown"); + jobManager.registerAgentForJob(agentData); + } + + private String buildWsUrl(VMInstanceRequest instanceRequest, VMInformation info) { + String host = selectHostForRegion(instanceRequest.getRegion(), info); + if (StringUtils.isBlank(host)) { + return null; + } + String wsPath = tankConfig.getAgentConfig().getCommandWsPath(); + if (!wsPath.startsWith("/")) { + wsPath = "/" + wsPath; + } + return "ws://" + host + ":" + tankConfig.getAgentConfig().getAgentPort() + wsPath; + } + + private String buildHttpUrl(VMInstanceRequest instanceRequest, VMInformation info) { + String host = selectHostForRegion(instanceRequest.getRegion(), info); + if (StringUtils.isBlank(host)) { + host = "localhost"; + } + return "http://" + host + ":" + tankConfig.getAgentConfig().getAgentPort(); + } + + private String selectHostForRegion(VMRegion region, VMInformation info) { + if (region == VMRegion.US_EAST || region == VMRegion.US_EAST_2) { + return firstNonBlank(info.getPrivateIp(), info.getPrivateDNS(), info.getPublicIp(), info.getPublicDNS()); + } + return firstNonBlank(info.getPublicDNS(), info.getPublicIp(), info.getPrivateIp(), info.getPrivateDNS()); + } + + private String firstNonBlank(String... values) { + for (String value : values) { + if (StringUtils.isNotBlank(value)) { + return value; + } + } + return null; + } + /** * @param req * @param info diff --git a/tank_vmManager/src/main/java/com/intuit/tank/vmManager/environment/amazon/AmazonInstance.java b/tank_vmManager/src/main/java/com/intuit/tank/vmManager/environment/amazon/AmazonInstance.java index ad226044e..198140a96 100644 --- a/tank_vmManager/src/main/java/com/intuit/tank/vmManager/environment/amazon/AmazonInstance.java +++ b/tank_vmManager/src/main/java/com/intuit/tank/vmManager/environment/amazon/AmazonInstance.java @@ -184,6 +184,12 @@ public List create(VMRequest request) { instanceRequest.addUserData(TankConstants.KEY_CONTROLLER_URL, config.getControllerBase()); instanceRequest.addUserData(TankConstants.KEY_AGENT_TOKEN, config.getAgentConfig().getAgentToken()); instanceRequest.addUserData(TankConstants.KEY_NUM_USERS_PER_AGENT, Integer.toString(instanceRequest.getNumUsersPerAgent())); + instanceRequest.addUserData(TankConstants.KEY_CONTROLLER_INITIATED_WS_ENABLED, + Boolean.toString(config.getAgentConfig().isControllerInitiatedWsEnabled())); + instanceRequest.addUserData(TankConstants.KEY_CONTROLLER_INITIATED_WS_DISABLE_AGENT_HTTP, + Boolean.toString(config.getAgentConfig().isControllerInitiatedWsDisableAgentHttp())); + instanceRequest.addUserData(TankConstants.KEY_CONTROLLER_INITIATED_WS_SCRIPT_PATH, + config.getAgentConfig().getControllerInitiatedWsScriptPath()); if (instanceRequest.isUseEips()) { instanceRequest.addUserData(TankConstants.KEY_USING_BIND_EIP, Boolean.TRUE.toString()); diff --git a/tank_vmManager/src/test/java/com/intuit/tank/perfManager/workLoads/ControllerInitiatedAgentWsClientTest.java b/tank_vmManager/src/test/java/com/intuit/tank/perfManager/workLoads/ControllerInitiatedAgentWsClientTest.java new file mode 100644 index 000000000..35fb1604f --- /dev/null +++ b/tank_vmManager/src/test/java/com/intuit/tank/perfManager/workLoads/ControllerInitiatedAgentWsClientTest.java @@ -0,0 +1,20 @@ +package com.intuit.tank.perfManager.workLoads; + +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertFalse; + +public class ControllerInitiatedAgentWsClientTest { + + @Test + public void testHasSessionFalseByDefault() { + ControllerInitiatedAgentWsClient client = new ControllerInitiatedAgentWsClient(); + assertFalse(client.hasSession("i-missing")); + } + + @Test + public void testSendCommandWithoutSessionReturnsFalse() { + ControllerInitiatedAgentWsClient client = new ControllerInitiatedAgentWsClient(); + assertFalse(client.sendCommand("i-missing", "job-1", "start", 1000L)); + } +} diff --git a/tank_vmManager/src/test/java/com/intuit/tank/perfManager/workLoads/JobManagerTest.java b/tank_vmManager/src/test/java/com/intuit/tank/perfManager/workLoads/JobManagerTest.java index 6dd4db42c..3e80430bd 100644 --- a/tank_vmManager/src/test/java/com/intuit/tank/perfManager/workLoads/JobManagerTest.java +++ b/tank_vmManager/src/test/java/com/intuit/tank/perfManager/workLoads/JobManagerTest.java @@ -14,19 +14,32 @@ */ import java.util.Collections; +import java.util.HashMap; import java.util.LinkedList; import java.util.List; +import java.util.Map; +import java.util.Set; import java.util.concurrent.CompletableFuture; import java.util.concurrent.FutureTask; +import java.lang.reflect.Constructor; +import java.lang.reflect.Field; +import jakarta.enterprise.inject.Instance; import org.junit.jupiter.api.*; import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.*; import com.intuit.tank.vm.agent.messages.AgentData; import com.intuit.tank.vm.agent.messages.AgentTestStartData; +import com.intuit.tank.vm.agent.messages.AgentWsCommandSender; import com.intuit.tank.vm.api.enumerated.VMRegion; import com.intuit.tank.vm.api.enumerated.AgentCommand; +import com.intuit.tank.vm.api.enumerated.IncrementStrategy; +import com.intuit.tank.vm.settings.AgentConfig; +import com.intuit.tank.vm.settings.TankConfig; +import com.intuit.tank.vm.vmManager.JobRequest; /** * The class JobManagerTest contains tests for the class {@link JobManager}. @@ -101,4 +114,133 @@ public void testGetInstanceUrl() { assertTrue(result.isEmpty()); } + + /** + * When WS sender is not injected (null), sendCommand should use HTTP path only. + * Uses empty string instanceId (same pattern as existing tests) to avoid + * hitting findAgent which requires tankConfig. + */ + @Test + public void testSendCommandWithNoWsSender() { + JobManager fixture = new JobManager(); + // wsCommandSenderInstance is null (not injected) — should fall through to HTTP + // Empty string is filtered out by StringUtils.isNotEmpty in getInstanceUrl + List instanceIds = Collections.singletonList(""); + AgentCommand cmd = AgentCommand.stop; + + // Should not throw, should return empty (no resolvable URL) + List> result = fixture.sendCommand(instanceIds, cmd); + assertEquals(0, result.size()); + } + + /** + * When WS sender is not injected, empty instanceId list should return empty results. + */ + @Test + public void testSendCommandEmptyListWithNoWsSender() { + JobManager fixture = new JobManager(); + List> result = fixture.sendCommand(Collections.emptyList(), AgentCommand.start); + assertEquals(0, result.size()); + } + + @Test + public void testSendCommandUsesControllerInitiatedWsWhenEnabled() throws Exception { + JobManager fixture = new JobManager(); + String jobId = "job-1"; + String instanceId = "i-12345"; + String instanceUrl = "http://127.0.0.1:8090"; + seedJobInfoCache(fixture, jobId, instanceId, instanceUrl); + + AgentConfig agentConfig = mock(AgentConfig.class); + when(agentConfig.isCommandWsEnabled()).thenReturn(false); + when(agentConfig.isControllerInitiatedWsEnabled()).thenReturn(true); + when(agentConfig.isCommandWsHttpFallbackEnabled()).thenReturn(false); + when(agentConfig.getCommandWsAckTimeoutMillis()).thenReturn(1000L); + + TankConfig tankConfig = mock(TankConfig.class); + when(tankConfig.getAgentConfig()).thenReturn(agentConfig); + setField(fixture, "tankConfig", tankConfig); + + ControllerInitiatedAgentWsClient wsClient = mock(ControllerInitiatedAgentWsClient.class); + when(wsClient.hasSession(instanceId)).thenReturn(true); + when(wsClient.sendCommand(instanceId, jobId, AgentCommand.stop.name(), 1000L)).thenReturn(true); + + @SuppressWarnings("unchecked") + Instance wsClientInstance = mock(Instance.class); + when(wsClientInstance.isResolvable()).thenReturn(true); + when(wsClientInstance.get()).thenReturn(wsClient); + setField(fixture, "controllerInitiatedWsClientInstance", wsClientInstance); + + List> result = fixture.sendCommand(Collections.singletonList(instanceId), AgentCommand.stop); + + assertEquals(1, result.size()); + assertTrue(result.get(0).isDone()); + assertNull(result.get(0).join()); + verify(wsClient).sendCommand(eq(instanceId), eq(jobId), eq(AgentCommand.stop.name()), eq(1000L)); + } + + @Test + public void testSendCommandSkipsWhenNoControllerWsSessionAndNoFallback() throws Exception { + JobManager fixture = new JobManager(); + String jobId = "job-2"; + String instanceId = "i-54321"; + String instanceUrl = "http://127.0.0.1:8090"; + seedJobInfoCache(fixture, jobId, instanceId, instanceUrl); + + AgentConfig agentConfig = mock(AgentConfig.class); + when(agentConfig.isCommandWsEnabled()).thenReturn(false); + when(agentConfig.isControllerInitiatedWsEnabled()).thenReturn(true); + when(agentConfig.isCommandWsHttpFallbackEnabled()).thenReturn(false); + when(agentConfig.getCommandWsAckTimeoutMillis()).thenReturn(1000L); + + TankConfig tankConfig = mock(TankConfig.class); + when(tankConfig.getAgentConfig()).thenReturn(agentConfig); + setField(fixture, "tankConfig", tankConfig); + + ControllerInitiatedAgentWsClient wsClient = mock(ControllerInitiatedAgentWsClient.class); + when(wsClient.hasSession(instanceId)).thenReturn(false); + + @SuppressWarnings("unchecked") + Instance wsClientInstance = mock(Instance.class); + when(wsClientInstance.isResolvable()).thenReturn(true); + when(wsClientInstance.get()).thenReturn(wsClient); + setField(fixture, "controllerInitiatedWsClientInstance", wsClientInstance); + + List> result = fixture.sendCommand(Collections.singletonList(instanceId), AgentCommand.stop); + + assertEquals(1, result.size()); + assertTrue(result.get(0).isDone()); + assertNull(result.get(0).join()); + verify(wsClient, never()).sendCommand(anyString(), anyString(), anyString(), anyLong()); + } + + private void seedJobInfoCache(JobManager fixture, String jobId, String instanceId, String instanceUrl) throws Exception { + JobRequest jobRequest = mock(JobRequest.class); + when(jobRequest.getIncrementStrategy()).thenReturn(IncrementStrategy.increasing); + when(jobRequest.getRegions()).thenReturn(Collections.emptySet()); + when(jobRequest.getScriptsXmlUrl()).thenReturn("script.xml"); + when(jobRequest.getId()).thenReturn(jobId); + when(jobRequest.getNumUsersPerAgent()).thenReturn(1); + + Class jobInfoClass = Class.forName("com.intuit.tank.perfManager.workLoads.JobManager$JobInfo"); + Constructor ctor = jobInfoClass.getDeclaredConstructor(JobRequest.class); + ctor.setAccessible(true); + Object jobInfo = ctor.newInstance(jobRequest); + + Field agentDataField = jobInfoClass.getDeclaredField("agentData"); + agentDataField.setAccessible(true); + @SuppressWarnings("unchecked") + Set agentData = (Set) agentDataField.get(jobInfo); + agentData.add(new AgentData(jobId, instanceId, instanceUrl, 1, VMRegion.US_EAST_2, "zone-a")); + + Map cache = new HashMap<>(); + cache.put(jobId, jobInfo); + setField(fixture, "jobInfoMapLocalCache", cache); + } + + private void setField(JobManager fixture, String fieldName, Object value) throws Exception { + Field field = JobManager.class.getDeclaredField(fieldName); + field.setAccessible(true); + field.set(fixture, value); + } } \ No newline at end of file diff --git a/tank_vmManager/src/test/java/com/intuit/tank/vmManager/environment/JobRequestControllerInitiatedWsTest.java b/tank_vmManager/src/test/java/com/intuit/tank/vmManager/environment/JobRequestControllerInitiatedWsTest.java new file mode 100644 index 000000000..5dcf33bcc --- /dev/null +++ b/tank_vmManager/src/test/java/com/intuit/tank/vmManager/environment/JobRequestControllerInitiatedWsTest.java @@ -0,0 +1,54 @@ +package com.intuit.tank.vmManager.environment; + +import com.intuit.tank.vm.api.enumerated.IncrementStrategy; +import com.intuit.tank.vm.api.enumerated.VMRegion; +import com.intuit.tank.vm.vmManager.VMInformation; +import com.intuit.tank.vm.vmManager.VMJobRequest; +import com.intuit.tank.vmManager.VMTrackerImpl; +import org.junit.jupiter.api.Test; + +import java.lang.reflect.Method; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class JobRequestControllerInitiatedWsTest { + + @Test + public void testSelectHostForEastPrefersPrivateAddress() throws Exception { + JobRequest jobRequest = createFixture(); + VMInformation vmInfo = buildVmInfo(); + + Method method = JobRequest.class.getDeclaredMethod("selectHostForRegion", VMRegion.class, VMInformation.class); + method.setAccessible(true); + String host = (String) method.invoke(jobRequest, VMRegion.US_EAST_2, vmInfo); + + assertEquals("10.0.0.5", host); + } + + @Test + public void testSelectHostForWestPrefersPublicAddress() throws Exception { + JobRequest jobRequest = createFixture(); + VMInformation vmInfo = buildVmInfo(); + + Method method = JobRequest.class.getDeclaredMethod("selectHostForRegion", VMRegion.class, VMInformation.class); + method.setAccessible(true); + String host = (String) method.invoke(jobRequest, VMRegion.US_WEST_2, vmInfo); + + assertEquals("ec2-public.example.com", host); + } + + private JobRequest createFixture() { + VMJobRequest vmJobRequest = new VMJobRequest("job-1", "none", "STANDARD", 1, + VMRegion.US_EAST_2, IncrementStrategy.increasing, "END_OF_SCRIPT", "m.xlarge", 4000); + return new JobRequest(vmJobRequest, new VMTrackerImpl()); + } + + private VMInformation buildVmInfo() { + VMInformation info = new VMInformation(); + info.setPrivateIp("10.0.0.5"); + info.setPrivateDNS("ip-10-0-0-5.ec2.internal"); + info.setPublicIp("54.12.34.56"); + info.setPublicDNS("ec2-public.example.com"); + return info; + } +}