171 lines
4.7 KiB
Python
171 lines
4.7 KiB
Python
from __future__ import annotations
|
|
|
|
import logging
|
|
import os
|
|
import shlex
|
|
import subprocess # noqa: S404
|
|
from dataclasses import dataclass
|
|
from typing import TYPE_CHECKING
|
|
|
|
if TYPE_CHECKING:
|
|
from collections.abc import Mapping
|
|
from collections.abc import Sequence
|
|
from pathlib import Path
|
|
|
|
logger = logging.getLogger("tussilago.control_plane.host_commands")
|
|
|
|
DEFAULT_INHERITED_ENV_KEYS: frozenset[str] = frozenset(
|
|
{
|
|
"HOME",
|
|
"LANG",
|
|
"LC_ALL",
|
|
"LC_CTYPE",
|
|
"LOGNAME",
|
|
"PATH",
|
|
"SSL_CERT_DIR",
|
|
"SSL_CERT_FILE",
|
|
"TMPDIR",
|
|
"USER",
|
|
"UV_CACHE_DIR",
|
|
"VIRTUAL_ENV",
|
|
"XDG_CACHE_HOME",
|
|
"XDG_RUNTIME_DIR",
|
|
},
|
|
)
|
|
|
|
|
|
@dataclass(frozen=True, slots=True)
|
|
class HostCommandResult:
|
|
"""Capture output from a completed host-side command."""
|
|
|
|
args: tuple[str, ...]
|
|
returncode: int
|
|
stdout: str
|
|
stderr: str
|
|
|
|
|
|
class HostCommandError(RuntimeError):
|
|
"""Raised when a host-side command fails or times out."""
|
|
|
|
def __init__(
|
|
self,
|
|
message: str,
|
|
*,
|
|
args: Sequence[str],
|
|
returncode: int | None,
|
|
stdout: str,
|
|
stderr: str,
|
|
) -> None:
|
|
"""Store captured command context for later error reporting."""
|
|
super().__init__(message)
|
|
self.command_args = tuple(args)
|
|
self.returncode = returncode
|
|
self.stdout = stdout
|
|
self.stderr = stderr
|
|
|
|
|
|
def build_host_command_env(
|
|
*,
|
|
env_overrides: Mapping[str, str] | None = None,
|
|
allowed_env_keys: frozenset[str] | None = None,
|
|
inherited_env_keys: frozenset[str] = DEFAULT_INHERITED_ENV_KEYS,
|
|
) -> dict[str, str]:
|
|
"""Build a sanitized environment for host-side child processes.
|
|
|
|
Returns:
|
|
A filtered environment dictionary suitable for subprocess execution.
|
|
|
|
Raises:
|
|
ValueError: If env overrides are provided without an allowlist.
|
|
"""
|
|
resolved_env = {key: value for key, value in os.environ.items() if key in inherited_env_keys}
|
|
|
|
if env_overrides is None:
|
|
return resolved_env
|
|
|
|
if allowed_env_keys is None:
|
|
msg = "allowed_env_keys is required when env_overrides are provided"
|
|
raise ValueError(msg)
|
|
|
|
disallowed_keys = sorted(set(env_overrides).difference(allowed_env_keys))
|
|
if disallowed_keys:
|
|
msg = f"env_overrides contains disallowed keys: {', '.join(disallowed_keys)}"
|
|
raise ValueError(msg)
|
|
|
|
resolved_env.update(env_overrides)
|
|
return resolved_env
|
|
|
|
|
|
def run_host_command(
|
|
*,
|
|
command: Sequence[str],
|
|
cwd: Path | None = None,
|
|
env_overrides: Mapping[str, str] | None = None,
|
|
allowed_env_keys: frozenset[str] | None = None,
|
|
timeout_seconds: float = 60.0,
|
|
) -> HostCommandResult:
|
|
"""Run a host-side command with explicit environment and timeout controls.
|
|
|
|
Returns:
|
|
A result object containing the command, return code, and captured output.
|
|
|
|
Raises:
|
|
ValueError: If the command is empty or env overrides are not allowlisted.
|
|
HostCommandError: If the command fails or times out.
|
|
"""
|
|
normalized_command = tuple(command)
|
|
if not normalized_command:
|
|
msg = "command must not be empty"
|
|
raise ValueError(msg)
|
|
|
|
if any(not argument for argument in normalized_command):
|
|
msg = "command arguments must be non-empty strings"
|
|
raise ValueError(msg)
|
|
|
|
resolved_env = build_host_command_env(
|
|
env_overrides=env_overrides,
|
|
allowed_env_keys=allowed_env_keys,
|
|
)
|
|
|
|
logger.debug(
|
|
"Running host command executable=%s argc=%s (cwd=%s)",
|
|
shlex.quote(normalized_command[0]),
|
|
len(normalized_command),
|
|
cwd,
|
|
)
|
|
|
|
try:
|
|
completed = subprocess.run( # noqa: S603
|
|
normalized_command,
|
|
check=True,
|
|
capture_output=True,
|
|
text=True,
|
|
cwd=cwd,
|
|
env=resolved_env,
|
|
timeout=timeout_seconds,
|
|
)
|
|
except subprocess.CalledProcessError as error:
|
|
msg_0 = "Host command failed."
|
|
raise HostCommandError(
|
|
msg_0,
|
|
args=tuple(str(argument) for argument in error.cmd),
|
|
returncode=error.returncode,
|
|
stdout=error.stdout or "",
|
|
stderr=error.stderr or "",
|
|
) from error
|
|
except subprocess.TimeoutExpired as error:
|
|
msg_0 = "Host command timed out."
|
|
raise HostCommandError(
|
|
msg_0,
|
|
args=normalized_command,
|
|
returncode=None,
|
|
stdout=str(error.stdout) or "",
|
|
stderr=str(error.stderr) or "",
|
|
) from error
|
|
|
|
return HostCommandResult(
|
|
args=normalized_command,
|
|
returncode=completed.returncode,
|
|
stdout=completed.stdout,
|
|
stderr=completed.stderr,
|
|
)
|