Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 68 additions & 26 deletions src/pyinfra/connectors/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,26 +232,43 @@ def write_stdin(stdin, buffer):
buffer.close()


ASKPASS_PATH_KEYS = ("sudo_askpass_path", "su_askpass_path")


def _iter_askpass_cache_keys(host) -> list[str]:
# Cache keys are either the bare base ("sudo_askpass_path") or the base
# joined with the resolved temp_dir ("sudo_askpass_path__/tmp"). Match
# both so cleanup covers every askpass file ever generated for the host.
return [
cache_key
for cache_key in list(host.connector_data.keys())
for base in ASKPASS_PATH_KEYS
if cache_key == base or cache_key.startswith(base + "__")
]


def remove_any_sudo_askpass_file(host) -> None:
# Best-effort cleanup: this is called from host.disconnect(), and the
# connection may already be broken (e.g. after `server.reboot`). Swallow
# any errors from the remote ``rm`` and still clear the local state so a
# reconnect will regenerate a fresh askpass file.
sudo_askpass_path = host.connector_data.get("sudo_askpass_path")
if sudo_askpass_path:
for cache_key in _iter_askpass_cache_keys(host):
path = host.connector_data.get(cache_key)
if not path:
continue
try:
host.run_shell_command(StringCommand("rm", "-f", QuoteString(sudo_askpass_path)))
host.run_shell_command(StringCommand("rm", "-f", QuoteString(path)))
except Exception as e:
logger.debug("Could not remove sudo askpass file %s: %s", sudo_askpass_path, e)
host.connector_data["sudo_askpass_path"] = None
logger.debug("Could not remove askpass file %s: %s", path, e)
host.connector_data[cache_key] = None

su_askpass_path = host.connector_data.get("su_askpass_path")
if su_askpass_path:
try:
host.run_shell_command(StringCommand("rm", "-f", QuoteString(su_askpass_path)))
except Exception as e:
logger.debug("Could not remove su askpass file %s: %s", su_askpass_path, e)
host.connector_data["su_askpass_path"] = None

def clear_askpass_cache(host) -> None:
# Drop every cached askpass path without touching the remote, used after
# ``server.reboot`` where the previous connection (and therefore any
# askpass scripts under its temp dir) is gone.
for cache_key in _iter_askpass_cache_keys(host):
host.connector_data[cache_key] = None


@memoize
Expand Down Expand Up @@ -280,18 +297,35 @@ def extract_control_arguments(arguments: "ConnectorArguments") -> "ConnectorArgu
return control_arguments


def _ensure_sudo_askpass_set_for_host(host: "Host"):
return _ensure_askpass_set_for_host(host, "sudo_askpass_path", SUDO_ASKPASS_ENV_VAR)
def _ensure_sudo_askpass_set_for_host(host: "Host", temp_dir: Optional[str] = None) -> str:
return _ensure_askpass_set_for_host(
host, "sudo_askpass_path", SUDO_ASKPASS_ENV_VAR, temp_dir=temp_dir
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should add the path to the key here or we'll incorrectly cache the wrong askpass path if the value is changed (very unlikely, but worth guarding against).

Suggested change
host, "sudo_askpass_path", SUDO_ASKPASS_ENV_VAR, temp_dir=temp_dir
host, f"sudo_askpass_path:{temp_dir}", SUDO_ASKPASS_ENV_VAR, temp_dir=temp_dir

)


def _ensure_su_askpass_set_for_host(host: "Host", temp_dir: Optional[str] = None) -> str:
return _ensure_askpass_set_for_host(
host, "su_askpass_path", SU_ASKPASS_ENV_VAR, temp_dir=temp_dir
)


def _ensure_su_askpass_set_for_host(host: "Host"):
return _ensure_askpass_set_for_host(host, "su_askpass_path", SU_ASKPASS_ENV_VAR)
def _ensure_askpass_set_for_host(
host: "Host", key: str, env_var: str, temp_dir: Optional[str] = None
) -> str:
# Operation-level _temp_dir (if any) overrides the host-level/global
# temp directory resolution so `server.shell(..., _temp_dir=X)` places
# the askpass script under X rather than /tmp. Encoding the resolved
# temp_dir in the cache key gives every (host, temp_dir) pair its own
# entry, so switching dirs across calls just misses the cache instead
# of needing an explicit invalidation step.
effective_temp_dir = temp_dir or host.get_temp_dir_config()
cache_key = "{0}__{1}".format(key, effective_temp_dir)

cached = host.connector_data.get(cache_key)
if cached:
return cached

def _ensure_askpass_set_for_host(host: "Host", key: str, env_var: str):
if host.connector_data.get(key):
return
ok, output = host.run_shell_command(ASKPASS_COMMAND.format(host.get_temp_dir_config(), env_var))
ok, output = host.run_shell_command(ASKPASS_COMMAND.format(effective_temp_dir, env_var))

if not ok:
raise PyinfraError("Failed to create sudo_askpass command: {0}".format(output.output))
Expand All @@ -303,7 +337,9 @@ def _ensure_askpass_set_for_host(host: "Host", key: str, env_var: str):
)
)

host.connector_data[key] = output.stdout_lines[0]
path = output.stdout_lines[0]
host.connector_data[cache_key] = path
return path


def make_unix_command_for_host(
Expand All @@ -312,6 +348,11 @@ def make_unix_command_for_host(
command: StringCommand,
**command_arguments,
) -> StringCommand:
# Operation-level temp directory override, if any. Passed through to the
# askpass helpers so the generated SUDO_ASKPASS / SU_ASKPASS script lands
# under the same directory the operation asked for.
op_temp_dir = command_arguments.get("_temp_dir")

# Handle sudo password
if command_arguments.get("_sudo"):
# If the sudo password is not set in the direct arguments,
Expand All @@ -320,15 +361,16 @@ def make_unix_command_for_host(
command_arguments["_sudo_password"] = host.connector_data.get("prompted_sudo_password")

if command_arguments.get("_sudo_password"):
# Ensure the askpass path is correctly set and passed through
_ensure_sudo_askpass_set_for_host(host)
command_arguments["_sudo_askpass_path"] = host.connector_data["sudo_askpass_path"]
command_arguments["_sudo_askpass_path"] = _ensure_sudo_askpass_set_for_host(
host, temp_dir=op_temp_dir
)

# Handle su password
if command_arguments.get("_su_user"):
if command_arguments.get("_su_password"):
_ensure_su_askpass_set_for_host(host)
command_arguments["_su_askpass_path"] = host.connector_data["su_askpass_path"]
command_arguments["_su_askpass_path"] = _ensure_su_askpass_set_for_host(
host, temp_dir=op_temp_dir
)

return make_unix_command(command, **command_arguments)

Expand Down
9 changes: 4 additions & 5 deletions src/pyinfra/operations/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from pyinfra import host, logger, state
from pyinfra.api import FunctionCommand, OperationError, QuoteString, StringCommand, operation
from pyinfra.api.util import try_int
from pyinfra.connectors.util import remove_any_sudo_askpass_file
from pyinfra.connectors.util import clear_askpass_cache, remove_any_sudo_askpass_file
from pyinfra.facts.files import Directory, FindInFile, Link
from pyinfra.facts.server import (
AuthorizedKeys,
Expand Down Expand Up @@ -93,11 +93,10 @@ def wait_and_reconnect(state, host): # pragma: no cover
max_retries = round(reboot_timeout / interval)

# The remote askpass files (if any) live on a host that has just
# rebooted the SSH session is dead and there is nothing to clean up.
# rebooted, the SSH session is dead and there is nothing to clean up.
# Clear the stored paths before disconnecting so the disconnect path
# does not attempt an ``rm -f`` over the broken connection.
host.connector_data["sudo_askpass_path"] = None
host.connector_data["su_askpass_path"] = None
clear_askpass_cache(host)

host.disconnect() # make sure we are properly disconnected
retries = 0
Expand All @@ -119,7 +118,7 @@ def wait_and_reconnect(state, host): # pragma: no cover

# On certain systems sudo files are lost on reboot
def clean_sudo_info(state, host):
host.connector_data["sudo_askpass_path"] = None
clear_askpass_cache(host)

yield FunctionCommand(clean_sudo_info, (), {})

Expand Down
2 changes: 1 addition & 1 deletion tests/test_connectors/test_ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,7 +622,7 @@ def test_run_shell_command_retry_for_sudo_password(
state = State(inventory, Config())
host = inventory.get_host("somehost")
host.connect(state)
host.connector_data["sudo_askpass_path"] = "/tmp/pyinfra-sudo-askpass-XXXXXXXXXXXX"
host.connector_data["sudo_askpass_path__/tmp"] = "/tmp/pyinfra-sudo-askpass-XXXXXXXXXXXX"

command = "echo hi"
return_values = [1, 0] # return 0 on the second call
Expand Down
Loading
Loading