Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 34 additions & 9 deletions src/pyinfra/connectors/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,18 +280,37 @@ def extract_control_arguments(arguments: "ConnectorArguments") -> "ConnectorArgu
return control_arguments


def _ensure_sudo_askpass_set_for_host(host: "Host"):
return _ensure_askpass_set_for_host(host, "sudo_askpass_path", SUDO_ASKPASS_ENV_VAR)
def _ensure_sudo_askpass_set_for_host(host: "Host", temp_dir: Optional[str] = None):
return _ensure_askpass_set_for_host(
host, "sudo_askpass_path", SUDO_ASKPASS_ENV_VAR, temp_dir=temp_dir
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should add the path to the key here or we'll incorrectly cache the wrong askpass path if the value is changed (very unlikely, but worth guarding against).

Suggested change
host, "sudo_askpass_path", SUDO_ASKPASS_ENV_VAR, temp_dir=temp_dir
host, f"sudo_askpass_path:{temp_dir}", SUDO_ASKPASS_ENV_VAR, temp_dir=temp_dir

)


def _ensure_su_askpass_set_for_host(host: "Host"):
return _ensure_askpass_set_for_host(host, "su_askpass_path", SU_ASKPASS_ENV_VAR)
def _ensure_su_askpass_set_for_host(host: "Host", temp_dir: Optional[str] = None):
return _ensure_askpass_set_for_host(
host, "su_askpass_path", SU_ASKPASS_ENV_VAR, temp_dir=temp_dir
)


def _ensure_askpass_set_for_host(host: "Host", key: str, env_var: str):
if host.connector_data.get(key):
def _ensure_askpass_set_for_host(
host: "Host", key: str, env_var: str, temp_dir: Optional[str] = None
):
# Operation-level _temp_dir (if any) overrides the host-level/global
# temp directory resolution so `server.shell(..., _temp_dir=X)` places
# the askpass script under X rather than /tmp.
effective_temp_dir = temp_dir or host.get_temp_dir_config()

# Invalidate the cache if the resolved temp_dir changed since the path
# was created, otherwise we'd hand out a stale path under the wrong dir.
# If the tracker is missing (older code path or external population),
# trust the existing path to preserve backward compatibility.
temp_dir_cache_key = "{0}_temp_dir".format(key)
existing_path = host.connector_data.get(key)
existing_temp_dir = host.connector_data.get(temp_dir_cache_key)
if existing_path and (existing_temp_dir is None or existing_temp_dir == effective_temp_dir):
return
ok, output = host.run_shell_command(ASKPASS_COMMAND.format(host.get_temp_dir_config(), env_var))

ok, output = host.run_shell_command(ASKPASS_COMMAND.format(effective_temp_dir, env_var))

if not ok:
raise PyinfraError("Failed to create sudo_askpass command: {0}".format(output.output))
Expand All @@ -304,6 +323,7 @@ def _ensure_askpass_set_for_host(host: "Host", key: str, env_var: str):
)

host.connector_data[key] = output.stdout_lines[0]
host.connector_data[temp_dir_cache_key] = effective_temp_dir


def make_unix_command_for_host(
Expand All @@ -312,6 +332,11 @@ def make_unix_command_for_host(
command: StringCommand,
**command_arguments,
) -> StringCommand:
# Operation-level temp directory override, if any. Passed through to the
# askpass helpers so the generated SUDO_ASKPASS / SU_ASKPASS script lands
# under the same directory the operation asked for.
op_temp_dir = command_arguments.get("_temp_dir")

# Handle sudo password
if command_arguments.get("_sudo"):
# If the sudo password is not set in the direct arguments,
Expand All @@ -321,13 +346,13 @@ def make_unix_command_for_host(

if command_arguments.get("_sudo_password"):
# Ensure the askpass path is correctly set and passed through
_ensure_sudo_askpass_set_for_host(host)
_ensure_sudo_askpass_set_for_host(host, temp_dir=op_temp_dir)
command_arguments["_sudo_askpass_path"] = host.connector_data["sudo_askpass_path"]

# Handle su password
if command_arguments.get("_su_user"):
if command_arguments.get("_su_password"):
_ensure_su_askpass_set_for_host(host)
_ensure_su_askpass_set_for_host(host, temp_dir=op_temp_dir)
command_arguments["_su_askpass_path"] = host.connector_data["su_askpass_path"]

return make_unix_command(command, **command_arguments)
Expand Down
100 changes: 99 additions & 1 deletion tests/test_connectors/test_util.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
# encoding: utf-8

from unittest import TestCase
from unittest.mock import MagicMock
from unittest.mock import MagicMock, patch

from pyinfra.api import Config, State
from pyinfra.connectors.util import (
CommandOutput,
OutputLine,
_ensure_askpass_set_for_host,
make_unix_command,
make_unix_command_for_host,
remove_any_sudo_askpass_file,
Expand Down Expand Up @@ -280,3 +283,98 @@ def test_noop_when_no_state(self):
remove_any_sudo_askpass_file(host)

host.run_shell_command.assert_not_called()


class TestEnsureAskpassTempDir(TestCase):
"""
The askpass helper must honour the resolved temp directory (issue #1623):
operation-level ``_temp_dir`` > ``config.TEMP_DIR`` > ``config.DEFAULT_TEMP_DIR``.
Per-host defaults for ``_temp_dir`` come through the standard global-argument
cascade (``host.data._temp_dir``), not via a separate code path here.
"""

_counter = 0

@classmethod
def _next_host(cls):
cls._counter += 1
return "askpass-temp-dir-test-host-{0}".format(cls._counter)

def _make_host(self, config=None):
name = self._next_host()
state = State(make_inventory(hosts=(name,)), config or Config())
host = state.inventory.get_host(name)
host.init(state)
return host

def _captured_script_run(self, host, temp_dir=None, stdout="/some/askpass/path"):
"""
Call ``_ensure_askpass_set_for_host`` with ``host.run_shell_command``
patched so we can assert on the remote script text (whose first
argument to the mkstemp template is the temp directory).
"""
captured = {}

def fake_run(command, *args, **kwargs):
captured["command"] = command
return (True, CommandOutput([OutputLine("stdout", stdout)]))

host.run_shell_command = fake_run # type: ignore[method-assign]
_ensure_askpass_set_for_host(
host,
key="sudo_askpass_path",
env_var="PYINFRA_SUDO_PASSWORD",
temp_dir=temp_dir,
)
return captured["command"]

def test_default_temp_dir(self):
host = self._make_host()
script = self._captured_script_run(host)
assert "${TMPDIR:=/tmp}" in script

def test_config_temp_dir(self):
host = self._make_host(Config(TEMP_DIR="/var/tmp"))
script = self._captured_script_run(host)
assert "${TMPDIR:=/var/tmp}" in script

def test_op_temp_dir_wins_over_config(self):
host = self._make_host(Config(TEMP_DIR="/var/tmp"))
script = self._captured_script_run(host, temp_dir="/dev/shm/pyinfra")
assert "${TMPDIR:=/dev/shm/pyinfra}" in script

def test_cache_invalidates_when_temp_dir_changes(self):
host = self._make_host()
first = self._captured_script_run(host, temp_dir="/a", stdout="/a/askpass")
assert "${TMPDIR:=/a}" in first
second = self._captured_script_run(host, temp_dir="/b", stdout="/b/askpass")
assert "${TMPDIR:=/b}" in second
assert host.connector_data["sudo_askpass_path"] == "/b/askpass"

def test_make_unix_command_for_host_threads_temp_dir(self):
host = self._make_host()
host.connector_data["prompted_sudo_password"] = "supersecret"

captured = {}

def fake_run(command, *args, **kwargs):
captured["command"] = command
return (
True,
CommandOutput([OutputLine("stdout", "/op/tmp/pyinfra-askpass-XYZ")]),
)

host.run_shell_command = fake_run # type: ignore[method-assign]

with patch("pyinfra.connectors.util.make_unix_command") as fake_make:
fake_make.return_value = "mocked"
make_unix_command_for_host(
host.state,
host,
"uptime",
_sudo=True,
_sudo_password="supersecret",
_temp_dir="/op/tmp",
)

assert "${TMPDIR:=/op/tmp}" in captured["command"]
Loading