Source code for dml_util.runners.base
"""Runner base class for executing code in different environments.
This module defines the base RunnerBase class for executing code in different
environments. Runners are used to execute tasks in various environments, such as
local, container, or remote environments. They provide a consistent interface
for task execution with state management and logging.
"""
import logging
import os
import re
from dataclasses import dataclass, field
from warnings import warn
from dml_util.core.config import EnvConfig, InputConfig
from dml_util.core.state import LocalState, State
logger = logging.getLogger(__name__)
[docs]
@dataclass
class RunnerBase:
"""Base Runner class for executing code in different environments.
This class provides a framework for running tasks with state management and logging.
Subclasses must implement specific methods and adhere to the defined interface.
Notes
-----
Subclasses must implement one or more of the following methods:
- `run`: Executes the primary task logic (e.g., `WrappedRunner`, `SshRunner`).
- `update`: Updates the state and handles task execution (e.g., `ScriptRunner`).
The difference being tha tthe `run` method will handle all of the locking and state
management for you, so if you override it, you should not call `self.put_state` or
`self.state.get` directly.
Examples
--------
>>> class MyRunner(RunnerBase):
...
... def run(self):
... print("Running task:", self.task_name)
"""
config: EnvConfig
input: InputConfig
state: State = field(init=False)
state_class = LocalState
_RUNNERS = {}
def __post_init__(self):
if isinstance(self.config, dict):
self.config = EnvConfig(**self.config)
if isinstance(self.input, dict):
self.input = InputConfig(**self.input)
self.state = self.state_class(self.input.cache_key)
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
key = re.sub(r"runner$", "", cls.__name__.lower())
if key in RunnerBase._RUNNERS:
warn(f"Runner {key} already exists, overwriting with {cls.__name__}", UserWarning, stacklevel=2)
RunnerBase._RUNNERS[key] = cls
@property
def clsname(self):
return self.__class__.__name__.lower()
@property
def prefix(self):
return f"{self.config.s3_prefix}/exec/{self.clsname}"
def _fmt(self, msg):
logger.info(msg)
return f"{self.clsname} [{self.input.cache_key}] :: {msg}"
[docs]
def put_state(self, state):
self.state.put(state)
[docs]
def run(self):
"""Run the task and return the result.
This method handles acquiring the job lock, updating the state, and
returning the response and message. The main logic of the task is
implemented in the `update` method, which must be defined by subclasses.
"""
state = self.state.get()
if state is None:
return None, self._fmt("Could not acquire job lock")
delete = False
try:
logger.info("getting info from %r", self.state_class.__name__)
new_state, msg, response = self.update(state)
if new_state is None:
delete = True
else:
self.put_state(new_state)
return response, self._fmt(msg)
except Exception:
delete = True
raise
finally:
if delete:
if not os.getenv("DML_NO_GC"):
self.gc(state)
self.state.delete()
else:
self.state.unlock()
[docs]
def update(self, state):
"""Update the state and return the new state, message, and response.
The `gc` method is called if and only if the returned state is None.
"""
raise NotImplementedError("Runner.update must be implemented by subclasses")
[docs]
def gc(self, state):
"""Clean up any resources."""
pass