diff --git a/os/windows.go b/os/windows.go index e3a62ee9..95f157ec 100644 --- a/os/windows.go +++ b/os/windows.go @@ -177,11 +177,46 @@ func (c Windows) CommandExist(h Host, cmd string) bool { return h.Execf("where /q %s", cmd) == nil } -// Reboot executes the reboot command +// Reboot triggers an immediate forced restart by scheduling a SYSTEM-context +// one-shot task that runs 'shutdown /r /f /t 5', then immediately triggering +// and deleting it within the 5-second countdown window. +// +// Running via a scheduled task bypasses the filtered Administrator token used +// by WinRM sessions (e.g. AWS EC2) which lacks SeShutdownPrivilege. Issuing +// 'shutdown /r' directly in the WinRM session is silently ignored in that +// context. +// +// /sc onstart is used instead of /sc once to avoid schtasks writing a +// stderr warning about the start time being in the past, which rig treats +// as an error. The task is deleted immediately after triggering (while the +// 5-second timer counts down) so it does not re-fire on subsequent startups. func (c Windows) Reboot(h Host) error { - if err := h.Exec("shutdown /r /t 5"); err != nil { - return fmt.Errorf("failed to reboot: %w", err) + const taskName = "RigReboot" + // Create a SYSTEM-context ONSTART task that runs 'shutdown /r /f /t 5'. + // The 5-second delay gives us time to delete the task before the OS + // actually executes the reboot, preventing it from firing again on the + // next startup. + create := fmt.Sprintf(`schtasks /create /tn "%s" /tr "shutdown /r /f /t 5" /sc onstart /f /ru SYSTEM`, taskName) + if err := h.Exec(create); err != nil { + return fmt.Errorf("failed to create reboot task: %w", err) } + run := fmt.Sprintf(`schtasks /run /tn "%s"`, taskName) + if err := h.Exec(run); err != nil { + // Tolerate connection-level errors; the OS may kill WinRM as it starts + // rebooting before the run command returns. + errMsg := err.Error() + if !strings.Contains(errMsg, "connection") && !strings.Contains(errMsg, "closed") && !strings.Contains(errMsg, "EOF") { + return fmt.Errorf("failed to run reboot task: %w", err) + } + } + // Delete the task immediately while the 5-second shutdown timer is still + // counting down. This prevents it from re-firing on subsequent startups. + del := fmt.Sprintf(`schtasks /delete /tn "%s" /f`, taskName) + // Best-effort: ignore delete errors — if the task fires before we can + // delete it, the caller is expected to delete it after reconnecting. + _ = h.Exec(del) + // Allow Windows time to complete shutdown before waitForHost begins polling. + time.Sleep(15 * time.Second) return nil } diff --git a/winrm.go b/winrm.go index e7eb3928..8ad58f5e 100644 --- a/winrm.go +++ b/winrm.go @@ -197,17 +197,8 @@ type Command struct { } // Wait blocks until the command finishes -func (c *Command) Wait() (err error) { //nolint:nonamedreturns // needed for panic recovery - defer func() { - if r := recover(); err == nil && r != nil { - if strings.Contains(fmt.Sprint(r), "close of closed channel") { - log.Debugf("recovered from a panic in Command.Wait: %v", r) - } else { - panic(r) - } - } - }() - +// Wait blocks until the command finishes +func (c *Command) Wait() error { defer c.sh.Close() defer c.cmd.Close() @@ -215,9 +206,9 @@ func (c *Command) Wait() (err error) { //nolint:nonamedreturns // needed for pan c.cmd.Wait() log.Debugf("command finished") if c.cmd.ExitCode() != 0 { - err = fmt.Errorf("%w: exit code %d", ErrCommandFailed, c.cmd.ExitCode()) + return fmt.Errorf("%w: exit code %d", ErrCommandFailed, c.cmd.ExitCode()) } - return err + return nil } // Close terminates the command @@ -319,26 +310,19 @@ func (c *WinRM) Exec(cmd string, opts ...exec.Option) error { //nolint:cyclop }() } + var errors []string + wg.Add(1) go func() { - // ignore channel close panics - defer func() { - if r := recover(); r != nil { - log.Debugf("recovered from a panic while reading stderr: %v", r) - } - }() defer wg.Done() if execOpts.Writer == nil { outputScanner := bufio.NewScanner(command.Stdout) - for outputScanner.Scan() { execOpts.AddOutput(c.String(), outputScanner.Text()+"\n", "") } - if err := outputScanner.Err(); err != nil { execOpts.LogErrorf("%s: %s", c, err.Error()) } - command.Stdout.Close() } else { if _, err := io.Copy(execOpts.Writer, command.Stdout); err != nil { execOpts.LogErrorf("%s: failed to stream stdout: %v", c, err) @@ -346,19 +330,10 @@ func (c *WinRM) Exec(cmd string, opts ...exec.Option) error { //nolint:cyclop } }() - var errors []string - wg.Add(1) go func() { - // ignore channel close panics - defer func() { - if r := recover(); r != nil { - log.Debugf("recovered from a panic while reading stderr: %v", r) - } - }() defer wg.Done() outputScanner := bufio.NewScanner(command.Stderr) - for outputScanner.Scan() { msg := outputScanner.Text() if msg != "" { @@ -366,11 +341,9 @@ func (c *WinRM) Exec(cmd string, opts ...exec.Option) error { //nolint:cyclop execOpts.LogErrorf("%s: %s", c, msg) } } - if err := outputScanner.Err(); err != nil { execOpts.LogErrorf("%s: %s", c, err.Error()) } - command.Stderr.Close() }() wg.Wait()