Source code for dml_util.runners.remote

"""Remote execution runners.

This module provides runners for executing tasks in remote environments,
such as SSH. These runners can execute commands on remote machines
with specific configurations.
"""

import logging
import shlex
import subprocess

from dml_util.runners.base import RunnerBase

logger = logging.getLogger(__name__)

SCRIPT_TPL = """
#!/usr/bin/env bash
set -euo pipefail

# REPLACE THIS LINE

# require exactly 2 args
if [ "$#" -ne 2 ]; then
  echo "Usage: echo data | $0 adapter uri" >&2
  exit 1
fi
cmd=( "$@" )
exec "${cmd[@]}"
""".strip()


[docs] class SshRunner(RunnerBase): """Runs a command over SSH."""
[docs] @classmethod def funkify(cls, host, sub, flags=None, env_files=None): script = SCRIPT_TPL if env_files is not None: script = script.replace( "REPLACE THIS LINE", "\n".join(["ENV FILES HERE..."] + [f". {env_file}" for env_file in env_files]), ) return {"sub": sub, "host": host, "flags": flags or [], "script": script}
[docs] def proc_script(self) -> str: # for k, v in self.env set flag in the script tmpf, _ = self._run_cmd("mktemp", "-t", "dml.XXXXXX.sh") shbang, *lines = self.input.kwargs["script"].split("\n") env_lines = [f"export {k}={shlex.quote(v)}" for k, v in self.config.to_envvars().items()] script = "\n".join([shbang, *env_lines, *lines]) self._run_cmd("cat", ">", tmpf, input=script) self._run_cmd("chmod", "+x", tmpf) return tmpf
def _run_cmd(self, *user_cmd, **kw): cmd = ["ssh", *self.input.kwargs["flags"], self.input.kwargs["host"], *user_cmd] resp = subprocess.run(cmd, capture_output=True, text=True, check=False, **kw) if resp.returncode != 0: msg = f"Ssh(code:{resp.returncode}) {user_cmd}\nSTDOUT\n{resp.stdout}\n\nSTDERR\n{resp.stderr}" raise RuntimeError(msg) stderr = resp.stderr.strip() logger.debug(f"SSH STDERR: {stderr}") return resp.stdout.strip(), stderr
[docs] def run(self): sub_adapter, sub_uri, sub_kwargs = self.input.get_sub() tmpf = self.proc_script() stdout, stderr = self._run_cmd(tmpf, sub_adapter, sub_uri, input=sub_kwargs) # stdout = json.loads(stdout or "{}") self._run_cmd("rm", tmpf) return stdout, stderr