diff --git a/agent/agent_startup/src/main/java/com/intuit/tank/agent/AgentStartup.java b/agent/agent_startup/src/main/java/com/intuit/tank/agent/AgentStartup.java
index 48af1d8b5..813d8396c 100644
--- a/agent/agent_startup/src/main/java/com/intuit/tank/agent/AgentStartup.java
+++ b/agent/agent_startup/src/main/java/com/intuit/tank/agent/AgentStartup.java
@@ -54,42 +54,48 @@ public void run() {
logger.info("Starting up...");
HttpClient client = HttpClient.newHttpClient();
try {
+ boolean controllerInitiatedWsEnabled = Boolean.parseBoolean(
+ AmazonUtil.getUserDataAsMap().getOrDefault(TankConstants.KEY_CONTROLLER_INITIATED_WS_ENABLED, "false"));
logger.info("Starting up: ControllerBaseUrl={}", controllerBaseUrl);
- HttpRequest request = HttpRequest.newBuilder().uri(URI.create(
- controllerBaseUrl + SERVICE_RELATIVE_PATH + METHOD_SETTINGS))
- .header("Authorization", "bearer "+token).build();
- logger.info("Starting up: making call to tank service url to get settings.xml {} {} {}",
- controllerBaseUrl, SERVICE_RELATIVE_PATH, METHOD_SUPPORT);
- client.send(request, BodyHandlers.ofFile(Paths.get(TANK_AGENT_DIR, "settings.xml")));
- logger.info("got settings file...");
- // Download Support Files
- request = HttpRequest.newBuilder().uri(URI.create(
- controllerBaseUrl + SERVICE_RELATIVE_PATH + METHOD_SUPPORT))
- .header("Authorization", "bearer "+token).build();
- logger.info("Making call to tank service url to get support files {} {} {}",
- controllerBaseUrl, SERVICE_RELATIVE_PATH, METHOD_SUPPORT);
- int retryCount = 0;
- while (true) {
- try (ZipInputStream zip = new ZipInputStream(
- client.send(request, BodyHandlers.ofInputStream()).body())) {
- ZipEntry entry = zip.getNextEntry();
- Path agentDirPath = Paths.get(TANK_AGENT_DIR).toAbsolutePath().normalize();
- while (entry != null) {
- String filename = entry.getName();
- logger.info("Got file from controller: {}", filename);
- Path targetPath = agentDirPath.resolve(filename).normalize();
- if (!targetPath.startsWith(agentDirPath)) // Protect "Zip Slip"
- throw new ZipException("Bad zip entry");
- Files.write(targetPath, zip.readAllBytes());
- entry = zip.getNextEntry();
+ if (!controllerInitiatedWsEnabled) {
+ HttpRequest request = HttpRequest.newBuilder().uri(URI.create(
+ controllerBaseUrl + SERVICE_RELATIVE_PATH + METHOD_SETTINGS))
+ .header("Authorization", "bearer "+token).build();
+ logger.info("Starting up: making call to tank service url to get settings.xml {} {} {}",
+ controllerBaseUrl, SERVICE_RELATIVE_PATH, METHOD_SUPPORT);
+ client.send(request, BodyHandlers.ofFile(Paths.get(TANK_AGENT_DIR, "settings.xml")));
+ logger.info("got settings file...");
+ // Download Support Files
+ request = HttpRequest.newBuilder().uri(URI.create(
+ controllerBaseUrl + SERVICE_RELATIVE_PATH + METHOD_SUPPORT))
+ .header("Authorization", "bearer "+token).build();
+ logger.info("Making call to tank service url to get support files {} {} {}",
+ controllerBaseUrl, SERVICE_RELATIVE_PATH, METHOD_SUPPORT);
+ int retryCount = 0;
+ while (true) {
+ try (ZipInputStream zip = new ZipInputStream(
+ client.send(request, BodyHandlers.ofInputStream()).body())) {
+ ZipEntry entry = zip.getNextEntry();
+ Path agentDirPath = Paths.get(TANK_AGENT_DIR).toAbsolutePath().normalize();
+ while (entry != null) {
+ String filename = entry.getName();
+ logger.info("Got file from controller: {}", filename);
+ Path targetPath = agentDirPath.resolve(filename).normalize();
+ if (!targetPath.startsWith(agentDirPath)) // Protect "Zip Slip"
+ throw new ZipException("Bad zip entry");
+ Files.write(targetPath, zip.readAllBytes());
+ entry = zip.getNextEntry();
+ }
+ break;
+ } catch (EOFException | ZipException e) {
+ logger.error("Error unzipping support files : retryCount={} : {}", retryCount, e.getMessage());
+ if (retryCount < FIBONACCI.length) {
+ Thread.sleep( FIBONACCI[retryCount++] * 1000 );
+ } else throw e;
}
- break;
- } catch (EOFException | ZipException e) {
- logger.error("Error unzipping support files : retryCount={} : {}", retryCount, e.getMessage());
- if (retryCount < FIBONACCI.length) {
- Thread.sleep( FIBONACCI[retryCount++] * 1000 );
- } else throw e;
}
+ } else {
+ logger.info("Controller-initiated WS mode enabled - skipping settings/support download from controller.");
}
// now start the harness
String controllerArg = " -http=" + controllerBaseUrl;
diff --git a/agent/apiharness/pom.xml b/agent/apiharness/pom.xml
index 1b16f203f..2e12e2965 100644
--- a/agent/apiharness/pom.xml
+++ b/agent/apiharness/pom.xml
@@ -73,5 +73,11 @@
org.openjdk.nashorn
nashorn-core
+
+
+ org.java-websocket
+ Java-WebSocket
+ 1.5.6
+
diff --git a/agent/apiharness/src/main/java/com/intuit/tank/harness/APIMonitor.java b/agent/apiharness/src/main/java/com/intuit/tank/harness/APIMonitor.java
index 7211a8467..f8ead2488 100644
--- a/agent/apiharness/src/main/java/com/intuit/tank/harness/APIMonitor.java
+++ b/agent/apiharness/src/main/java/com/intuit/tank/harness/APIMonitor.java
@@ -84,7 +84,7 @@ private void updateInstanceStatus() {
sendTps(tpsInfo);
}
- if (!isLocal) {
+ if (!isLocal && !APITestHarness.getInstance().isControllerInitiatedWsModeEnabled()) {
setInstanceStatus(newStatus.getInstanceId(), newStatus);
}
APITestHarness.getInstance().checkAgentThreads();
@@ -162,7 +162,9 @@ public synchronized static void setJobStatus(JobStatus jobStatus) {
stats.getMaxVirtualUsers(),
stats.getCurrentNumberUsers(), status.getStartTime(), endTime);
status.setUserDetails(APITestHarness.getInstance().getUserTracker().getSnapshot());
- setInstanceStatus(status.getInstanceId(), status);
+ if (!APITestHarness.getInstance().isControllerInitiatedWsModeEnabled()) {
+ setInstanceStatus(status.getInstanceId(), status);
+ }
} catch (Exception e) {
LOG.error("Error sending status to controller: {}", e.toString(), e);
}
diff --git a/agent/apiharness/src/main/java/com/intuit/tank/harness/APITestHarness.java b/agent/apiharness/src/main/java/com/intuit/tank/harness/APITestHarness.java
index fff55d082..07c597733 100644
--- a/agent/apiharness/src/main/java/com/intuit/tank/harness/APITestHarness.java
+++ b/agent/apiharness/src/main/java/com/intuit/tank/harness/APITestHarness.java
@@ -105,6 +105,7 @@ public class APITestHarness {
private TPSMonitor tpsMonitor;
private ResultsReporter resultsReporter;
private String tankHttpClientClass;
+ private AgentCommandWebSocketServer controllerInitiatedWsServer;
private Date send = new Date();
private static final int interval = 15; // SECONDS
@@ -226,6 +227,39 @@ private void initializeFromArgs(String[] args) {
}
}
+ public boolean isControllerInitiatedWsModeEnabled() {
+ try {
+ String value = AmazonUtil.getUserDataAsMap().get(TankConstants.KEY_CONTROLLER_INITIATED_WS_ENABLED);
+ if (StringUtils.isNotBlank(value)) {
+ return Boolean.parseBoolean(value);
+ }
+ } catch (Exception ignored) {
+ }
+ return tankConfig.getAgentConfig().isControllerInitiatedWsEnabled();
+ }
+
+ private boolean isControllerInitiatedWsDisableAgentHttpEnabled() {
+ try {
+ String value = AmazonUtil.getUserDataAsMap().get(TankConstants.KEY_CONTROLLER_INITIATED_WS_DISABLE_AGENT_HTTP);
+ if (StringUtils.isNotBlank(value)) {
+ return Boolean.parseBoolean(value);
+ }
+ } catch (Exception ignored) {
+ }
+ return tankConfig.getAgentConfig().isControllerInitiatedWsDisableAgentHttp();
+ }
+
+ private String getControllerInitiatedWsScriptPath() {
+ try {
+ String value = AmazonUtil.getUserDataAsMap().get(TankConstants.KEY_CONTROLLER_INITIATED_WS_SCRIPT_PATH);
+ if (StringUtils.isNotBlank(value)) {
+ return value;
+ }
+ } catch (Exception ignored) {
+ }
+ return tankConfig.getAgentConfig().getControllerInitiatedWsScriptPath();
+ }
+
private String getLocalInstanceId() {
isLocal = true;
try {
@@ -255,7 +289,12 @@ private static void usage() {
private void startHttp(String baseUrl, String token) {
isLocal = false;
HostInfo hostInfo = new HostInfo();
- CommandListener.startHttpServer(tankConfig.getAgentConfig().getAgentPort());
+ boolean controllerInitiatedWsMode = isControllerInitiatedWsModeEnabled();
+ boolean controllerInitiatedWsDisableAgentHttp = isControllerInitiatedWsDisableAgentHttpEnabled();
+
+ if (!controllerInitiatedWsMode || !controllerInitiatedWsDisableAgentHttp) {
+ CommandListener.startHttpServer(tankConfig.getAgentConfig().getAgentPort());
+ }
baseUrl = (baseUrl == null) ? AmazonUtil.getControllerBaseUrl() : baseUrl;
token = (token == null) ? AmazonUtil.getAgentToken() : token;
String instanceUrl = null;
@@ -291,6 +330,42 @@ private void startHttp(String baseUrl, String token) {
agentRunData.setStopBehavior(AmazonUtil.getStopBehavior());
LogUtil.getLogEvent().setJobId(agentRunData.getJobId());
+ if (controllerInitiatedWsMode && controllerInitiatedWsDisableAgentHttp) {
+ try {
+ if (controllerInitiatedWsServer == null) {
+ controllerInitiatedWsServer = new AgentCommandWebSocketServer(
+ tankConfig.getAgentConfig().getAgentPort(),
+ instanceId,
+ agentRunData.getJobId(),
+ capacity);
+ controllerInitiatedWsServer.start();
+ }
+
+ String scriptPath = getControllerInitiatedWsScriptPath();
+ if (StringUtils.isNotBlank(scriptPath) && new File(scriptPath).exists()) {
+ LOG.info(new ObjectMessage(Map.of("Message", "Controller-initiated WS mode loading script from " + scriptPath)));
+ TestPlanSingleton.getInstance().setTestPlans(scriptPath);
+ } else if (StringUtils.isNotBlank(testPlans) && AgentUtil.validateTestPlans(testPlans)) {
+ LOG.info(new ObjectMessage(Map.of("Message", "Controller-initiated WS mode loading script from args " + testPlans)));
+ TestPlanSingleton.getInstance().setTestPlans(testPlans);
+ } else {
+ LOG.warn(new ObjectMessage(Map.of("Message", "Controller-initiated WS mode has no valid local script path configured. Awaiting controller command anyway.")));
+ }
+
+ Thread thread = new Thread(new StartedChecker());
+ thread.setName("StartedChecker");
+ thread.setDaemon(false);
+ thread.start();
+ return;
+ } catch (Exception e) {
+ LOG.error("Error starting controller-initiated WS mode: " + e, e);
+ System.exit(0);
+ }
+ } else if (controllerInitiatedWsMode) {
+ LOG.warn(new ObjectMessage(Map.of("Message",
+ "Controller-initiated WS is enabled but HTTP disable flag is false; running legacy HTTP lifecycle path")));
+ }
+
AgentData data = new AgentData(agentRunData.getJobId(), instanceId, instanceUrl, capacity,
AmazonUtil.getVMRegion(), AmazonUtil.getZone());
try {
@@ -340,6 +415,15 @@ private void startHttp(String baseUrl, String token) {
saveDataFile(dfRequest, token);
}
}
+ // Start WS control channel if enabled
+ if (tankConfig.getAgentConfig().isCommandWsEnabled()) {
+ String wsPath = tankConfig.getAgentConfig().getCommandWsPath();
+ LOG.info(new ObjectMessage(Map.of("Message", "Starting WS control channel to " + baseUrl + wsPath)));
+ AgentCommandWebSocketClient wsClient = new AgentCommandWebSocketClient(
+ baseUrl, wsPath, token, instanceId, agentRunData.getJobId());
+ wsClient.connect();
+ }
+
Thread thread = new Thread(new StartedChecker());
thread.setName("StartedChecker");
thread.setDaemon(false);
diff --git a/agent/apiharness/src/main/java/com/intuit/tank/harness/AgentCommandWebSocketClient.java b/agent/apiharness/src/main/java/com/intuit/tank/harness/AgentCommandWebSocketClient.java
new file mode 100644
index 000000000..64e15f25d
--- /dev/null
+++ b/agent/apiharness/src/main/java/com/intuit/tank/harness/AgentCommandWebSocketClient.java
@@ -0,0 +1,255 @@
+package com.intuit.tank.harness;
+
+import com.intuit.tank.vm.agent.messages.AgentWsEnvelope;
+import com.intuit.tank.vm.agent.messages.AgentWsEnvelope.AckStatus;
+import org.apache.logging.log4j.LogManager;
+import org.apache.logging.log4j.Logger;
+import org.apache.logging.log4j.message.ObjectMessage;
+
+import java.io.IOException;
+import java.net.URI;
+import java.net.http.HttpClient;
+import java.net.http.WebSocket;
+import java.nio.ByteBuffer;
+import java.util.Map;
+import java.util.Set;
+import java.util.UUID;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.CompletionStage;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicReference;
+
+public class AgentCommandWebSocketClient implements WebSocket.Listener {
+
+ private static final Logger LOG = LogManager.getLogger(AgentCommandWebSocketClient.class);
+
+ private static final int[] BACKOFF_SECONDS = {1, 2, 5, 10, 20, 30};
+
+ private final String controllerBaseUrl;
+ private final String wsPath;
+ private final String token;
+ private final String instanceId;
+ private final String jobId;
+ private final String agentSessionId;
+ private final HttpClient httpClient;
+
+ private static final int MAX_APPLIED_COMMAND_IDS = 10_000;
+
+ private final AtomicReference webSocketRef = new AtomicReference<>();
+ private final AtomicBoolean running = new AtomicBoolean(true);
+ private final AtomicBoolean reconnecting = new AtomicBoolean(false);
+ private final Set appliedCommandIds = ConcurrentHashMap.newKeySet();
+ private volatile String lastAppliedCommandId = null;
+
+ // Buffer for accumulating partial text frames
+ private final StringBuilder messageBuffer = new StringBuilder();
+
+ public AgentCommandWebSocketClient(String controllerBaseUrl, String wsPath, String token,
+ String instanceId, String jobId) {
+ this.controllerBaseUrl = controllerBaseUrl;
+ this.wsPath = wsPath;
+ this.token = token;
+ this.instanceId = instanceId;
+ this.jobId = jobId;
+ this.agentSessionId = UUID.randomUUID().toString();
+ this.httpClient = HttpClient.newHttpClient();
+ }
+
+ public void connect() {
+ Thread connectThread = new Thread(this::connectWithRetry, "WS-Connect-" + instanceId);
+ connectThread.setDaemon(true);
+ connectThread.start();
+ }
+
+ private void connectWithRetry() {
+ int attempt = 0;
+ while (running.get()) {
+ try {
+ String wsUrl = buildWsUrl();
+ LOG.info(new ObjectMessage(Map.of("Message", "Connecting WS to " + wsUrl)));
+
+ WebSocket ws = httpClient.newWebSocketBuilder()
+ .header("Authorization", "bearer " + token)
+ .buildAsync(URI.create(wsUrl), this)
+ .join();
+
+ webSocketRef.set(ws);
+ attempt = 0; // reset on successful connect
+
+ // Send hello
+ AgentWsEnvelope hello = AgentWsEnvelope.hello(instanceId, jobId, agentSessionId, lastAppliedCommandId);
+ ws.sendText(hello.toJson(), true);
+ LOG.info(new ObjectMessage(Map.of("Message", "WS hello sent for agent " + instanceId + " job " + jobId)));
+
+ // Block until connection closes (handled by listener callbacks)
+ // The listener methods will handle incoming frames
+ return;
+
+ } catch (Exception e) {
+ int backoff = BACKOFF_SECONDS[Math.min(attempt, BACKOFF_SECONDS.length - 1)];
+ // Add jitter: 0-500ms
+ int jitter = (int) (Math.random() * 500);
+ LOG.warn(new ObjectMessage(Map.of("Message", "WS connect failed (attempt " + attempt + "): " + e.getMessage()
+ + ", retrying in " + backoff + "s")));
+ try {
+ Thread.sleep(backoff * 1000L + jitter);
+ } catch (InterruptedException ignored) {
+ Thread.currentThread().interrupt();
+ return;
+ }
+ attempt++;
+ }
+ }
+ }
+
+ private String buildWsUrl() {
+ String base = controllerBaseUrl;
+ // Convert http(s) to ws(s)
+ if (base.startsWith("https://")) {
+ base = "wss://" + base.substring("https://".length());
+ } else if (base.startsWith("http://")) {
+ base = "ws://" + base.substring("http://".length());
+ }
+ // Remove trailing slash
+ if (base.endsWith("/")) {
+ base = base.substring(0, base.length() - 1);
+ }
+ return base + wsPath;
+ }
+
+ @Override
+ public void onOpen(WebSocket webSocket) {
+ LOG.info(new ObjectMessage(Map.of("Message", "WS connection opened for agent " + instanceId)));
+ webSocket.request(1);
+ }
+
+ @Override
+ public CompletionStage> onText(WebSocket webSocket, CharSequence data, boolean last) {
+ messageBuffer.append(data);
+ if (last) {
+ String fullMessage = messageBuffer.toString();
+ messageBuffer.setLength(0);
+ handleMessage(webSocket, fullMessage);
+ }
+ webSocket.request(1);
+ return null;
+ }
+
+ @Override
+ public CompletionStage> onPing(WebSocket webSocket, ByteBuffer message) {
+ webSocket.request(1);
+ return null;
+ }
+
+ @Override
+ public CompletionStage> onClose(WebSocket webSocket, int statusCode, String reason) {
+ LOG.info(new ObjectMessage(Map.of("Message", "WS closed: code=" + statusCode + " reason=" + reason)));
+ webSocketRef.set(null);
+ scheduleReconnect();
+ return null;
+ }
+
+ @Override
+ public void onError(WebSocket webSocket, Throwable error) {
+ LOG.error(new ObjectMessage(Map.of("Message", "WS error for agent " + instanceId + ": " + error.getMessage())), error);
+ webSocketRef.set(null);
+ scheduleReconnect();
+ }
+
+ private void scheduleReconnect() {
+ if (running.get() && reconnecting.compareAndSet(false, true)) {
+ Thread reconnect = new Thread(() -> {
+ try {
+ connectWithRetry();
+ } finally {
+ reconnecting.set(false);
+ }
+ }, "WS-Reconnect-" + instanceId);
+ reconnect.setDaemon(true);
+ reconnect.start();
+ }
+ }
+
+ private void handleMessage(WebSocket webSocket, String text) {
+ try {
+ AgentWsEnvelope envelope = AgentWsEnvelope.fromJson(text);
+ if (envelope.getType() == null) {
+ LOG.warn(new ObjectMessage(Map.of("Message", "Received WS frame with no type")));
+ return;
+ }
+
+ switch (envelope.getType()) {
+ case command -> handleCommand(webSocket, envelope);
+ case ping -> handlePing(webSocket, envelope);
+ case ack -> LOG.info(new ObjectMessage(Map.of("Message", "Received ack: " + envelope.getAckForType())));
+ default -> LOG.warn(new ObjectMessage(Map.of("Message", "Unexpected WS frame type: " + envelope.getType())));
+ }
+ } catch (IOException e) {
+ LOG.warn(new ObjectMessage(Map.of("Message", "Failed to parse WS frame: " + e.getMessage())));
+ }
+ }
+
+ private void handleCommand(WebSocket webSocket, AgentWsEnvelope envelope) {
+ String commandId = envelope.getCommandId();
+ String command = envelope.getCommand();
+
+ // Bound the dedup set to prevent unbounded growth
+ if (appliedCommandIds.size() > MAX_APPLIED_COMMAND_IDS) {
+ appliedCommandIds.clear();
+ LOG.info(new ObjectMessage(Map.of("Message", "Cleared appliedCommandIds set (exceeded " + MAX_APPLIED_COMMAND_IDS + ")")));
+ }
+
+ // Deduplicate
+ if (commandId != null && !appliedCommandIds.add(commandId)) {
+ LOG.info(new ObjectMessage(Map.of("Message", "Duplicate WS command " + commandId + " ignored")));
+ sendAck(webSocket, commandId, AckStatus.duplicate);
+ return;
+ }
+
+ LOG.info(new ObjectMessage(Map.of("Message", "Received WS command: " + command + " (id=" + commandId + ")")));
+
+ try {
+ applyCommand(command);
+ lastAppliedCommandId = commandId;
+ sendAck(webSocket, commandId, AckStatus.ok);
+ } catch (Exception e) {
+ LOG.error(new ObjectMessage(Map.of("Message", "Error applying WS command " + command + ": " + e.getMessage())), e);
+ sendAck(webSocket, commandId, AckStatus.failed);
+ }
+ }
+
+ private void applyCommand(String command) {
+ CommandListener.applyCommand(command);
+ }
+
+ private void handlePing(WebSocket webSocket, AgentWsEnvelope envelope) {
+ try {
+ AgentWsEnvelope pong = AgentWsEnvelope.pong(instanceId, agentSessionId, envelope.getPingId(), lastAppliedCommandId);
+ webSocket.sendText(pong.toJson(), true);
+ } catch (IOException e) {
+ LOG.warn(new ObjectMessage(Map.of("Message", "Failed to send pong: " + e.getMessage())));
+ }
+ }
+
+ private void sendAck(WebSocket webSocket, String commandId, AckStatus status) {
+ try {
+ AgentWsEnvelope ack = AgentWsEnvelope.ack(instanceId, "command", commandId, status);
+ webSocket.sendText(ack.toJson(), true);
+ } catch (IOException e) {
+ LOG.warn(new ObjectMessage(Map.of("Message", "Failed to send ack for " + commandId + ": " + e.getMessage())));
+ }
+ }
+
+ public void close() {
+ running.set(false);
+ WebSocket ws = webSocketRef.get();
+ if (ws != null) {
+ try {
+ AgentWsEnvelope closeEnvelope = AgentWsEnvelope.close(instanceId, "agent_shutdown", "Agent shutting down");
+ ws.sendText(closeEnvelope.toJson(), true);
+ } catch (IOException ignored) {}
+ ws.sendClose(WebSocket.NORMAL_CLOSURE, "Agent shutting down");
+ }
+ }
+}
diff --git a/agent/apiharness/src/main/java/com/intuit/tank/harness/AgentCommandWebSocketServer.java b/agent/apiharness/src/main/java/com/intuit/tank/harness/AgentCommandWebSocketServer.java
new file mode 100644
index 000000000..752a8a246
--- /dev/null
+++ b/agent/apiharness/src/main/java/com/intuit/tank/harness/AgentCommandWebSocketServer.java
@@ -0,0 +1,134 @@
+package com.intuit.tank.harness;
+
+import com.intuit.tank.vm.agent.messages.AgentWsEnvelope;
+import com.intuit.tank.vm.agent.messages.AgentWsEnvelope.AckStatus;
+import org.apache.logging.log4j.LogManager;
+import org.apache.logging.log4j.Logger;
+import org.apache.logging.log4j.message.ObjectMessage;
+import org.java_websocket.WebSocket;
+import org.java_websocket.handshake.ClientHandshake;
+import org.java_websocket.server.WebSocketServer;
+
+import java.io.IOException;
+import java.net.InetSocketAddress;
+import java.util.Map;
+import java.util.Set;
+import java.util.UUID;
+import java.util.concurrent.ConcurrentHashMap;
+
+public class AgentCommandWebSocketServer extends WebSocketServer {
+
+ private static final Logger LOG = LogManager.getLogger(AgentCommandWebSocketServer.class);
+ private static final int MAX_APPLIED_COMMAND_IDS = 10_000;
+
+ private final String instanceId;
+ private final String jobId;
+ private final int capacity;
+ private final String agentSessionId;
+
+ private final Set appliedCommandIds = ConcurrentHashMap.newKeySet();
+ private volatile String lastAppliedCommandId;
+
+ public AgentCommandWebSocketServer(int port, String instanceId, String jobId, int capacity) {
+ super(new InetSocketAddress(port));
+ this.instanceId = instanceId;
+ this.jobId = jobId;
+ this.capacity = capacity;
+ this.agentSessionId = UUID.randomUUID().toString();
+ }
+
+ @Override
+ public void onOpen(WebSocket connection, ClientHandshake handshake) {
+ LOG.info(new ObjectMessage(Map.of("Message",
+ "Controller WS connected to agent " + instanceId + " from " + connection.getRemoteSocketAddress())));
+ try {
+ AgentWsEnvelope hello = AgentWsEnvelope.hello(instanceId, jobId, agentSessionId, lastAppliedCommandId, capacity);
+ connection.send(hello.toJson());
+ } catch (IOException e) {
+ LOG.warn(new ObjectMessage(Map.of("Message", "Failed to send WS hello: " + e.getMessage())));
+ connection.close();
+ }
+ }
+
+ @Override
+ public void onMessage(WebSocket connection, String message) {
+ try {
+ AgentWsEnvelope envelope = AgentWsEnvelope.fromJson(message);
+ if (envelope.getType() == null) {
+ LOG.warn(new ObjectMessage(Map.of("Message", "WS frame missing type")));
+ return;
+ }
+
+ switch (envelope.getType()) {
+ case command -> handleCommand(connection, envelope);
+ case ping -> handlePing(connection, envelope);
+ case close -> {
+ LOG.info(new ObjectMessage(Map.of("Message", "Controller closed WS: " + envelope.getReason())));
+ connection.close();
+ }
+ default -> LOG.warn(new ObjectMessage(Map.of("Message", "Unexpected WS frame type: " + envelope.getType())));
+ }
+ } catch (IOException e) {
+ LOG.warn(new ObjectMessage(Map.of("Message", "Failed parsing WS message: " + e.getMessage())));
+ }
+ }
+
+ private void handleCommand(WebSocket connection, AgentWsEnvelope envelope) {
+ String commandId = envelope.getCommandId();
+ String command = envelope.getCommand();
+
+ if (appliedCommandIds.size() > MAX_APPLIED_COMMAND_IDS) {
+ appliedCommandIds.clear();
+ }
+
+ if (commandId != null && !appliedCommandIds.add(commandId)) {
+ sendAck(connection, commandId, AckStatus.duplicate, null);
+ return;
+ }
+
+ try {
+ CommandListener.applyCommand(command);
+ lastAppliedCommandId = commandId;
+ sendAck(connection, commandId, AckStatus.ok, null);
+ } catch (Exception e) {
+ LOG.warn(new ObjectMessage(Map.of("Message",
+ "Failed to apply command " + command + " for agent " + instanceId + ": " + e.getMessage())));
+ sendAck(connection, commandId, AckStatus.failed, e.getMessage());
+ }
+ }
+
+ private void handlePing(WebSocket connection, AgentWsEnvelope envelope) {
+ try {
+ AgentWsEnvelope pong = AgentWsEnvelope.pong(instanceId, agentSessionId, envelope.getPingId(), lastAppliedCommandId);
+ connection.send(pong.toJson());
+ } catch (IOException e) {
+ LOG.warn(new ObjectMessage(Map.of("Message", "Failed to send pong: " + e.getMessage())));
+ }
+ }
+
+ private void sendAck(WebSocket connection, String commandId, AckStatus status, String error) {
+ try {
+ AgentWsEnvelope ack = AgentWsEnvelope.ack(instanceId, "command", commandId, status);
+ ack.setError(error);
+ connection.send(ack.toJson());
+ } catch (IOException e) {
+ LOG.warn(new ObjectMessage(Map.of("Message", "Failed to send command ack: " + e.getMessage())));
+ }
+ }
+
+ @Override
+ public void onClose(WebSocket connection, int code, String reason, boolean remote) {
+ LOG.info(new ObjectMessage(Map.of("Message",
+ "Controller WS disconnected from agent " + instanceId + " code=" + code + " reason=" + reason)));
+ }
+
+ @Override
+ public void onError(WebSocket connection, Exception ex) {
+ LOG.error(new ObjectMessage(Map.of("Message", "Agent WS server transport error: " + ex.getMessage())), ex);
+ }
+
+ @Override
+ public void onStart() {
+ LOG.info(new ObjectMessage(Map.of("Message", "Agent WS control server started for " + instanceId)));
+ }
+}
diff --git a/agent/apiharness/src/main/java/com/intuit/tank/harness/CommandListener.java b/agent/apiharness/src/main/java/com/intuit/tank/harness/CommandListener.java
index cff991a2b..6de334a51 100644
--- a/agent/apiharness/src/main/java/com/intuit/tank/harness/CommandListener.java
+++ b/agent/apiharness/src/main/java/com/intuit/tank/harness/CommandListener.java
@@ -76,29 +76,10 @@ private static void handleRequest(HttpExchange exchange) {
try {
String response = "Not Found";
String path = exchange.getRequestURI().getPath();
- if (path.equals(AgentCommand.start.getPath()) || path.equals(AgentCommand.run.getPath())) {
- response = "Received command " + path + ", Starting Test JobId=" + APITestHarness.getInstance().getAgentRunData().getJobId();
- LOG.info(LogUtil.getLogMessage("Received START command - launching test threads for job " +
- APITestHarness.getInstance().getAgentRunData().getJobId()));
- startTest();
- } else if (path.startsWith(AgentCommand.stop.getPath())) {
- response = "Received command " + path + ", Stopping Test JobId=" + APITestHarness.getInstance().getAgentRunData().getJobId();
- APITestHarness.getInstance().setCommand(AgentCommand.stop);
- } else if (path.startsWith(AgentCommand.kill.getPath())) {
- response = "Received command " + path + ", Killing Test JobId=" + APITestHarness.getInstance().getAgentRunData().getJobId();
- System.exit(0);
- } else if (path.startsWith(AgentCommand.pause.getPath())) {
- response = "Received command " + path + ", Pausing Test JobId=" + APITestHarness.getInstance().getAgentRunData().getJobId();
- APITestHarness.getInstance().setCommand(AgentCommand.pause);
- } else if (path.startsWith(AgentCommand.pause_ramp.getPath())) {
- response = "Received command " + path + ", Pausing Ramp for Test JobId=" + APITestHarness.getInstance().getAgentRunData().getJobId();
- APITestHarness.getInstance().setCommand(AgentCommand.pause_ramp);
- } else if (path.startsWith(AgentCommand.resume_ramp.getPath())) {
- response = "Received command " + path + ", Resume Test JobId=" + APITestHarness.getInstance().getAgentRunData().getJobId();
- APITestHarness.getInstance().setCommand(AgentCommand.resume_ramp);
- } else if (path.startsWith(AgentCommand.status.getPath())) {
- response = APITestHarness.getInstance().getStatus().toString();
- APITestHarness.getInstance().setCommand(AgentCommand.resume_ramp);
+ try {
+ response = applyHttpPathCommand(path);
+ } catch (UnsupportedOperationException ignored) {
+ response = "Not Found";
}
LOG.info(LogUtil.getLogMessage(response));
@@ -118,6 +99,68 @@ private static void handleRequest(HttpExchange exchange) {
}
}
+ private static String applyHttpPathCommand(String path) {
+ if (path.equals(AgentCommand.start.getPath()) || path.equals(AgentCommand.run.getPath())) {
+ LOG.info(LogUtil.getLogMessage("Received START command - launching test threads for job " +
+ APITestHarness.getInstance().getAgentRunData().getJobId()));
+ executeCommand(AgentCommand.start);
+ return "Received command " + path + ", Starting Test JobId=" + APITestHarness.getInstance().getAgentRunData().getJobId();
+ }
+ if (path.startsWith(AgentCommand.stop.getPath())) {
+ executeCommand(AgentCommand.stop);
+ return "Received command " + path + ", Stopping Test JobId=" + APITestHarness.getInstance().getAgentRunData().getJobId();
+ }
+ if (path.startsWith(AgentCommand.kill.getPath())) {
+ executeCommand(AgentCommand.kill);
+ return "Received command " + path + ", Killing Test JobId=" + APITestHarness.getInstance().getAgentRunData().getJobId();
+ }
+ if (path.startsWith(AgentCommand.pause.getPath())) {
+ executeCommand(AgentCommand.pause);
+ return "Received command " + path + ", Pausing Test JobId=" + APITestHarness.getInstance().getAgentRunData().getJobId();
+ }
+ if (path.startsWith(AgentCommand.pause_ramp.getPath())) {
+ executeCommand(AgentCommand.pause_ramp);
+ return "Received command " + path + ", Pausing Ramp for Test JobId=" + APITestHarness.getInstance().getAgentRunData().getJobId();
+ }
+ if (path.startsWith(AgentCommand.resume_ramp.getPath())) {
+ executeCommand(AgentCommand.resume_ramp);
+ return "Received command " + path + ", Resume Test JobId=" + APITestHarness.getInstance().getAgentRunData().getJobId();
+ }
+ if (path.startsWith(AgentCommand.status.getPath())) {
+ executeCommand(AgentCommand.status);
+ return APITestHarness.getInstance().getStatus().toString();
+ }
+ throw new UnsupportedOperationException("Unknown command path: " + path);
+ }
+
+ public static void applyCommand(String command) {
+ if (command == null) {
+ throw new UnsupportedOperationException("Unknown command: null");
+ }
+ switch (command) {
+ case "start", "run" -> executeCommand(AgentCommand.start);
+ case "stop" -> executeCommand(AgentCommand.stop);
+ case "kill" -> executeCommand(AgentCommand.kill);
+ case "pause" -> executeCommand(AgentCommand.pause);
+ case "pause_ramp" -> executeCommand(AgentCommand.pause_ramp);
+ case "resume_ramp", "resume" -> executeCommand(AgentCommand.resume_ramp);
+ case "status" -> executeCommand(AgentCommand.status);
+ default -> throw new UnsupportedOperationException("Unknown command: " + command);
+ }
+ }
+
+ private static void executeCommand(AgentCommand command) {
+ switch (command) {
+ case start, run -> startTest();
+ case stop -> APITestHarness.getInstance().setCommand(AgentCommand.stop);
+ case kill -> System.exit(0);
+ case pause -> APITestHarness.getInstance().setCommand(AgentCommand.pause);
+ case pause_ramp -> APITestHarness.getInstance().setCommand(AgentCommand.pause_ramp);
+ case resume_ramp, status -> APITestHarness.getInstance().setCommand(AgentCommand.resume_ramp);
+ default -> throw new UnsupportedOperationException("Unknown command: " + command);
+ }
+ }
+
public static void startTest() {
Thread thread = new Thread( () -> APITestHarness.getInstance().runConcurrentTestPlans());
thread.setDaemon(true);
diff --git a/agent/apiharness/src/test/java/com/intuit/tank/harness/APITestHarnessTest.java b/agent/apiharness/src/test/java/com/intuit/tank/harness/APITestHarnessTest.java
index 96f9e513f..cc622eb9d 100644
--- a/agent/apiharness/src/test/java/com/intuit/tank/harness/APITestHarnessTest.java
+++ b/agent/apiharness/src/test/java/com/intuit/tank/harness/APITestHarnessTest.java
@@ -19,6 +19,7 @@
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.PrintStream;
+import java.util.Map;
import java.util.stream.IntStream;
import static org.junit.jupiter.api.Assertions.*;
@@ -84,6 +85,15 @@ public void testInitializeFromMainArgs() throws IOException {
}
}
+ @Test
+ public void testControllerInitiatedWsModeEnabledFromUserData() {
+ try (MockedStatic amazonUtil = Mockito.mockStatic(AmazonUtil.class)) {
+ amazonUtil.when(AmazonUtil::getUserDataAsMap).thenReturn(
+ Map.of("controllerInitiatedWsEnabled", "true"));
+ assertTrue(APITestHarness.getInstance().isControllerInitiatedWsModeEnabled());
+ }
+ }
+
@Test
public void testUsageMain() {
PrintStream originalOut = System.out;
diff --git a/agent/apiharness/src/test/java/com/intuit/tank/harness/AgentCommandWebSocketClientTest.java b/agent/apiharness/src/test/java/com/intuit/tank/harness/AgentCommandWebSocketClientTest.java
new file mode 100644
index 000000000..f0681421c
--- /dev/null
+++ b/agent/apiharness/src/test/java/com/intuit/tank/harness/AgentCommandWebSocketClientTest.java
@@ -0,0 +1,45 @@
+package com.intuit.tank.harness;
+
+import org.junit.jupiter.api.Test;
+
+import static org.junit.jupiter.api.Assertions.*;
+
+public class AgentCommandWebSocketClientTest {
+
+ @Test
+ void testConstructorSetsFields() {
+ AgentCommandWebSocketClient client = new AgentCommandWebSocketClient(
+ "http://controller.example.com/tank",
+ "/v2/agent/ws/control",
+ "test-token",
+ "i-123",
+ "job-1"
+ );
+ assertNotNull(client);
+ }
+
+ @Test
+ void testConstructorWithHttpsUrl() {
+ AgentCommandWebSocketClient client = new AgentCommandWebSocketClient(
+ "https://controller.example.com/tank",
+ "/v2/agent/ws/control",
+ "test-token",
+ "i-456",
+ "job-2"
+ );
+ assertNotNull(client);
+ }
+
+ @Test
+ void testCloseBeforeConnect() {
+ AgentCommandWebSocketClient client = new AgentCommandWebSocketClient(
+ "http://localhost:8080/tank",
+ "/v2/agent/ws/control",
+ "token",
+ "i-789",
+ "job-3"
+ );
+ // Should not throw
+ client.close();
+ }
+}
diff --git a/agent/apiharness/src/test/java/com/intuit/tank/harness/CommandListenerTest.java b/agent/apiharness/src/test/java/com/intuit/tank/harness/CommandListenerTest.java
new file mode 100644
index 000000000..aedc6bf4d
--- /dev/null
+++ b/agent/apiharness/src/test/java/com/intuit/tank/harness/CommandListenerTest.java
@@ -0,0 +1,13 @@
+package com.intuit.tank.harness;
+
+import org.junit.jupiter.api.Test;
+
+import static org.junit.jupiter.api.Assertions.assertThrows;
+
+public class CommandListenerTest {
+
+ @Test
+ public void testApplyCommandUnknownThrows() {
+ assertThrows(UnsupportedOperationException.class, () -> CommandListener.applyCommand("not-a-real-command"));
+ }
+}
diff --git a/api/src/main/java/com/intuit/tank/vm/agent/messages/AgentWsCommandSender.java b/api/src/main/java/com/intuit/tank/vm/agent/messages/AgentWsCommandSender.java
new file mode 100644
index 000000000..18330d8d1
--- /dev/null
+++ b/api/src/main/java/com/intuit/tank/vm/agent/messages/AgentWsCommandSender.java
@@ -0,0 +1,19 @@
+package com.intuit.tank.vm.agent.messages;
+
+/**
+ * Interface for sending commands to agents via WebSocket.
+ * Implemented by the controller's WS handler, consumed by JobManager via CDI.
+ */
+public interface AgentWsCommandSender {
+
+ /**
+ * Check if an agent has an active WS session.
+ */
+ boolean hasSession(String instanceId);
+
+ /**
+ * Send a command to an agent via WS and wait for ack.
+ * @return true if command was acked successfully within timeout
+ */
+ boolean sendCommand(String instanceId, String jobId, String command, long ackTimeoutMillis);
+}
diff --git a/api/src/main/java/com/intuit/tank/vm/agent/messages/AgentWsEnvelope.java b/api/src/main/java/com/intuit/tank/vm/agent/messages/AgentWsEnvelope.java
new file mode 100644
index 000000000..4844682dd
--- /dev/null
+++ b/api/src/main/java/com/intuit/tank/vm/agent/messages/AgentWsEnvelope.java
@@ -0,0 +1,212 @@
+package com.intuit.tank.vm.agent.messages;
+
+import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
+import com.fasterxml.jackson.annotation.JsonInclude;
+import com.fasterxml.jackson.annotation.JsonProperty;
+import com.fasterxml.jackson.databind.ObjectMapper;
+
+import java.io.IOException;
+
+@JsonIgnoreProperties(ignoreUnknown = true)
+@JsonInclude(JsonInclude.Include.NON_NULL)
+public class AgentWsEnvelope {
+
+ private static final ObjectMapper MAPPER = new ObjectMapper();
+
+ public static final int PROTOCOL_VERSION = 1;
+
+ public enum Type {
+ hello, command, ack, ping, pong, close
+ }
+
+ public enum AckStatus {
+ ok, duplicate, failed, unsupported
+ }
+
+ @JsonProperty("type")
+ private Type type;
+
+ @JsonProperty("instanceId")
+ private String instanceId;
+
+ @JsonProperty("sentAtMs")
+ private long sentAtMs;
+
+ @JsonProperty("protocolVersion")
+ private int protocolVersion = PROTOCOL_VERSION;
+
+ // hello fields
+ @JsonProperty("jobId")
+ private String jobId;
+
+ @JsonProperty("agentSessionId")
+ private String agentSessionId;
+
+ @JsonProperty("capacity")
+ private Integer capacity;
+
+ // command fields
+ @JsonProperty("commandId")
+ private String commandId;
+
+ @JsonProperty("command")
+ private String command;
+
+ // ack fields
+ @JsonProperty("ackForType")
+ private String ackForType;
+
+ @JsonProperty("ackForId")
+ private String ackForId;
+
+ @JsonProperty("status")
+ private AckStatus status;
+
+ @JsonProperty("error")
+ private String error;
+
+ // ping/pong fields
+ @JsonProperty("pingId")
+ private String pingId;
+
+ @JsonProperty("lastAppliedCommandId")
+ private String lastAppliedCommandId;
+
+ // close fields
+ @JsonProperty("reasonCode")
+ private String reasonCode;
+
+ @JsonProperty("reason")
+ private String reason;
+
+ public AgentWsEnvelope() {}
+
+ public Type getType() { return type; }
+ public void setType(Type type) { this.type = type; }
+
+ public String getInstanceId() { return instanceId; }
+ public void setInstanceId(String instanceId) { this.instanceId = instanceId; }
+
+ public long getSentAtMs() { return sentAtMs; }
+ public void setSentAtMs(long sentAtMs) { this.sentAtMs = sentAtMs; }
+
+ public int getProtocolVersion() { return protocolVersion; }
+ public void setProtocolVersion(int protocolVersion) { this.protocolVersion = protocolVersion; }
+
+ public String getJobId() { return jobId; }
+ public void setJobId(String jobId) { this.jobId = jobId; }
+
+ public String getAgentSessionId() { return agentSessionId; }
+ public void setAgentSessionId(String agentSessionId) { this.agentSessionId = agentSessionId; }
+
+ public Integer getCapacity() { return capacity; }
+ public void setCapacity(Integer capacity) { this.capacity = capacity; }
+
+ public String getCommandId() { return commandId; }
+ public void setCommandId(String commandId) { this.commandId = commandId; }
+
+ public String getCommand() { return command; }
+ public void setCommand(String command) { this.command = command; }
+
+ public String getAckForType() { return ackForType; }
+ public void setAckForType(String ackForType) { this.ackForType = ackForType; }
+
+ public String getAckForId() { return ackForId; }
+ public void setAckForId(String ackForId) { this.ackForId = ackForId; }
+
+ public AckStatus getStatus() { return status; }
+ public void setStatus(AckStatus status) { this.status = status; }
+
+ public String getError() { return error; }
+ public void setError(String error) { this.error = error; }
+
+ public String getPingId() { return pingId; }
+ public void setPingId(String pingId) { this.pingId = pingId; }
+
+ public String getLastAppliedCommandId() { return lastAppliedCommandId; }
+ public void setLastAppliedCommandId(String lastAppliedCommandId) { this.lastAppliedCommandId = lastAppliedCommandId; }
+
+ public String getReasonCode() { return reasonCode; }
+ public void setReasonCode(String reasonCode) { this.reasonCode = reasonCode; }
+
+ public String getReason() { return reason; }
+ public void setReason(String reason) { this.reason = reason; }
+
+ public String toJson() throws IOException {
+ return MAPPER.writeValueAsString(this);
+ }
+
+ public static AgentWsEnvelope fromJson(String json) throws IOException {
+ return MAPPER.readValue(json, AgentWsEnvelope.class);
+ }
+
+ // Factory methods for common frame types
+
+ public static AgentWsEnvelope hello(String instanceId, String jobId, String agentSessionId, String lastAppliedCommandId) {
+ return hello(instanceId, jobId, agentSessionId, lastAppliedCommandId, null);
+ }
+
+ public static AgentWsEnvelope hello(String instanceId, String jobId, String agentSessionId,
+ String lastAppliedCommandId, Integer capacity) {
+ AgentWsEnvelope env = new AgentWsEnvelope();
+ env.setType(Type.hello);
+ env.setInstanceId(instanceId);
+ env.setJobId(jobId);
+ env.setAgentSessionId(agentSessionId);
+ env.setLastAppliedCommandId(lastAppliedCommandId);
+ env.setCapacity(capacity);
+ env.setSentAtMs(System.currentTimeMillis());
+ return env;
+ }
+
+ public static AgentWsEnvelope command(String commandId, String instanceId, String jobId, String command) {
+ AgentWsEnvelope env = new AgentWsEnvelope();
+ env.setType(Type.command);
+ env.setCommandId(commandId);
+ env.setInstanceId(instanceId);
+ env.setJobId(jobId);
+ env.setCommand(command);
+ env.setSentAtMs(System.currentTimeMillis());
+ return env;
+ }
+
+ public static AgentWsEnvelope ack(String instanceId, String ackForType, String ackForId, AckStatus status) {
+ AgentWsEnvelope env = new AgentWsEnvelope();
+ env.setType(Type.ack);
+ env.setInstanceId(instanceId);
+ env.setAckForType(ackForType);
+ env.setAckForId(ackForId);
+ env.setStatus(status);
+ env.setSentAtMs(System.currentTimeMillis());
+ return env;
+ }
+
+ public static AgentWsEnvelope ping(String pingId) {
+ AgentWsEnvelope env = new AgentWsEnvelope();
+ env.setType(Type.ping);
+ env.setPingId(pingId);
+ env.setSentAtMs(System.currentTimeMillis());
+ return env;
+ }
+
+ public static AgentWsEnvelope pong(String instanceId, String agentSessionId, String pingId, String lastAppliedCommandId) {
+ AgentWsEnvelope env = new AgentWsEnvelope();
+ env.setType(Type.pong);
+ env.setInstanceId(instanceId);
+ env.setAgentSessionId(agentSessionId);
+ env.setPingId(pingId);
+ env.setLastAppliedCommandId(lastAppliedCommandId);
+ env.setSentAtMs(System.currentTimeMillis());
+ return env;
+ }
+
+ public static AgentWsEnvelope close(String instanceId, String reasonCode, String reason) {
+ AgentWsEnvelope env = new AgentWsEnvelope();
+ env.setType(Type.close);
+ env.setInstanceId(instanceId);
+ env.setReasonCode(reasonCode);
+ env.setReason(reason);
+ env.setSentAtMs(System.currentTimeMillis());
+ return env;
+ }
+}
diff --git a/api/src/main/java/com/intuit/tank/vm/common/TankConstants.java b/api/src/main/java/com/intuit/tank/vm/common/TankConstants.java
index d730d284e..124749dda 100644
--- a/api/src/main/java/com/intuit/tank/vm/common/TankConstants.java
+++ b/api/src/main/java/com/intuit/tank/vm/common/TankConstants.java
@@ -58,6 +58,9 @@ public class TankConstants {
public static final String KEY_JVM_ARGS = "jvmArgs";
public static final String KEY_AWS_SECRET_KEY_ID = "AWS_SECRET_KEY_ID";
public static final String KEY_AWS_SECRET_KEY = "AWS_SECRET_KEY";
+ public static final String KEY_CONTROLLER_INITIATED_WS_ENABLED = "controllerInitiatedWsEnabled";
+ public static final String KEY_CONTROLLER_INITIATED_WS_DISABLE_AGENT_HTTP = "controllerInitiatedWsDisableAgentHttp";
+ public static final String KEY_CONTROLLER_INITIATED_WS_SCRIPT_PATH = "controllerInitiatedWsScriptPath";
public static final String HTTP_CASE_SKIP = "SKIP";
public static final String HTTP_CASE_SKIPGROUP = "SKIPGROUP";
diff --git a/api/src/main/java/com/intuit/tank/vm/settings/AgentConfig.java b/api/src/main/java/com/intuit/tank/vm/settings/AgentConfig.java
index 54efd4b68..db7072efa 100644
--- a/api/src/main/java/com/intuit/tank/vm/settings/AgentConfig.java
+++ b/api/src/main/java/com/intuit/tank/vm/settings/AgentConfig.java
@@ -64,6 +64,14 @@ public class AgentConfig implements Serializable {
private static final String KEY_LOG_VARIABLES = "log-variables";
private static final String KEY_CONNECTION_TIMEOUT = "connection-timeout";
+ private static final String KEY_COMMAND_WS_ENABLED = "command-ws-enabled";
+ private static final String KEY_COMMAND_WS_HTTP_FALLBACK_ENABLED = "command-ws-http-fallback-enabled";
+ private static final String KEY_COMMAND_WS_ACK_TIMEOUT_MILLIS = "command-ws-ack-timeout-millis";
+ private static final String KEY_COMMAND_WS_PATH = "command-ws-path";
+ private static final String KEY_CONTROLLER_INITIATED_WS_ENABLED = "controller-initiated-ws-enabled";
+ private static final String KEY_CONTROLLER_INITIATED_WS_DISABLE_AGENT_HTTP = "controller-initiated-ws-disable-agent-http";
+ private static final String KEY_CONTROLLER_INITIATED_WS_SCRIPT_PATH = "controller-initiated-ws-script-path";
+
private static final String KEY_REQUEST_HEADERS = "request-headers/header";
// private static final String KEY_RESULT_PROVIDERS =
// "result-providers/provider";
@@ -336,4 +344,32 @@ public long getStatusReportIntervalMilis(long pollTime) {
return config.getLong(KEY_POLL_TIME_MILIS, pollTime);
}
+ public boolean isCommandWsEnabled() {
+ return config.getBoolean(KEY_COMMAND_WS_ENABLED, false);
+ }
+
+ public boolean isCommandWsHttpFallbackEnabled() {
+ return config.getBoolean(KEY_COMMAND_WS_HTTP_FALLBACK_ENABLED, true);
+ }
+
+ public long getCommandWsAckTimeoutMillis() {
+ return config.getLong(KEY_COMMAND_WS_ACK_TIMEOUT_MILLIS, 3000L);
+ }
+
+ public String getCommandWsPath() {
+ return config.getString(KEY_COMMAND_WS_PATH, "/v2/agent/ws/control");
+ }
+
+ public boolean isControllerInitiatedWsEnabled() {
+ return config.getBoolean(KEY_CONTROLLER_INITIATED_WS_ENABLED, false);
+ }
+
+ public boolean isControllerInitiatedWsDisableAgentHttp() {
+ return config.getBoolean(KEY_CONTROLLER_INITIATED_WS_DISABLE_AGENT_HTTP, true);
+ }
+
+ public String getControllerInitiatedWsScriptPath() {
+ return config.getString(KEY_CONTROLLER_INITIATED_WS_SCRIPT_PATH, "script.xml");
+ }
+
}
diff --git a/api/src/main/resources/settings.xml b/api/src/main/resources/settings.xml
index b6f1105bc..63a9a7ee8 100644
--- a/api/src/main/resources/settings.xml
+++ b/api/src/main/resources/settings.xml
@@ -137,6 +137,15 @@
JDK HttpClient
+
+ false
+ true
+ 3000
+ /v2/agent/ws/control
+ false
+ true
+ script.xml
+
diff --git a/api/src/test/java/com/intuit/tank/vm/agent/messages/AgentWsEnvelopeTest.java b/api/src/test/java/com/intuit/tank/vm/agent/messages/AgentWsEnvelopeTest.java
new file mode 100644
index 000000000..2f7aebaa4
--- /dev/null
+++ b/api/src/test/java/com/intuit/tank/vm/agent/messages/AgentWsEnvelopeTest.java
@@ -0,0 +1,184 @@
+package com.intuit.tank.vm.agent.messages;
+
+import com.intuit.tank.vm.agent.messages.AgentWsEnvelope.AckStatus;
+import com.intuit.tank.vm.agent.messages.AgentWsEnvelope.Type;
+import org.junit.jupiter.api.Test;
+
+import java.io.IOException;
+
+import static org.junit.jupiter.api.Assertions.*;
+
+public class AgentWsEnvelopeTest {
+
+ @Test
+ public void testHelloFactory() throws IOException {
+ AgentWsEnvelope env = AgentWsEnvelope.hello("i-123", "job-1", "sess-1", "cmd-99");
+
+ assertEquals(Type.hello, env.getType());
+ assertEquals("i-123", env.getInstanceId());
+ assertEquals("job-1", env.getJobId());
+ assertEquals("sess-1", env.getAgentSessionId());
+ assertEquals("cmd-99", env.getLastAppliedCommandId());
+ assertEquals(AgentWsEnvelope.PROTOCOL_VERSION, env.getProtocolVersion());
+ assertTrue(env.getSentAtMs() > 0);
+ }
+
+ @Test
+ public void testHelloFactoryWithCapacity() throws IOException {
+ AgentWsEnvelope env = AgentWsEnvelope.hello("i-123", "job-1", "sess-1", "cmd-99", 4000);
+
+ assertEquals(Type.hello, env.getType());
+ assertEquals("i-123", env.getInstanceId());
+ assertEquals("job-1", env.getJobId());
+ assertEquals("sess-1", env.getAgentSessionId());
+ assertEquals("cmd-99", env.getLastAppliedCommandId());
+ assertEquals(4000, env.getCapacity());
+
+ String json = env.toJson();
+ AgentWsEnvelope parsed = AgentWsEnvelope.fromJson(json);
+ assertEquals(4000, parsed.getCapacity());
+ }
+
+ @Test
+ public void testCommandFactory() throws IOException {
+ AgentWsEnvelope env = AgentWsEnvelope.command("cmd-1", "i-123", "job-1", "start");
+
+ assertEquals(Type.command, env.getType());
+ assertEquals("cmd-1", env.getCommandId());
+ assertEquals("i-123", env.getInstanceId());
+ assertEquals("job-1", env.getJobId());
+ assertEquals("start", env.getCommand());
+ }
+
+ @Test
+ public void testAckFactory() {
+ AgentWsEnvelope env = AgentWsEnvelope.ack("i-123", "command", "cmd-1", AckStatus.ok);
+
+ assertEquals(Type.ack, env.getType());
+ assertEquals("i-123", env.getInstanceId());
+ assertEquals("command", env.getAckForType());
+ assertEquals("cmd-1", env.getAckForId());
+ assertEquals(AckStatus.ok, env.getStatus());
+ }
+
+ @Test
+ public void testPingFactory() {
+ AgentWsEnvelope env = AgentWsEnvelope.ping("ping-1");
+
+ assertEquals(Type.ping, env.getType());
+ assertEquals("ping-1", env.getPingId());
+ }
+
+ @Test
+ public void testPongFactory() {
+ AgentWsEnvelope env = AgentWsEnvelope.pong("i-123", "sess-1", "ping-1", "cmd-5");
+
+ assertEquals(Type.pong, env.getType());
+ assertEquals("i-123", env.getInstanceId());
+ assertEquals("sess-1", env.getAgentSessionId());
+ assertEquals("ping-1", env.getPingId());
+ assertEquals("cmd-5", env.getLastAppliedCommandId());
+ }
+
+ @Test
+ public void testCloseFactory() {
+ AgentWsEnvelope env = AgentWsEnvelope.close("i-123", "shutdown", "Agent shutting down");
+
+ assertEquals(Type.close, env.getType());
+ assertEquals("i-123", env.getInstanceId());
+ assertEquals("shutdown", env.getReasonCode());
+ assertEquals("Agent shutting down", env.getReason());
+ }
+
+ @Test
+ public void testJsonRoundTrip() throws IOException {
+ AgentWsEnvelope original = AgentWsEnvelope.command("cmd-1", "i-123", "job-1", "stop");
+ String json = original.toJson();
+
+ assertNotNull(json);
+ assertTrue(json.contains("\"type\":\"command\""));
+ assertTrue(json.contains("\"commandId\":\"cmd-1\""));
+ assertTrue(json.contains("\"command\":\"stop\""));
+
+ AgentWsEnvelope parsed = AgentWsEnvelope.fromJson(json);
+ assertEquals(Type.command, parsed.getType());
+ assertEquals("cmd-1", parsed.getCommandId());
+ assertEquals("i-123", parsed.getInstanceId());
+ assertEquals("job-1", parsed.getJobId());
+ assertEquals("stop", parsed.getCommand());
+ }
+
+ @Test
+ public void testJsonNullFieldsOmitted() throws IOException {
+ AgentWsEnvelope env = AgentWsEnvelope.ping("ping-1");
+ String json = env.toJson();
+
+ // Null fields should not be present
+ assertFalse(json.contains("\"instanceId\""));
+ assertFalse(json.contains("\"jobId\""));
+ assertFalse(json.contains("\"commandId\""));
+ }
+
+ @Test
+ public void testFromJsonUnknownFieldsIgnored() throws IOException {
+ String json = "{\"type\":\"hello\",\"instanceId\":\"i-1\",\"jobId\":\"j-1\",\"agentSessionId\":\"s-1\",\"unknownField\":\"value\",\"sentAtMs\":1000,\"protocolVersion\":1}";
+ AgentWsEnvelope env = AgentWsEnvelope.fromJson(json);
+
+ assertEquals(Type.hello, env.getType());
+ assertEquals("i-1", env.getInstanceId());
+ assertEquals("j-1", env.getJobId());
+ }
+
+ @Test
+ public void testFromJsonInvalidType() {
+ String json = "{\"type\":\"bogus\",\"instanceId\":\"i-1\"}";
+ assertThrows(IOException.class, () -> AgentWsEnvelope.fromJson(json));
+ }
+
+ @Test
+ public void testFromJsonMalformed() {
+ assertThrows(IOException.class, () -> AgentWsEnvelope.fromJson("not json"));
+ }
+
+ @Test
+ public void testFromJsonEmptyObject() throws IOException {
+ AgentWsEnvelope env = AgentWsEnvelope.fromJson("{}");
+ assertNull(env.getType());
+ assertNull(env.getInstanceId());
+ }
+
+ @Test
+ public void testAckStatusValues() {
+ assertEquals(4, AckStatus.values().length);
+ assertNotNull(AckStatus.valueOf("ok"));
+ assertNotNull(AckStatus.valueOf("duplicate"));
+ assertNotNull(AckStatus.valueOf("failed"));
+ assertNotNull(AckStatus.valueOf("unsupported"));
+ }
+
+ @Test
+ public void testTypeValues() {
+ assertEquals(6, Type.values().length);
+ assertNotNull(Type.valueOf("hello"));
+ assertNotNull(Type.valueOf("command"));
+ assertNotNull(Type.valueOf("ack"));
+ assertNotNull(Type.valueOf("ping"));
+ assertNotNull(Type.valueOf("pong"));
+ assertNotNull(Type.valueOf("close"));
+ }
+
+ @Test
+ public void testAckRoundTrip() throws IOException {
+ AgentWsEnvelope ack = AgentWsEnvelope.ack("i-123", "command", "cmd-1", AckStatus.duplicate);
+ ack.setError("already applied");
+ ack.setAgentSessionId("sess-1");
+
+ String json = ack.toJson();
+ AgentWsEnvelope parsed = AgentWsEnvelope.fromJson(json);
+
+ assertEquals(Type.ack, parsed.getType());
+ assertEquals(AckStatus.duplicate, parsed.getStatus());
+ assertEquals("already applied", parsed.getError());
+ assertEquals("sess-1", parsed.getAgentSessionId());
+ }
+}
diff --git a/api/src/test/java/com/intuit/tank/vm/settings/AgentConfigCpTest.java b/api/src/test/java/com/intuit/tank/vm/settings/AgentConfigCpTest.java
index 23065e86a..6e119e915 100644
--- a/api/src/test/java/com/intuit/tank/vm/settings/AgentConfigCpTest.java
+++ b/api/src/test/java/com/intuit/tank/vm/settings/AgentConfigCpTest.java
@@ -480,4 +480,17 @@ public void testSetResultsTypeMap_1()
fixture.setResultsTypeMap(resultsTypeMap);
}
+
+ @Test
+ public void testWsConfigDefaults() throws Exception {
+ AgentConfig fixture = new AgentConfig(new BasicConfigurationBuilder<>(XMLConfiguration.class).getConfiguration());
+
+ assertFalse(fixture.isCommandWsEnabled());
+ assertTrue(fixture.isCommandWsHttpFallbackEnabled());
+ assertEquals(3000L, fixture.getCommandWsAckTimeoutMillis());
+ assertEquals("/v2/agent/ws/control", fixture.getCommandWsPath());
+ assertFalse(fixture.isControllerInitiatedWsEnabled());
+ assertTrue(fixture.isControllerInitiatedWsDisableAgentHttp());
+ assertEquals("script.xml", fixture.getControllerInitiatedWsScriptPath());
+ }
}
\ No newline at end of file
diff --git a/pom.xml b/pom.xml
index 24a06fca5..c5056e1a3 100644
--- a/pom.xml
+++ b/pom.xml
@@ -769,6 +769,11 @@
spring-webmvc
${version.spring-webmvc}
+
+ org.springframework
+ spring-websocket
+ ${version.spring-webmvc}
+
org.wiremock
wiremock
diff --git a/rest-mvc/impl/pom.xml b/rest-mvc/impl/pom.xml
index e562670d4..a7e5eeff0 100644
--- a/rest-mvc/impl/pom.xml
+++ b/rest-mvc/impl/pom.xml
@@ -47,6 +47,10 @@
tank-script-processor
${project.version}
+
+ org.springframework
+ spring-websocket
+
org.apache.tomcat
tomcat-coyote
diff --git a/rest-mvc/impl/src/main/java/com/intuit/tank/rest/mvc/AgentCommandWebSocketConfig.java b/rest-mvc/impl/src/main/java/com/intuit/tank/rest/mvc/AgentCommandWebSocketConfig.java
new file mode 100644
index 000000000..689c40a60
--- /dev/null
+++ b/rest-mvc/impl/src/main/java/com/intuit/tank/rest/mvc/AgentCommandWebSocketConfig.java
@@ -0,0 +1,34 @@
+package com.intuit.tank.rest.mvc;
+
+import com.intuit.tank.vm.settings.TankConfig;
+import org.springframework.context.annotation.Bean;
+import org.springframework.context.annotation.Configuration;
+import org.springframework.web.socket.config.annotation.EnableWebSocket;
+import org.springframework.web.socket.config.annotation.WebSocketConfigurer;
+import org.springframework.web.socket.config.annotation.WebSocketHandlerRegistry;
+
+@Configuration
+@EnableWebSocket
+public class AgentCommandWebSocketConfig implements WebSocketConfigurer {
+
+ @Bean
+ public AgentCommandWebSocketHandler agentCommandWebSocketHandler() {
+ AgentCommandWebSocketHandler handler = new AgentCommandWebSocketHandler();
+ AgentWsCommandSenderHolder.set(handler);
+ return handler;
+ }
+
+ @Override
+ public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {
+ String path = "/v2/agent/ws/control";
+ try {
+ path = new TankConfig().getAgentConfig().getCommandWsPath();
+ } catch (Exception e) {
+ // Fall back to default path if config not available during Spring init
+ }
+ // Agents connect via JDK HttpClient (no Origin header), not browsers.
+ // Allow all origins since auth is handled via bearer token in handshake.
+ registry.addHandler(agentCommandWebSocketHandler(), path)
+ .setAllowedOriginPatterns("*");
+ }
+}
diff --git a/rest-mvc/impl/src/main/java/com/intuit/tank/rest/mvc/AgentCommandWebSocketHandler.java b/rest-mvc/impl/src/main/java/com/intuit/tank/rest/mvc/AgentCommandWebSocketHandler.java
new file mode 100644
index 000000000..7bdee2ce8
--- /dev/null
+++ b/rest-mvc/impl/src/main/java/com/intuit/tank/rest/mvc/AgentCommandWebSocketHandler.java
@@ -0,0 +1,204 @@
+package com.intuit.tank.rest.mvc;
+
+import com.intuit.tank.vm.agent.messages.AgentWsCommandSender;
+import com.intuit.tank.vm.agent.messages.AgentWsEnvelope;
+import com.intuit.tank.vm.agent.messages.AgentWsEnvelope.AckStatus;
+import com.intuit.tank.vm.agent.messages.AgentWsEnvelope.Type;
+import org.apache.logging.log4j.LogManager;
+import org.apache.logging.log4j.Logger;
+import org.apache.logging.log4j.message.ObjectMessage;
+import org.springframework.web.socket.CloseStatus;
+import org.springframework.web.socket.TextMessage;
+import org.springframework.web.socket.WebSocketSession;
+import org.springframework.web.socket.handler.TextWebSocketHandler;
+
+import java.io.IOException;
+import java.util.Map;
+import java.util.UUID;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.TimeUnit;
+
+public class AgentCommandWebSocketHandler extends TextWebSocketHandler implements AgentWsCommandSender {
+
+ private static final Logger LOG = LogManager.getLogger(AgentCommandWebSocketHandler.class);
+
+ // instanceId -> active WS session
+ private final ConcurrentHashMap agentSessions = new ConcurrentHashMap<>();
+
+ // sessionId -> bound instanceId (one identity per session, immutable after hello)
+ private final ConcurrentHashMap sessionIdentity = new ConcurrentHashMap<>();
+
+ // instanceId -> last seen timestamp
+ private final ConcurrentHashMap agentLastSeen = new ConcurrentHashMap<>();
+
+ // commandId -> ack future (for pending ack tracking)
+ private final ConcurrentHashMap> pendingAcks = new ConcurrentHashMap<>();
+
+ @Override
+ public void afterConnectionEstablished(WebSocketSession session) {
+ LOG.info(new ObjectMessage(Map.of("Message", "WS connection established: " + session.getId())));
+ }
+
+ @Override
+ protected void handleTextMessage(WebSocketSession session, TextMessage message) throws Exception {
+ AgentWsEnvelope envelope;
+ try {
+ envelope = AgentWsEnvelope.fromJson(message.getPayload());
+ } catch (IOException e) {
+ LOG.warn(new ObjectMessage(Map.of("Message", "Invalid WS frame, closing session: " + e.getMessage())));
+ session.close(CloseStatus.BAD_DATA);
+ return;
+ }
+
+ if (envelope.getType() == null) {
+ LOG.warn(new ObjectMessage(Map.of("Message", "WS frame missing type, closing session")));
+ session.close(CloseStatus.BAD_DATA);
+ return;
+ }
+
+ // Require hello before any other frame type
+ if (envelope.getType() != Type.hello && !sessionIdentity.containsKey(session.getId())) {
+ LOG.warn(new ObjectMessage(Map.of("Message", "WS frame received before hello, closing session")));
+ session.close(CloseStatus.POLICY_VIOLATION);
+ return;
+ }
+
+ switch (envelope.getType()) {
+ case hello -> handleHello(session, envelope);
+ case ack -> handleAck(session, envelope);
+ case pong -> handlePong(session, envelope);
+ case close -> handleClose(session, envelope);
+ default -> {
+ LOG.warn(new ObjectMessage(Map.of("Message", "Unexpected WS frame type from agent: " + envelope.getType())));
+ sendAck(session, envelope.getInstanceId(), "hello", session.getId(), AckStatus.unsupported);
+ }
+ }
+ }
+
+ private void handleHello(WebSocketSession session, AgentWsEnvelope envelope) throws IOException {
+ String instanceId = envelope.getInstanceId();
+ if (instanceId == null || envelope.getJobId() == null || envelope.getAgentSessionId() == null) {
+ LOG.warn(new ObjectMessage(Map.of("Message", "Hello frame missing required fields, closing")));
+ session.close(CloseStatus.BAD_DATA);
+ return;
+ }
+
+ // Check if this session already has a bound identity — reject rebind
+ String existingIdentity = sessionIdentity.get(session.getId());
+ if (existingIdentity != null && !existingIdentity.equals(instanceId)) {
+ LOG.warn(new ObjectMessage(Map.of("Message", "Session " + session.getId() + " already bound to " + existingIdentity
+ + ", rejecting rebind to " + instanceId)));
+ session.close(CloseStatus.POLICY_VIOLATION);
+ return;
+ }
+
+ // Bind identity to session
+ sessionIdentity.put(session.getId(), instanceId);
+
+ // Replace old session if agent reconnects from a different connection
+ WebSocketSession oldSession = agentSessions.put(instanceId, session);
+ if (oldSession != null && oldSession.isOpen() && !oldSession.getId().equals(session.getId())) {
+ LOG.info(new ObjectMessage(Map.of("Message", "Replacing old WS session for agent " + instanceId)));
+ sessionIdentity.remove(oldSession.getId());
+ try { oldSession.close(CloseStatus.GOING_AWAY); } catch (IOException ignored) {}
+ }
+
+ agentLastSeen.put(instanceId, System.currentTimeMillis());
+ LOG.info(new ObjectMessage(Map.of("Message", "Agent " + instanceId + " registered via WS for job " + envelope.getJobId()
+ + " (agentSession=" + envelope.getAgentSessionId() + ")")));
+
+ sendAck(session, instanceId, "hello", envelope.getAgentSessionId(), AckStatus.ok);
+ }
+
+ private void handleAck(WebSocketSession session, AgentWsEnvelope envelope) {
+ String ackForId = envelope.getAckForId();
+ if (ackForId != null) {
+ CompletableFuture future = pendingAcks.remove(ackForId);
+ if (future != null) {
+ future.complete(envelope);
+ }
+ }
+ String boundId = sessionIdentity.get(session.getId());
+ if (boundId != null) {
+ agentLastSeen.put(boundId, System.currentTimeMillis());
+ }
+ }
+
+ private void handlePong(WebSocketSession session, AgentWsEnvelope envelope) {
+ String boundId = sessionIdentity.get(session.getId());
+ if (boundId != null) {
+ agentLastSeen.put(boundId, System.currentTimeMillis());
+ }
+ }
+
+ private void handleClose(WebSocketSession session, AgentWsEnvelope envelope) throws IOException {
+ String boundId = sessionIdentity.remove(session.getId());
+ if (boundId != null) {
+ LOG.info(new ObjectMessage(Map.of("Message", "Agent " + boundId + " sent close: " + envelope.getReason())));
+ agentSessions.remove(boundId, session);
+ agentLastSeen.remove(boundId);
+ }
+ session.close(CloseStatus.NORMAL);
+ }
+
+ @Override
+ public void afterConnectionClosed(WebSocketSession session, CloseStatus status) {
+ String boundId = sessionIdentity.remove(session.getId());
+ if (boundId != null) {
+ agentSessions.remove(boundId, session);
+ agentLastSeen.remove(boundId);
+ }
+ LOG.info(new ObjectMessage(Map.of("Message", "WS session closed: " + session.getId() + " status=" + status)));
+ }
+
+ @Override
+ public void handleTransportError(WebSocketSession session, Throwable exception) {
+ LOG.error(new ObjectMessage(Map.of("Message", "WS transport error for session " + session.getId() + ": " + exception.getMessage())), exception);
+ String boundId = sessionIdentity.remove(session.getId());
+ if (boundId != null) {
+ agentSessions.remove(boundId, session);
+ agentLastSeen.remove(boundId);
+ }
+ }
+
+ @Override
+ public boolean hasSession(String instanceId) {
+ WebSocketSession session = agentSessions.get(instanceId);
+ return session != null && session.isOpen();
+ }
+
+ @Override
+ public boolean sendCommand(String instanceId, String jobId, String command, long ackTimeoutMillis) {
+ WebSocketSession session = agentSessions.get(instanceId);
+ if (session == null || !session.isOpen()) {
+ return false;
+ }
+
+ String commandId = UUID.randomUUID().toString();
+ CompletableFuture ackFuture = new CompletableFuture<>();
+ pendingAcks.put(commandId, ackFuture);
+
+ try {
+ AgentWsEnvelope cmdEnvelope = AgentWsEnvelope.command(commandId, instanceId, jobId, command);
+ synchronized (session) {
+ session.sendMessage(new TextMessage(cmdEnvelope.toJson()));
+ }
+ LOG.info(new ObjectMessage(Map.of("Message", "Sent WS command " + command + " (id=" + commandId + ") to agent " + instanceId)));
+
+ AgentWsEnvelope ack = ackFuture.get(ackTimeoutMillis, TimeUnit.MILLISECONDS);
+ return ack != null && ack.getStatus() == AckStatus.ok;
+ } catch (Exception e) {
+ LOG.warn(new ObjectMessage(Map.of("Message", "WS command " + command + " to agent " + instanceId + " failed: " + e.getMessage())));
+ pendingAcks.remove(commandId);
+ return false;
+ }
+ }
+
+ private void sendAck(WebSocketSession session, String instanceId, String ackForType, String ackForId, AckStatus status) throws IOException {
+ AgentWsEnvelope ack = AgentWsEnvelope.ack(instanceId, ackForType, ackForId, status);
+ synchronized (session) {
+ session.sendMessage(new TextMessage(ack.toJson()));
+ }
+ }
+}
diff --git a/rest-mvc/impl/src/main/java/com/intuit/tank/rest/mvc/AgentWsCommandSenderHolder.java b/rest-mvc/impl/src/main/java/com/intuit/tank/rest/mvc/AgentWsCommandSenderHolder.java
new file mode 100644
index 000000000..2ad5da825
--- /dev/null
+++ b/rest-mvc/impl/src/main/java/com/intuit/tank/rest/mvc/AgentWsCommandSenderHolder.java
@@ -0,0 +1,20 @@
+package com.intuit.tank.rest.mvc;
+
+import com.intuit.tank.vm.agent.messages.AgentWsCommandSender;
+
+/**
+ * Static holder so the Spring-managed WS handler can be accessed by CDI producer.
+ * Set by AgentCommandWebSocketConfig at Spring init time.
+ */
+public class AgentWsCommandSenderHolder {
+
+ private static volatile AgentWsCommandSender instance;
+
+ public static void set(AgentWsCommandSender sender) {
+ instance = sender;
+ }
+
+ public static AgentWsCommandSender get() {
+ return instance;
+ }
+}
diff --git a/rest-mvc/impl/src/main/java/com/intuit/tank/rest/mvc/AgentWsCommandSenderProducer.java b/rest-mvc/impl/src/main/java/com/intuit/tank/rest/mvc/AgentWsCommandSenderProducer.java
new file mode 100644
index 000000000..e607b39ab
--- /dev/null
+++ b/rest-mvc/impl/src/main/java/com/intuit/tank/rest/mvc/AgentWsCommandSenderProducer.java
@@ -0,0 +1,18 @@
+package com.intuit.tank.rest.mvc;
+
+import com.intuit.tank.vm.agent.messages.AgentWsCommandSender;
+import jakarta.enterprise.context.ApplicationScoped;
+import jakarta.enterprise.inject.Produces;
+
+/**
+ * CDI producer that bridges the Spring-managed AgentCommandWebSocketHandler
+ * into the CDI container so JobManager can inject it via Instance.
+ */
+@ApplicationScoped
+public class AgentWsCommandSenderProducer {
+
+ @Produces
+ public AgentWsCommandSender getAgentWsCommandSender() {
+ return AgentWsCommandSenderHolder.get();
+ }
+}
diff --git a/rest-mvc/impl/src/test/java/com/intuit/tank/rest/mvc/AgentCommandWebSocketHandlerTest.java b/rest-mvc/impl/src/test/java/com/intuit/tank/rest/mvc/AgentCommandWebSocketHandlerTest.java
new file mode 100644
index 000000000..2d1631098
--- /dev/null
+++ b/rest-mvc/impl/src/test/java/com/intuit/tank/rest/mvc/AgentCommandWebSocketHandlerTest.java
@@ -0,0 +1,241 @@
+package com.intuit.tank.rest.mvc;
+
+import com.intuit.tank.vm.agent.messages.AgentWsEnvelope;
+import com.intuit.tank.vm.agent.messages.AgentWsEnvelope.AckStatus;
+import com.intuit.tank.vm.agent.messages.AgentWsEnvelope.Type;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.springframework.web.socket.CloseStatus;
+import org.springframework.web.socket.TextMessage;
+import org.springframework.web.socket.WebSocketSession;
+
+import java.io.IOException;
+
+import static org.junit.jupiter.api.Assertions.*;
+import static org.mockito.Mockito.*;
+
+public class AgentCommandWebSocketHandlerTest {
+
+ private AgentCommandWebSocketHandler handler;
+ private WebSocketSession session;
+
+ @BeforeEach
+ void setUp() {
+ handler = new AgentCommandWebSocketHandler();
+ session = mock(WebSocketSession.class);
+ when(session.getId()).thenReturn("test-session-1");
+ when(session.isOpen()).thenReturn(true);
+ }
+
+ @Test
+ void testHelloRegistersSession() throws Exception {
+ AgentWsEnvelope hello = AgentWsEnvelope.hello("i-123", "job-1", "sess-1", null);
+ handler.handleTextMessage(session, new TextMessage(hello.toJson()));
+
+ assertTrue(handler.hasSession("i-123"));
+
+ // Verify ack was sent
+ verify(session).sendMessage(argThat(msg -> {
+ String payload = ((TextMessage) msg).getPayload();
+ return payload.contains("\"type\":\"ack\"") && payload.contains("\"status\":\"ok\"");
+ }));
+ }
+
+ @Test
+ void testHelloReRegisterReplacesSession() throws Exception {
+ AgentWsEnvelope hello1 = AgentWsEnvelope.hello("i-123", "job-1", "sess-1", null);
+ handler.handleTextMessage(session, new TextMessage(hello1.toJson()));
+
+ WebSocketSession session2 = mock(WebSocketSession.class);
+ when(session2.getId()).thenReturn("test-session-2");
+ when(session2.isOpen()).thenReturn(true);
+
+ AgentWsEnvelope hello2 = AgentWsEnvelope.hello("i-123", "job-1", "sess-2", null);
+ handler.handleTextMessage(session2, new TextMessage(hello2.toJson()));
+
+ assertTrue(handler.hasSession("i-123"));
+ // Old session should have been closed
+ verify(session).close(CloseStatus.GOING_AWAY);
+ }
+
+ @Test
+ void testHelloMissingFieldsClosesSession() throws Exception {
+ // Missing jobId and agentSessionId
+ AgentWsEnvelope badHello = new AgentWsEnvelope();
+ badHello.setType(Type.hello);
+ badHello.setInstanceId("i-123");
+ badHello.setSentAtMs(System.currentTimeMillis());
+
+ handler.handleTextMessage(session, new TextMessage(badHello.toJson()));
+
+ assertFalse(handler.hasSession("i-123"));
+ verify(session).close(CloseStatus.BAD_DATA);
+ }
+
+ @Test
+ void testInvalidJsonClosesSession() throws Exception {
+ handler.handleTextMessage(session, new TextMessage("not valid json"));
+
+ verify(session).close(CloseStatus.BAD_DATA);
+ }
+
+ @Test
+ void testMissingTypeClosesSession() throws Exception {
+ handler.handleTextMessage(session, new TextMessage("{\"instanceId\":\"i-123\"}"));
+
+ verify(session).close(CloseStatus.BAD_DATA);
+ }
+
+ @Test
+ void testPongUpdatesLastSeen() throws Exception {
+ // Register first
+ AgentWsEnvelope hello = AgentWsEnvelope.hello("i-123", "job-1", "sess-1", null);
+ handler.handleTextMessage(session, new TextMessage(hello.toJson()));
+
+ // Send pong
+ AgentWsEnvelope pong = AgentWsEnvelope.pong("i-123", "sess-1", "ping-1", "cmd-5");
+ handler.handleTextMessage(session, new TextMessage(pong.toJson()));
+
+ // Session should still be active
+ assertTrue(handler.hasSession("i-123"));
+ }
+
+ @Test
+ void testCloseUnregistersSession() throws Exception {
+ // Register
+ AgentWsEnvelope hello = AgentWsEnvelope.hello("i-123", "job-1", "sess-1", null);
+ handler.handleTextMessage(session, new TextMessage(hello.toJson()));
+ assertTrue(handler.hasSession("i-123"));
+
+ // Close
+ AgentWsEnvelope close = AgentWsEnvelope.close("i-123", "shutdown", "done");
+ handler.handleTextMessage(session, new TextMessage(close.toJson()));
+
+ assertFalse(handler.hasSession("i-123"));
+ }
+
+ @Test
+ void testAfterConnectionClosedCleansUp() throws Exception {
+ AgentWsEnvelope hello = AgentWsEnvelope.hello("i-123", "job-1", "sess-1", null);
+ handler.handleTextMessage(session, new TextMessage(hello.toJson()));
+ assertTrue(handler.hasSession("i-123"));
+
+ handler.afterConnectionClosed(session, CloseStatus.NORMAL);
+
+ assertFalse(handler.hasSession("i-123"));
+ }
+
+ @Test
+ void testHasSessionReturnsFalseForUnknown() {
+ assertFalse(handler.hasSession("i-unknown"));
+ }
+
+ @Test
+ void testHasSessionReturnsFalseForClosedSession() throws Exception {
+ AgentWsEnvelope hello = AgentWsEnvelope.hello("i-123", "job-1", "sess-1", null);
+ handler.handleTextMessage(session, new TextMessage(hello.toJson()));
+
+ when(session.isOpen()).thenReturn(false);
+ assertFalse(handler.hasSession("i-123"));
+ }
+
+ @Test
+ void testSendCommandReturnsFalseWhenNoSession() {
+ boolean result = handler.sendCommand("i-unknown", "job-1", "start", 1000);
+ assertFalse(result);
+ }
+
+ @Test
+ void testSendCommandReturnsFalseWhenSessionClosed() throws Exception {
+ AgentWsEnvelope hello = AgentWsEnvelope.hello("i-123", "job-1", "sess-1", null);
+ handler.handleTextMessage(session, new TextMessage(hello.toJson()));
+
+ when(session.isOpen()).thenReturn(false);
+
+ boolean result = handler.sendCommand("i-123", "job-1", "start", 1000);
+ assertFalse(result);
+ }
+
+ @Test
+ void testSendCommandTimesOutWithNoAck() throws Exception {
+ AgentWsEnvelope hello = AgentWsEnvelope.hello("i-123", "job-1", "sess-1", null);
+ handler.handleTextMessage(session, new TextMessage(hello.toJson()));
+
+ // Send command with very short timeout - no ack will come
+ boolean result = handler.sendCommand("i-123", "job-1", "start", 50);
+ assertFalse(result);
+ }
+
+ @Test
+ void testSendCommandSucceedsWithAck() throws Exception {
+ AgentWsEnvelope hello = AgentWsEnvelope.hello("i-123", "job-1", "sess-1", null);
+ handler.handleTextMessage(session, new TextMessage(hello.toJson()));
+
+ // Simulate the ack arriving when command is sent
+ doAnswer(invocation -> {
+ TextMessage msg = invocation.getArgument(0);
+ AgentWsEnvelope cmd = AgentWsEnvelope.fromJson(msg.getPayload());
+ if (cmd.getType() == Type.command) {
+ // Simulate agent ack
+ AgentWsEnvelope ack = AgentWsEnvelope.ack("i-123", "command", cmd.getCommandId(), AckStatus.ok);
+ handler.handleTextMessage(session, new TextMessage(ack.toJson()));
+ }
+ return null;
+ }).when(session).sendMessage(any(TextMessage.class));
+
+ boolean result = handler.sendCommand("i-123", "job-1", "start", 5000);
+ assertTrue(result);
+ }
+
+ @Test
+ void testTransportErrorCleansUpSession() throws Exception {
+ AgentWsEnvelope hello = AgentWsEnvelope.hello("i-123", "job-1", "sess-1", null);
+ handler.handleTextMessage(session, new TextMessage(hello.toJson()));
+ assertTrue(handler.hasSession("i-123"));
+
+ handler.handleTransportError(session, new RuntimeException("connection lost"));
+
+ assertFalse(handler.hasSession("i-123"));
+ }
+
+ @Test
+ void testFrameBeforeHelloRejected() throws Exception {
+ // Send pong without hello first
+ AgentWsEnvelope pong = AgentWsEnvelope.pong("i-123", "sess-1", "ping-1", null);
+ handler.handleTextMessage(session, new TextMessage(pong.toJson()));
+
+ verify(session).close(CloseStatus.POLICY_VIOLATION);
+ }
+
+ @Test
+ void testIdentityRebindRejected() throws Exception {
+ // Register with one instanceId
+ AgentWsEnvelope hello1 = AgentWsEnvelope.hello("i-123", "job-1", "sess-1", null);
+ handler.handleTextMessage(session, new TextMessage(hello1.toJson()));
+ assertTrue(handler.hasSession("i-123"));
+
+ // Try to rebind same session to different instanceId
+ AgentWsEnvelope hello2 = AgentWsEnvelope.hello("i-999", "job-1", "sess-1", null);
+ handler.handleTextMessage(session, new TextMessage(hello2.toJson()));
+
+ // Session should be closed due to policy violation
+ verify(session).close(CloseStatus.POLICY_VIOLATION);
+ // Original identity should NOT be replaced
+ assertFalse(handler.hasSession("i-999"));
+ }
+
+ @Test
+ void testSameIdentityReHelloAllowed() throws Exception {
+ // Register
+ AgentWsEnvelope hello1 = AgentWsEnvelope.hello("i-123", "job-1", "sess-1", null);
+ handler.handleTextMessage(session, new TextMessage(hello1.toJson()));
+
+ // Re-hello with same instanceId on same session (e.g., agent restart detection)
+ AgentWsEnvelope hello2 = AgentWsEnvelope.hello("i-123", "job-1", "sess-2", null);
+ handler.handleTextMessage(session, new TextMessage(hello2.toJson()));
+
+ assertTrue(handler.hasSession("i-123"));
+ // Should NOT have been closed with policy violation
+ verify(session, never()).close(CloseStatus.POLICY_VIOLATION);
+ }
+}
diff --git a/tank_vmManager/src/main/java/com/intuit/tank/perfManager/workLoads/ControllerInitiatedAgentWsClient.java b/tank_vmManager/src/main/java/com/intuit/tank/perfManager/workLoads/ControllerInitiatedAgentWsClient.java
new file mode 100644
index 000000000..212c1ca52
--- /dev/null
+++ b/tank_vmManager/src/main/java/com/intuit/tank/perfManager/workLoads/ControllerInitiatedAgentWsClient.java
@@ -0,0 +1,215 @@
+package com.intuit.tank.perfManager.workLoads;
+
+import com.intuit.tank.vm.agent.messages.AgentWsEnvelope;
+import com.intuit.tank.vm.agent.messages.AgentWsEnvelope.AckStatus;
+import org.apache.logging.log4j.LogManager;
+import org.apache.logging.log4j.Logger;
+import org.apache.logging.log4j.message.ObjectMessage;
+
+import jakarta.enterprise.context.ApplicationScoped;
+import java.io.IOException;
+import java.net.URI;
+import java.net.http.HttpClient;
+import java.net.http.WebSocket;
+import java.nio.ByteBuffer;
+import java.util.Map;
+import java.util.Optional;
+import java.util.UUID;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.CompletionStage;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicBoolean;
+
+@ApplicationScoped
+public class ControllerInitiatedAgentWsClient {
+
+ private static final Logger LOG = LogManager.getLogger(ControllerInitiatedAgentWsClient.class);
+
+ private final HttpClient httpClient = HttpClient.newHttpClient();
+
+ private final ConcurrentHashMap sessions = new ConcurrentHashMap<>();
+ private final ConcurrentHashMap> pendingAcks = new ConcurrentHashMap<>();
+
+ public Optional connect(String instanceId, String wsUrl, String token, long helloTimeoutMillis) {
+ try {
+ SessionContext existing = sessions.get(instanceId);
+ if (existing != null && existing.isOpen()) {
+ return Optional.empty();
+ }
+
+ CompletableFuture helloFuture = new CompletableFuture<>();
+ Listener listener = new Listener(instanceId, helloFuture);
+
+ WebSocket webSocket = httpClient.newWebSocketBuilder()
+ .header("Authorization", "bearer " + token)
+ .buildAsync(URI.create(wsUrl), listener)
+ .join();
+
+ SessionContext context = new SessionContext(webSocket, helloFuture);
+ SessionContext previous = sessions.put(instanceId, context);
+ if (previous != null) {
+ previous.close();
+ }
+
+ AgentWsEnvelope hello = helloFuture.get(helloTimeoutMillis, TimeUnit.MILLISECONDS);
+ if (hello != null) {
+ LOG.info(new ObjectMessage(Map.of("Message",
+ "Controller initiated WS connected to agent " + instanceId + " via " + wsUrl)));
+ return Optional.of(hello);
+ }
+ } catch (Exception e) {
+ LOG.warn(new ObjectMessage(Map.of("Message",
+ "Failed to connect WS to agent " + instanceId + ": " + e.getMessage())));
+ }
+
+ SessionContext failed = sessions.remove(instanceId);
+ if (failed != null) {
+ failed.close();
+ }
+ return Optional.empty();
+ }
+
+ public boolean hasSession(String instanceId) {
+ SessionContext context = sessions.get(instanceId);
+ return context != null && context.isOpen();
+ }
+
+ public boolean sendCommand(String instanceId, String jobId, String command, long ackTimeoutMillis) {
+ SessionContext context = sessions.get(instanceId);
+ if (context == null || !context.isOpen()) {
+ return false;
+ }
+
+ String commandId = UUID.randomUUID().toString();
+ CompletableFuture ackFuture = new CompletableFuture<>();
+ pendingAcks.put(commandId, ackFuture);
+
+ try {
+ AgentWsEnvelope cmd = AgentWsEnvelope.command(commandId, instanceId, jobId, command);
+ context.send(cmd.toJson());
+
+ AgentWsEnvelope ack = ackFuture.get(ackTimeoutMillis, TimeUnit.MILLISECONDS);
+ return ack != null && ack.getStatus() == AckStatus.ok;
+ } catch (Exception e) {
+ LOG.warn(new ObjectMessage(Map.of("Message",
+ "WS command " + command + " to " + instanceId + " failed: " + e.getMessage())));
+ return false;
+ } finally {
+ pendingAcks.remove(commandId);
+ }
+ }
+
+ private void handleText(String instanceId, String text, CompletableFuture helloFuture) {
+ try {
+ AgentWsEnvelope envelope = AgentWsEnvelope.fromJson(text);
+ if (envelope.getType() == null) {
+ return;
+ }
+ switch (envelope.getType()) {
+ case hello -> helloFuture.complete(envelope);
+ case ack -> {
+ String ackForId = envelope.getAckForId();
+ if (ackForId != null) {
+ CompletableFuture future = pendingAcks.remove(ackForId);
+ if (future != null) {
+ future.complete(envelope);
+ }
+ }
+ }
+ default -> {
+ // no-op for pong/close/ping in PoC client path
+ }
+ }
+ } catch (IOException ignored) {
+ }
+ }
+
+ private void onClosed(String instanceId) {
+ SessionContext context = sessions.remove(instanceId);
+ if (context != null) {
+ context.markClosed();
+ }
+ }
+
+ private class Listener implements WebSocket.Listener {
+ private final String instanceId;
+ private final CompletableFuture helloFuture;
+ private final StringBuilder messageBuffer = new StringBuilder();
+
+ private Listener(String instanceId, CompletableFuture helloFuture) {
+ this.instanceId = instanceId;
+ this.helloFuture = helloFuture;
+ }
+
+ @Override
+ public void onOpen(WebSocket webSocket) {
+ webSocket.request(1);
+ }
+
+ @Override
+ public CompletionStage> onText(WebSocket webSocket, CharSequence data, boolean last) {
+ messageBuffer.append(data);
+ if (last) {
+ String fullMessage = messageBuffer.toString();
+ messageBuffer.setLength(0);
+ handleText(instanceId, fullMessage, helloFuture);
+ }
+ webSocket.request(1);
+ return null;
+ }
+
+ @Override
+ public CompletionStage> onPing(WebSocket webSocket, ByteBuffer message) {
+ webSocket.request(1);
+ return null;
+ }
+
+ @Override
+ public CompletionStage> onClose(WebSocket webSocket, int statusCode, String reason) {
+ onClosed(instanceId);
+ return null;
+ }
+
+ @Override
+ public void onError(WebSocket webSocket, Throwable error) {
+ LOG.warn(new ObjectMessage(Map.of("Message",
+ "Controller initiated WS listener error for " + instanceId + ": " + error.getMessage())));
+ onClosed(instanceId);
+ }
+ }
+
+ private static class SessionContext {
+ private final WebSocket webSocket;
+ private final CompletableFuture helloFuture;
+ private final AtomicBoolean closed = new AtomicBoolean(false);
+
+ private SessionContext(WebSocket webSocket, CompletableFuture helloFuture) {
+ this.webSocket = webSocket;
+ this.helloFuture = helloFuture;
+ }
+
+ private boolean isOpen() {
+ return !closed.get() && !webSocket.isInputClosed() && !webSocket.isOutputClosed();
+ }
+
+ private void send(String text) {
+ webSocket.sendText(text, true).join();
+ }
+
+ private void close() {
+ if (closed.compareAndSet(false, true)) {
+ try {
+ webSocket.sendClose(WebSocket.NORMAL_CLOSURE, "Closing previous session");
+ } catch (Exception ignored) {
+ }
+ helloFuture.completeExceptionally(new IllegalStateException("Closed"));
+ }
+ }
+
+ private void markClosed() {
+ closed.set(true);
+ helloFuture.completeExceptionally(new IllegalStateException("Closed"));
+ }
+ }
+}
diff --git a/tank_vmManager/src/main/java/com/intuit/tank/perfManager/workLoads/JobManager.java b/tank_vmManager/src/main/java/com/intuit/tank/perfManager/workLoads/JobManager.java
index 3e57328ea..9f534e855 100644
--- a/tank_vmManager/src/main/java/com/intuit/tank/perfManager/workLoads/JobManager.java
+++ b/tank_vmManager/src/main/java/com/intuit/tank/perfManager/workLoads/JobManager.java
@@ -57,6 +57,7 @@
import com.intuit.tank.vm.agent.messages.DataFileRequest;
import com.intuit.tank.vm.agent.messages.StandaloneAgentRequest;
import com.intuit.tank.vm.perfManager.StandaloneAgentTracker;
+import com.intuit.tank.vm.settings.AgentConfig;
import com.intuit.tank.vm.settings.TankConfig;
import com.intuit.tank.vm.vmManager.JobRequest;
import com.intuit.tank.vm.vmManager.JobVmCalculator;
@@ -93,6 +94,12 @@ public class JobManager implements Serializable {
@Inject
private TankConfig tankConfig;
+ @Inject
+ private jakarta.enterprise.inject.Instance wsCommandSenderInstance;
+
+ @Inject
+ private jakarta.enterprise.inject.Instance controllerInitiatedWsClientInstance;
+
/**
* @param id
* @throws Exception
@@ -214,17 +221,65 @@ private void startTest(final JobInfo info) {
LOG.info(new ObjectMessage(Map.of("Message", "Start agents command received - Sending start commands for job " + jobId + " asynchronously to following agents: " +
info.agentData.stream().collect(Collectors.toMap(AgentData::getInstanceId, AgentData::getInstanceUrl)))));
}
- LOG.info(new ObjectMessage(Map.of("Message", "Sending START commands to " + info.agentData.size() +
+ LOG.info(new ObjectMessage(Map.of("Message", "Sending START commands to " + info.agentData.size() +
" agents for job " + jobId)));
+
+ AgentConfig agentConfig = tankConfig != null ? tankConfig.getAgentConfig() : null;
+ boolean wsEnabled = agentConfig != null && agentConfig.isCommandWsEnabled();
+ boolean controllerInitiatedWsEnabled = agentConfig != null && agentConfig.isControllerInitiatedWsEnabled();
+ boolean httpFallback = agentConfig == null || agentConfig.isCommandWsHttpFallbackEnabled();
+ long ackTimeout = agentConfig != null ? agentConfig.getCommandWsAckTimeoutMillis() : 3000L;
+ com.intuit.tank.vm.agent.messages.AgentWsCommandSender wsSender = getWsCommandSender();
+ ControllerInitiatedAgentWsClient controllerWsClient = getControllerInitiatedWsClient();
+
info.agentData.parallelStream()
- .map(agentData -> {
+ .forEach(agentData -> {
+ String instanceId = agentData.getInstanceId();
+
+ if (controllerInitiatedWsEnabled && controllerWsClient != null
+ && controllerWsClient.hasSession(instanceId)) {
+ boolean acked = controllerWsClient.sendCommand(instanceId, jobId, AgentCommand.start.name(), ackTimeout);
+ if (acked) {
+ LOG.info(new ObjectMessage(Map.of("Message",
+ "Controller-initiated WS START command to agent " + instanceId + " was SUCCESSFUL for job " + jobId)));
+ return;
+ }
+ LOG.warn(new ObjectMessage(Map.of("Message", "Controller-initiated WS START command to agent "
+ + instanceId + " failed, " + (httpFallback ? "falling back to HTTP" : "no fallback"))));
+ if (!httpFallback) {
+ return;
+ }
+ } else if (controllerInitiatedWsEnabled && !httpFallback) {
+ LOG.warn(new ObjectMessage(Map.of("Message", "Controller-initiated WS enabled but no session for agent "
+ + instanceId + " and fallback disabled, skipping")));
+ return;
+ }
+
+ // Try WS first if enabled
+ if (wsEnabled) {
+ if (wsSender != null && wsSender.hasSession(instanceId)) {
+ boolean acked = wsSender.sendCommand(instanceId, jobId, AgentCommand.start.name(), ackTimeout);
+ if (acked) {
+ LOG.info(new ObjectMessage(Map.of("Message", "WS START command to agent " + instanceId + " was SUCCESSFUL for job " + jobId)));
+ return;
+ }
+ LOG.warn(new ObjectMessage(Map.of("Message", "WS START command to agent " + instanceId + " failed, " +
+ (httpFallback ? "falling back to HTTP" : "no fallback"))));
+ if (!httpFallback) {
+ return;
+ }
+ } else if (!httpFallback) {
+ LOG.warn(new ObjectMessage(Map.of("Message", "WS enabled but " +
+ (wsSender == null ? "sender unavailable" : "no session") +
+ " for agent " + instanceId + " and fallback disabled, skipping")));
+ return;
+ }
+ }
+
+ // HTTP fallback (or WS not enabled)
String url = agentData.getInstanceUrl() + AgentCommand.start.getPath();
LOG.info(new ObjectMessage(Map.of("Message", "Sending command to url " + url)));
- return url;
- })
- .map(URI::create)
- .map(uri -> sendCommand(uri, MAX_RETRIES))
- .forEach(future -> {
+ CompletableFuture> future = sendCommand(URI.create(url), MAX_RETRIES);
HttpResponse response = (HttpResponse) future.join();
if (response != null && Set.of(HttpStatus.SC_OK, HttpStatus.SC_ACCEPTED).contains(response.statusCode())) {
LOG.info(new ObjectMessage(Map.of(
@@ -246,16 +301,110 @@ private CloudVmStatus createFailureStatus(AgentData data) {
/**
* Convert the List of instanceIds to instanceUrls to instanceUris and pass it to sendCommand(List, retry) to send asynchttp commands.
+ * Uses WS transport when available and enabled, falls back to HTTP.
* @param instanceIds
* @param cmd
* @return Array of CompletableFuture that will probably never be looked at.
*/
public List> sendCommand(List instanceIds, AgentCommand cmd) {
+ AgentConfig agentConfig = tankConfig != null ? tankConfig.getAgentConfig() : null;
+ boolean wsEnabled = agentConfig != null && agentConfig.isCommandWsEnabled();
+ boolean controllerInitiatedWsEnabled = agentConfig != null && agentConfig.isControllerInitiatedWsEnabled();
+ boolean httpFallback = agentConfig == null || agentConfig.isCommandWsHttpFallbackEnabled();
+ long ackTimeout = agentConfig != null ? agentConfig.getCommandWsAckTimeoutMillis() : 3000L;
+
+ com.intuit.tank.vm.agent.messages.AgentWsCommandSender wsSender = getWsCommandSender();
+ ControllerInitiatedAgentWsClient controllerWsClient = getControllerInitiatedWsClient();
+
+ // Resolve instanceId -> jobId mapping for WS commands
+ Map instanceJobMap = new HashMap<>();
+ if ((wsEnabled && wsSender != null) || controllerInitiatedWsEnabled) {
+ for (String instanceId : instanceIds) {
+ for (JobInfo info : jobInfoMapLocalCache.values()) {
+ for (AgentData data : info.agentData) {
+ if (instanceId.equals(data.getInstanceId())) {
+ instanceJobMap.put(instanceId, info.jobRequest.getId());
+ }
+ }
+ }
+ }
+ }
+
List instanceUrls = getInstanceUrl(instanceIds);
+
+ // Build a parallel list of instanceId -> instanceUrl for fallback
+ List orderedInstanceIds = new ArrayList<>(instanceIds);
+
return instanceUrls.parallelStream()
- .map(instanceUrl -> instanceUrl + cmd.getPath())
- .map(URI::create)
- .map(uri -> sendCommand(uri, 0))
+ .map(instanceUrl -> {
+ // Find matching instanceId for this URL
+ String matchedInstanceId = null;
+ for (int i = 0; i < orderedInstanceIds.size(); i++) {
+ String candidateId = orderedInstanceIds.get(i);
+ // Match by checking if this URL belongs to this instanceId
+ for (JobInfo info : jobInfoMapLocalCache.values()) {
+ for (AgentData data : info.agentData) {
+ if (candidateId.equals(data.getInstanceId()) && instanceUrl.equals(data.getInstanceUrl())) {
+ matchedInstanceId = candidateId;
+ }
+ }
+ }
+ }
+
+ // Try WS first if enabled
+ if (controllerInitiatedWsEnabled && controllerWsClient != null && matchedInstanceId != null) {
+ String jobId = instanceJobMap.get(matchedInstanceId);
+ if (jobId != null && controllerWsClient.hasSession(matchedInstanceId)) {
+ boolean acked = controllerWsClient.sendCommand(matchedInstanceId, jobId, cmd.name(), ackTimeout);
+ if (acked) {
+ LOG.info(new ObjectMessage(Map.of("Message", "Controller-initiated WS command " + cmd + " to agent " + matchedInstanceId + " succeeded")));
+ return CompletableFuture.completedFuture(null);
+ }
+ LOG.warn(new ObjectMessage(Map.of("Message", "Controller-initiated WS command " + cmd + " to agent " + matchedInstanceId + " failed, " +
+ (httpFallback ? "falling back to HTTP" : "no fallback"))));
+ if (!httpFallback) {
+ return CompletableFuture.completedFuture(null);
+ }
+ } else if (!httpFallback) {
+ LOG.warn(new ObjectMessage(Map.of("Message", "Controller-initiated WS enabled but no session for agent "
+ + matchedInstanceId + " and fallback disabled, skipping")));
+ return CompletableFuture.completedFuture(null);
+ }
+ } else if (controllerInitiatedWsEnabled && !httpFallback) {
+ LOG.warn(new ObjectMessage(Map.of("Message", "Controller-initiated WS enabled but no matched instanceId and fallback disabled, skipping")));
+ return CompletableFuture.completedFuture(null);
+ }
+
+ if (wsEnabled) {
+ if (wsSender != null && matchedInstanceId != null) {
+ String jobId = instanceJobMap.get(matchedInstanceId);
+ if (jobId != null && wsSender.hasSession(matchedInstanceId)) {
+ boolean acked = wsSender.sendCommand(matchedInstanceId, jobId, cmd.name(), ackTimeout);
+ if (acked) {
+ LOG.info(new ObjectMessage(Map.of("Message", "WS command " + cmd + " to agent " + matchedInstanceId + " succeeded")));
+ return CompletableFuture.completedFuture(null);
+ }
+ LOG.warn(new ObjectMessage(Map.of("Message", "WS command " + cmd + " to agent " + matchedInstanceId + " failed, " +
+ (httpFallback ? "falling back to HTTP" : "no fallback"))));
+ if (!httpFallback) {
+ return CompletableFuture.completedFuture(null);
+ }
+ } else if (!httpFallback) {
+ LOG.warn(new ObjectMessage(Map.of("Message", "WS enabled but no session for agent " + matchedInstanceId + " and fallback disabled, skipping")));
+ return CompletableFuture.completedFuture(null);
+ }
+ } else if (!httpFallback) {
+ LOG.warn(new ObjectMessage(Map.of("Message", "WS enabled but " +
+ (wsSender == null ? "sender unavailable" : "no matched instanceId") +
+ " and fallback disabled, skipping")));
+ return CompletableFuture.completedFuture(null);
+ }
+ }
+
+ // HTTP fallback (or WS not enabled)
+ URI uri = URI.create(instanceUrl + cmd.getPath());
+ return sendCommand(uri, 0);
+ })
.collect(Collectors.toList());
}
@@ -364,6 +513,20 @@ private int getNumberOfLines(Integer dataFileId) {
return ret;
}
+ private com.intuit.tank.vm.agent.messages.AgentWsCommandSender getWsCommandSender() {
+ if (wsCommandSenderInstance != null && wsCommandSenderInstance.isResolvable()) {
+ return wsCommandSenderInstance.get();
+ }
+ return null;
+ }
+
+ private ControllerInitiatedAgentWsClient getControllerInitiatedWsClient() {
+ if (controllerInitiatedWsClientInstance != null && controllerInitiatedWsClientInstance.isResolvable()) {
+ return controllerInitiatedWsClientInstance.get();
+ }
+ return null;
+ }
+
public void startAgents(String jobId){
LOG.info(new ObjectMessage(Map.of("Message","Sending start agents command to start test for job " + jobId)));
if(!jobInfoMapLocalCache.get(jobId).isStarted()){
diff --git a/tank_vmManager/src/main/java/com/intuit/tank/vmManager/VmMessageProcessorImpl.java b/tank_vmManager/src/main/java/com/intuit/tank/vmManager/VmMessageProcessorImpl.java
index f30f33460..edcfabb80 100644
--- a/tank_vmManager/src/main/java/com/intuit/tank/vmManager/VmMessageProcessorImpl.java
+++ b/tank_vmManager/src/main/java/com/intuit/tank/vmManager/VmMessageProcessorImpl.java
@@ -17,6 +17,9 @@
import jakarta.inject.Inject;
import com.amazonaws.xray.AWSXRay;
+import com.intuit.tank.perfManager.workLoads.ControllerInitiatedAgentWsClient;
+import com.intuit.tank.perfManager.workLoads.JobManager;
+import com.intuit.tank.vm.settings.TankConfig;
import com.intuit.tank.vm.vmManager.*;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
@@ -42,6 +45,15 @@ public class VmMessageProcessorImpl implements VmMessageProcessor {
@Inject
private VMTracker vmTracker;
+ @Inject
+ private ControllerInitiatedAgentWsClient controllerInitiatedAgentWsClient;
+
+ @Inject
+ private JobManager jobManager;
+
+ @Inject
+ private TankConfig tankConfig;
+
/**
* @param messageObject
*/
@@ -52,7 +64,8 @@ public void handleVMRequest(VMRequest messageObject) {
CreateInstance instance = new CreateInstance((VMInstanceRequest) messageObject, vmTracker);
Objects.requireNonNull(AWSXRay.getGlobalRecorder().getTraceEntity()).run(instance);
} else if (messageObject instanceof VMJobRequest) {
- JobRequest instance = new JobRequest((VMJobRequest) messageObject, vmTracker);
+ JobRequest instance = new JobRequest((VMJobRequest) messageObject, vmTracker,
+ controllerInitiatedAgentWsClient, jobManager, tankConfig);
Objects.requireNonNull(AWSXRay.getGlobalRecorder().getTraceEntity()).run(instance);
} else if (messageObject instanceof VMKillRequest) {
logger.debug("vmManager received VMKillRequest");
diff --git a/tank_vmManager/src/main/java/com/intuit/tank/vmManager/environment/JobRequest.java b/tank_vmManager/src/main/java/com/intuit/tank/vmManager/environment/JobRequest.java
index e8a189d6f..2eb04665d 100644
--- a/tank_vmManager/src/main/java/com/intuit/tank/vmManager/environment/JobRequest.java
+++ b/tank_vmManager/src/main/java/com/intuit/tank/vmManager/environment/JobRequest.java
@@ -15,11 +15,13 @@
import java.util.List;
import java.util.Map;
+import java.util.Optional;
import com.intuit.tank.vm.api.enumerated.IncrementStrategy;
import com.intuit.tank.logging.ControllerLoggingConfig;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
+import org.apache.commons.lang3.StringUtils;
import com.intuit.tank.vm.vmManager.VMTracker;
import com.intuit.tank.vm.vmManager.models.CloudVmStatus;
@@ -29,6 +31,12 @@
import com.intuit.tank.vm.api.enumerated.JobStatus;
import com.intuit.tank.vm.api.enumerated.VMImageType;
import com.intuit.tank.vm.api.enumerated.VMProvider;
+import com.intuit.tank.vm.api.enumerated.VMRegion;
+import com.intuit.tank.vm.agent.messages.AgentData;
+import com.intuit.tank.vm.agent.messages.AgentWsEnvelope;
+import com.intuit.tank.vm.settings.TankConfig;
+import com.intuit.tank.perfManager.workLoads.ControllerInitiatedAgentWsClient;
+import com.intuit.tank.perfManager.workLoads.JobManager;
import com.intuit.tank.vm.vmManager.JobVmCalculator;
import com.intuit.tank.vm.vmManager.VMInformation;
import com.intuit.tank.vm.vmManager.VMInstanceRequest;
@@ -43,10 +51,24 @@ public class JobRequest implements Runnable {
private VMJobRequest request = null;
private VMTracker vmTracker;
+ private ControllerInitiatedAgentWsClient controllerInitiatedAgentWsClient;
+ private JobManager jobManager;
+ private TankConfig tankConfig;
public JobRequest(VMJobRequest request, VMTracker tracker) {
+ this(request, tracker, null, null, new TankConfig());
+ }
+
+ public JobRequest(VMJobRequest request,
+ VMTracker tracker,
+ ControllerInitiatedAgentWsClient controllerInitiatedAgentWsClient,
+ JobManager jobManager,
+ TankConfig tankConfig) {
this.request = request;
this.vmTracker = tracker;
+ this.controllerInitiatedAgentWsClient = controllerInitiatedAgentWsClient;
+ this.jobManager = jobManager;
+ this.tankConfig = tankConfig;
}
@Override
@@ -95,6 +117,7 @@ private void persistInstances(VMInstanceRequest instanceRequest, List hello = Optional.empty();
+ int attempts = 0;
+ while (hello.isEmpty() && attempts < 8) {
+ attempts++;
+ hello = controllerInitiatedAgentWsClient.connect(
+ info.getInstanceId(),
+ wsUrl,
+ tankConfig.getAgentConfig().getAgentToken(),
+ 10000L);
+ if (hello.isPresent()) {
+ break;
+ }
+ try {
+ Thread.sleep(5000L);
+ } catch (InterruptedException ignored) {
+ Thread.currentThread().interrupt();
+ break;
+ }
+ }
+
+ if (hello.isEmpty()) {
+ logger.warn("Controller-initiated WS hello not received for {}", info.getInstanceId());
+ return;
+ }
+
+ int capacity = hello.get().getCapacity() != null ? hello.get().getCapacity() : instanceRequest.getNumUsersPerAgent();
+ String instanceUrl = buildHttpUrl(instanceRequest, info);
+ AgentData agentData = new AgentData(instanceRequest.getJobId(), info.getInstanceId(), instanceUrl,
+ capacity, instanceRequest.getRegion(), "unknown");
+ jobManager.registerAgentForJob(agentData);
+ }
+
+ private String buildWsUrl(VMInstanceRequest instanceRequest, VMInformation info) {
+ String host = selectHostForRegion(instanceRequest.getRegion(), info);
+ if (StringUtils.isBlank(host)) {
+ return null;
+ }
+ String wsPath = tankConfig.getAgentConfig().getCommandWsPath();
+ if (!wsPath.startsWith("/")) {
+ wsPath = "/" + wsPath;
+ }
+ return "ws://" + host + ":" + tankConfig.getAgentConfig().getAgentPort() + wsPath;
+ }
+
+ private String buildHttpUrl(VMInstanceRequest instanceRequest, VMInformation info) {
+ String host = selectHostForRegion(instanceRequest.getRegion(), info);
+ if (StringUtils.isBlank(host)) {
+ host = "localhost";
+ }
+ return "http://" + host + ":" + tankConfig.getAgentConfig().getAgentPort();
+ }
+
+ private String selectHostForRegion(VMRegion region, VMInformation info) {
+ if (region == VMRegion.US_EAST || region == VMRegion.US_EAST_2) {
+ return firstNonBlank(info.getPrivateIp(), info.getPrivateDNS(), info.getPublicIp(), info.getPublicDNS());
+ }
+ return firstNonBlank(info.getPublicDNS(), info.getPublicIp(), info.getPrivateIp(), info.getPrivateDNS());
+ }
+
+ private String firstNonBlank(String... values) {
+ for (String value : values) {
+ if (StringUtils.isNotBlank(value)) {
+ return value;
+ }
+ }
+ return null;
+ }
+
/**
* @param req
* @param info
diff --git a/tank_vmManager/src/main/java/com/intuit/tank/vmManager/environment/amazon/AmazonInstance.java b/tank_vmManager/src/main/java/com/intuit/tank/vmManager/environment/amazon/AmazonInstance.java
index ad226044e..198140a96 100644
--- a/tank_vmManager/src/main/java/com/intuit/tank/vmManager/environment/amazon/AmazonInstance.java
+++ b/tank_vmManager/src/main/java/com/intuit/tank/vmManager/environment/amazon/AmazonInstance.java
@@ -184,6 +184,12 @@ public List create(VMRequest request) {
instanceRequest.addUserData(TankConstants.KEY_CONTROLLER_URL, config.getControllerBase());
instanceRequest.addUserData(TankConstants.KEY_AGENT_TOKEN, config.getAgentConfig().getAgentToken());
instanceRequest.addUserData(TankConstants.KEY_NUM_USERS_PER_AGENT, Integer.toString(instanceRequest.getNumUsersPerAgent()));
+ instanceRequest.addUserData(TankConstants.KEY_CONTROLLER_INITIATED_WS_ENABLED,
+ Boolean.toString(config.getAgentConfig().isControllerInitiatedWsEnabled()));
+ instanceRequest.addUserData(TankConstants.KEY_CONTROLLER_INITIATED_WS_DISABLE_AGENT_HTTP,
+ Boolean.toString(config.getAgentConfig().isControllerInitiatedWsDisableAgentHttp()));
+ instanceRequest.addUserData(TankConstants.KEY_CONTROLLER_INITIATED_WS_SCRIPT_PATH,
+ config.getAgentConfig().getControllerInitiatedWsScriptPath());
if (instanceRequest.isUseEips()) {
instanceRequest.addUserData(TankConstants.KEY_USING_BIND_EIP, Boolean.TRUE.toString());
diff --git a/tank_vmManager/src/test/java/com/intuit/tank/perfManager/workLoads/ControllerInitiatedAgentWsClientTest.java b/tank_vmManager/src/test/java/com/intuit/tank/perfManager/workLoads/ControllerInitiatedAgentWsClientTest.java
new file mode 100644
index 000000000..35fb1604f
--- /dev/null
+++ b/tank_vmManager/src/test/java/com/intuit/tank/perfManager/workLoads/ControllerInitiatedAgentWsClientTest.java
@@ -0,0 +1,20 @@
+package com.intuit.tank.perfManager.workLoads;
+
+import org.junit.jupiter.api.Test;
+
+import static org.junit.jupiter.api.Assertions.assertFalse;
+
+public class ControllerInitiatedAgentWsClientTest {
+
+ @Test
+ public void testHasSessionFalseByDefault() {
+ ControllerInitiatedAgentWsClient client = new ControllerInitiatedAgentWsClient();
+ assertFalse(client.hasSession("i-missing"));
+ }
+
+ @Test
+ public void testSendCommandWithoutSessionReturnsFalse() {
+ ControllerInitiatedAgentWsClient client = new ControllerInitiatedAgentWsClient();
+ assertFalse(client.sendCommand("i-missing", "job-1", "start", 1000L));
+ }
+}
diff --git a/tank_vmManager/src/test/java/com/intuit/tank/perfManager/workLoads/JobManagerTest.java b/tank_vmManager/src/test/java/com/intuit/tank/perfManager/workLoads/JobManagerTest.java
index 6dd4db42c..3e80430bd 100644
--- a/tank_vmManager/src/test/java/com/intuit/tank/perfManager/workLoads/JobManagerTest.java
+++ b/tank_vmManager/src/test/java/com/intuit/tank/perfManager/workLoads/JobManagerTest.java
@@ -14,19 +14,32 @@
*/
import java.util.Collections;
+import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
+import java.util.Map;
+import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.FutureTask;
+import java.lang.reflect.Constructor;
+import java.lang.reflect.Field;
+import jakarta.enterprise.inject.Instance;
import org.junit.jupiter.api.*;
import static org.junit.jupiter.api.Assertions.*;
+import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.Mockito.*;
import com.intuit.tank.vm.agent.messages.AgentData;
import com.intuit.tank.vm.agent.messages.AgentTestStartData;
+import com.intuit.tank.vm.agent.messages.AgentWsCommandSender;
import com.intuit.tank.vm.api.enumerated.VMRegion;
import com.intuit.tank.vm.api.enumerated.AgentCommand;
+import com.intuit.tank.vm.api.enumerated.IncrementStrategy;
+import com.intuit.tank.vm.settings.AgentConfig;
+import com.intuit.tank.vm.settings.TankConfig;
+import com.intuit.tank.vm.vmManager.JobRequest;
/**
* The class JobManagerTest contains tests for the class {@link JobManager}.
@@ -101,4 +114,133 @@ public void testGetInstanceUrl() {
assertTrue(result.isEmpty());
}
+
+ /**
+ * When WS sender is not injected (null), sendCommand should use HTTP path only.
+ * Uses empty string instanceId (same pattern as existing tests) to avoid
+ * hitting findAgent which requires tankConfig.
+ */
+ @Test
+ public void testSendCommandWithNoWsSender() {
+ JobManager fixture = new JobManager();
+ // wsCommandSenderInstance is null (not injected) — should fall through to HTTP
+ // Empty string is filtered out by StringUtils.isNotEmpty in getInstanceUrl
+ List instanceIds = Collections.singletonList("");
+ AgentCommand cmd = AgentCommand.stop;
+
+ // Should not throw, should return empty (no resolvable URL)
+ List> result = fixture.sendCommand(instanceIds, cmd);
+ assertEquals(0, result.size());
+ }
+
+ /**
+ * When WS sender is not injected, empty instanceId list should return empty results.
+ */
+ @Test
+ public void testSendCommandEmptyListWithNoWsSender() {
+ JobManager fixture = new JobManager();
+ List> result = fixture.sendCommand(Collections.emptyList(), AgentCommand.start);
+ assertEquals(0, result.size());
+ }
+
+ @Test
+ public void testSendCommandUsesControllerInitiatedWsWhenEnabled() throws Exception {
+ JobManager fixture = new JobManager();
+ String jobId = "job-1";
+ String instanceId = "i-12345";
+ String instanceUrl = "http://127.0.0.1:8090";
+ seedJobInfoCache(fixture, jobId, instanceId, instanceUrl);
+
+ AgentConfig agentConfig = mock(AgentConfig.class);
+ when(agentConfig.isCommandWsEnabled()).thenReturn(false);
+ when(agentConfig.isControllerInitiatedWsEnabled()).thenReturn(true);
+ when(agentConfig.isCommandWsHttpFallbackEnabled()).thenReturn(false);
+ when(agentConfig.getCommandWsAckTimeoutMillis()).thenReturn(1000L);
+
+ TankConfig tankConfig = mock(TankConfig.class);
+ when(tankConfig.getAgentConfig()).thenReturn(agentConfig);
+ setField(fixture, "tankConfig", tankConfig);
+
+ ControllerInitiatedAgentWsClient wsClient = mock(ControllerInitiatedAgentWsClient.class);
+ when(wsClient.hasSession(instanceId)).thenReturn(true);
+ when(wsClient.sendCommand(instanceId, jobId, AgentCommand.stop.name(), 1000L)).thenReturn(true);
+
+ @SuppressWarnings("unchecked")
+ Instance wsClientInstance = mock(Instance.class);
+ when(wsClientInstance.isResolvable()).thenReturn(true);
+ when(wsClientInstance.get()).thenReturn(wsClient);
+ setField(fixture, "controllerInitiatedWsClientInstance", wsClientInstance);
+
+ List> result = fixture.sendCommand(Collections.singletonList(instanceId), AgentCommand.stop);
+
+ assertEquals(1, result.size());
+ assertTrue(result.get(0).isDone());
+ assertNull(result.get(0).join());
+ verify(wsClient).sendCommand(eq(instanceId), eq(jobId), eq(AgentCommand.stop.name()), eq(1000L));
+ }
+
+ @Test
+ public void testSendCommandSkipsWhenNoControllerWsSessionAndNoFallback() throws Exception {
+ JobManager fixture = new JobManager();
+ String jobId = "job-2";
+ String instanceId = "i-54321";
+ String instanceUrl = "http://127.0.0.1:8090";
+ seedJobInfoCache(fixture, jobId, instanceId, instanceUrl);
+
+ AgentConfig agentConfig = mock(AgentConfig.class);
+ when(agentConfig.isCommandWsEnabled()).thenReturn(false);
+ when(agentConfig.isControllerInitiatedWsEnabled()).thenReturn(true);
+ when(agentConfig.isCommandWsHttpFallbackEnabled()).thenReturn(false);
+ when(agentConfig.getCommandWsAckTimeoutMillis()).thenReturn(1000L);
+
+ TankConfig tankConfig = mock(TankConfig.class);
+ when(tankConfig.getAgentConfig()).thenReturn(agentConfig);
+ setField(fixture, "tankConfig", tankConfig);
+
+ ControllerInitiatedAgentWsClient wsClient = mock(ControllerInitiatedAgentWsClient.class);
+ when(wsClient.hasSession(instanceId)).thenReturn(false);
+
+ @SuppressWarnings("unchecked")
+ Instance wsClientInstance = mock(Instance.class);
+ when(wsClientInstance.isResolvable()).thenReturn(true);
+ when(wsClientInstance.get()).thenReturn(wsClient);
+ setField(fixture, "controllerInitiatedWsClientInstance", wsClientInstance);
+
+ List> result = fixture.sendCommand(Collections.singletonList(instanceId), AgentCommand.stop);
+
+ assertEquals(1, result.size());
+ assertTrue(result.get(0).isDone());
+ assertNull(result.get(0).join());
+ verify(wsClient, never()).sendCommand(anyString(), anyString(), anyString(), anyLong());
+ }
+
+ private void seedJobInfoCache(JobManager fixture, String jobId, String instanceId, String instanceUrl) throws Exception {
+ JobRequest jobRequest = mock(JobRequest.class);
+ when(jobRequest.getIncrementStrategy()).thenReturn(IncrementStrategy.increasing);
+ when(jobRequest.getRegions()).thenReturn(Collections.emptySet());
+ when(jobRequest.getScriptsXmlUrl()).thenReturn("script.xml");
+ when(jobRequest.getId()).thenReturn(jobId);
+ when(jobRequest.getNumUsersPerAgent()).thenReturn(1);
+
+ Class> jobInfoClass = Class.forName("com.intuit.tank.perfManager.workLoads.JobManager$JobInfo");
+ Constructor> ctor = jobInfoClass.getDeclaredConstructor(JobRequest.class);
+ ctor.setAccessible(true);
+ Object jobInfo = ctor.newInstance(jobRequest);
+
+ Field agentDataField = jobInfoClass.getDeclaredField("agentData");
+ agentDataField.setAccessible(true);
+ @SuppressWarnings("unchecked")
+ Set agentData = (Set) agentDataField.get(jobInfo);
+ agentData.add(new AgentData(jobId, instanceId, instanceUrl, 1, VMRegion.US_EAST_2, "zone-a"));
+
+ Map cache = new HashMap<>();
+ cache.put(jobId, jobInfo);
+ setField(fixture, "jobInfoMapLocalCache", cache);
+ }
+
+ private void setField(JobManager fixture, String fieldName, Object value) throws Exception {
+ Field field = JobManager.class.getDeclaredField(fieldName);
+ field.setAccessible(true);
+ field.set(fixture, value);
+ }
}
\ No newline at end of file
diff --git a/tank_vmManager/src/test/java/com/intuit/tank/vmManager/environment/JobRequestControllerInitiatedWsTest.java b/tank_vmManager/src/test/java/com/intuit/tank/vmManager/environment/JobRequestControllerInitiatedWsTest.java
new file mode 100644
index 000000000..5dcf33bcc
--- /dev/null
+++ b/tank_vmManager/src/test/java/com/intuit/tank/vmManager/environment/JobRequestControllerInitiatedWsTest.java
@@ -0,0 +1,54 @@
+package com.intuit.tank.vmManager.environment;
+
+import com.intuit.tank.vm.api.enumerated.IncrementStrategy;
+import com.intuit.tank.vm.api.enumerated.VMRegion;
+import com.intuit.tank.vm.vmManager.VMInformation;
+import com.intuit.tank.vm.vmManager.VMJobRequest;
+import com.intuit.tank.vmManager.VMTrackerImpl;
+import org.junit.jupiter.api.Test;
+
+import java.lang.reflect.Method;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+
+public class JobRequestControllerInitiatedWsTest {
+
+ @Test
+ public void testSelectHostForEastPrefersPrivateAddress() throws Exception {
+ JobRequest jobRequest = createFixture();
+ VMInformation vmInfo = buildVmInfo();
+
+ Method method = JobRequest.class.getDeclaredMethod("selectHostForRegion", VMRegion.class, VMInformation.class);
+ method.setAccessible(true);
+ String host = (String) method.invoke(jobRequest, VMRegion.US_EAST_2, vmInfo);
+
+ assertEquals("10.0.0.5", host);
+ }
+
+ @Test
+ public void testSelectHostForWestPrefersPublicAddress() throws Exception {
+ JobRequest jobRequest = createFixture();
+ VMInformation vmInfo = buildVmInfo();
+
+ Method method = JobRequest.class.getDeclaredMethod("selectHostForRegion", VMRegion.class, VMInformation.class);
+ method.setAccessible(true);
+ String host = (String) method.invoke(jobRequest, VMRegion.US_WEST_2, vmInfo);
+
+ assertEquals("ec2-public.example.com", host);
+ }
+
+ private JobRequest createFixture() {
+ VMJobRequest vmJobRequest = new VMJobRequest("job-1", "none", "STANDARD", 1,
+ VMRegion.US_EAST_2, IncrementStrategy.increasing, "END_OF_SCRIPT", "m.xlarge", 4000);
+ return new JobRequest(vmJobRequest, new VMTrackerImpl());
+ }
+
+ private VMInformation buildVmInfo() {
+ VMInformation info = new VMInformation();
+ info.setPrivateIp("10.0.0.5");
+ info.setPrivateDNS("ip-10-0-0-5.ec2.internal");
+ info.setPublicIp("54.12.34.56");
+ info.setPublicDNS("ec2-public.example.com");
+ return info;
+ }
+}