from __future__ import annotations import logging import os import shlex import subprocess # noqa: S404 from dataclasses import dataclass from typing import TYPE_CHECKING if TYPE_CHECKING: from collections.abc import Mapping from collections.abc import Sequence from pathlib import Path logger = logging.getLogger("tussilago.control_plane.host_commands") DEFAULT_INHERITED_ENV_KEYS: frozenset[str] = frozenset( { "HOME", "LANG", "LC_ALL", "LC_CTYPE", "LOGNAME", "PATH", "SSL_CERT_DIR", "SSL_CERT_FILE", "TMPDIR", "USER", "UV_CACHE_DIR", "VIRTUAL_ENV", "XDG_CACHE_HOME", "XDG_RUNTIME_DIR", }, ) @dataclass(frozen=True, slots=True) class HostCommandResult: """Capture output from a completed host-side command.""" args: tuple[str, ...] returncode: int stdout: str stderr: str class HostCommandError(RuntimeError): """Raised when a host-side command fails or times out.""" def __init__( self, message: str, *, args: Sequence[str], returncode: int | None, stdout: str, stderr: str, ) -> None: """Store captured command context for later error reporting.""" super().__init__(message) self.command_args = tuple(args) self.returncode = returncode self.stdout = stdout self.stderr = stderr def build_host_command_env( *, env_overrides: Mapping[str, str] | None = None, allowed_env_keys: frozenset[str] | None = None, inherited_env_keys: frozenset[str] = DEFAULT_INHERITED_ENV_KEYS, ) -> dict[str, str]: """Build a sanitized environment for host-side child processes. Returns: A filtered environment dictionary suitable for subprocess execution. Raises: ValueError: If env overrides are provided without an allowlist. """ resolved_env = {key: value for key, value in os.environ.items() if key in inherited_env_keys} if env_overrides is None: return resolved_env if allowed_env_keys is None: msg = "allowed_env_keys is required when env_overrides are provided" raise ValueError(msg) disallowed_keys = sorted(set(env_overrides).difference(allowed_env_keys)) if disallowed_keys: msg = f"env_overrides contains disallowed keys: {', '.join(disallowed_keys)}" raise ValueError(msg) resolved_env.update(env_overrides) return resolved_env def run_host_command( *, command: Sequence[str], cwd: Path | None = None, env_overrides: Mapping[str, str] | None = None, allowed_env_keys: frozenset[str] | None = None, timeout_seconds: float = 60.0, ) -> HostCommandResult: """Run a host-side command with explicit environment and timeout controls. Returns: A result object containing the command, return code, and captured output. Raises: ValueError: If the command is empty or env overrides are not allowlisted. HostCommandError: If the command fails or times out. """ normalized_command = tuple(command) if not normalized_command: msg = "command must not be empty" raise ValueError(msg) if any(not argument for argument in normalized_command): msg = "command arguments must be non-empty strings" raise ValueError(msg) resolved_env = build_host_command_env( env_overrides=env_overrides, allowed_env_keys=allowed_env_keys, ) logger.debug( "Running host command executable=%s argc=%s (cwd=%s)", shlex.quote(normalized_command[0]), len(normalized_command), cwd, ) try: completed = subprocess.run( # noqa: S603 normalized_command, check=True, capture_output=True, text=True, cwd=cwd, env=resolved_env, timeout=timeout_seconds, ) except subprocess.CalledProcessError as error: msg_0 = "Host command failed." raise HostCommandError( msg_0, args=tuple(str(argument) for argument in error.cmd), returncode=error.returncode, stdout=error.stdout or "", stderr=error.stderr or "", ) from error except subprocess.TimeoutExpired as error: msg_0 = "Host command timed out." raise HostCommandError( msg_0, args=normalized_command, returncode=None, stdout=str(error.stdout) or "", stderr=str(error.stderr) or "", ) from error return HostCommandResult( args=normalized_command, returncode=completed.returncode, stdout=completed.stdout, stderr=completed.stderr, )