"""Base adapter class.
This module provides the base adapter class for the DaggerML utilities.
Adapters are used to connect DaggerML to various execution environments,
such as AWS Lambda or local runners.
"""
import logging
import logging.config
import os
import re
import sys
import time
from argparse import ArgumentParser
from dataclasses import dataclass
from urllib.parse import urlparse
import boto3
import botocore
from botocore.exceptions import BotoCoreError, NoRegionError
from dml_util.aws.s3 import S3Store
from dml_util.core.config import EnvConfig
from dml_util.core.daggerml import Error, Resource
logger = logging.getLogger(__name__)
try:
import watchtower
[docs]
class SafeCloudWatchLogHandler(watchtower.CloudWatchLogHandler):
def __init__(self, *args, boto3_client=None, **kwargs):
"""Initialize the CloudWatch Log Handler with a safe region detection."""
self._enabled = False
if not boto3_client:
boto3_client = boto3.client("logs", region_name=self._detect_region())
try:
super().__init__(*args, boto3_client=boto3_client, **kwargs)
self._enabled = True
except Exception as e:
logger.warning(f"CloudWatch logging disabled: {e}")
[docs]
def emit(self, record):
if not self._enabled:
return
try:
super().emit(record)
except Exception as e:
logger.debug(f"Failed to emit to CloudWatch: {e}")
def _detect_region(self):
# Priority: AWS_REGION > AWS_DEFAULT_REGION > boto3 config > fail
env_region = os.environ.get("AWS_REGION") or os.environ.get("AWS_DEFAULT_REGION")
if env_region:
return env_region
try:
session = botocore.session.get_session()
return session.get_config_variable("region")
except (BotoCoreError, NoRegionError):
return None
except ModuleNotFoundError:
watchtower = None
def _read_data(file):
"""Read data from a file, stdin, or S3."""
if not isinstance(file, str):
return file.read()
if urlparse(file).scheme == "s3":
return S3Store().get(file).decode()
with open(file) as f:
data = f.read()
return data.strip()
def _write_data(data, to, mode="w"):
"""Write data to a file, stdout, or S3."""
if not isinstance(to, str):
return print(data, file=to, flush=True)
if urlparse(to).scheme == "s3":
return S3Store().put(data.encode(), uri=to)
with open(to, mode) as f:
f.write(data + ("\n" if mode == "a" else ""))
f.flush()
[docs]
class VerboseArgumentParser(ArgumentParser):
[docs]
def error(self, message):
# Customize this however you want
self.print_usage(sys.stderr)
self.exit(
2,
f"\nError: {message}\n\n"
f"Hint: Run with '--help' to see usage and examples.\n"
)
[docs]
@dataclass
class AdapterBase:
"""Base class for DaggerML adapters.
This class provides a CLI interface for executing DaggerML functions iteratively,
passing environment variables along. It supports different adapters for remote
execution, such as AWS Lambda or local runners.
Attributes
----------
ADAPTER : str
The name of the adapter (to be defined in subclasses).
ADAPTERS : dict
"""
ADAPTER = None # to be defined in subclasses
ADAPTERS = {}
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
AdapterBase.ADAPTERS[re.sub(r"adapter$", "", cls.__name__.lower())] = cls
@staticmethod
def _setup(config):
"""Setup logging configuration for the run."""
_config = {
"version": 1,
"disable_existing_loggers": False,
"formatters": {
"simple": {
"format": f"[{config.run_id}] %(levelname)1s %(name)s: %(message)s",
}
},
"handlers": {
"console": {
"class": "logging.StreamHandler",
"formatter": "simple",
"level": (logging.DEBUG if config.debug else logging.WARNING),
}
},
"loggers": {
"dml_util": {
"handlers": ["console"],
"level": logging.DEBUG,
"propagate": False,
},
"": {
"handlers": ["console"],
"level": logging.WARNING,
},
},
}
logging.config.dictConfig(_config)
if watchtower:
# get region from boto3
try:
handler = SafeCloudWatchLogHandler(
log_group_name=config.log_group,
log_stream_name="/adapter",
create_log_stream=True,
create_log_group=False,
level=logging.DEBUG,
)
if handler._enabled:
logging.getLogger("dml_util").addHandler(handler)
logging.getLogger("").addHandler(handler)
logger.debug("added watchtower handler %r", handler)
except Exception as e:
logger.error("Error setting up watchtower handler: %s", e)
@staticmethod
def _teardown():
if watchtower:
for handler in logging.getLogger("dml_util").handlers:
if isinstance(handler, watchtower.CloudWatchLogHandler):
logging.getLogger("").removeHandler(handler)
logger.debug("removing watchtower handler %r", handler)
logging.getLogger("dml_util").removeHandler(handler)
handler.flush()
handler.close()
[docs]
@classmethod
def cli(cls, args=None):
"""
Command-line interface for the adapter.
This method reads input data from a file or stdin, sends it to a remote service
specified by the URI, and writes the response to an output file or stdout.
If an error occurs, it writes the error message to an error file or stderr.
Cli Parameters
--------------
uri : str
URI of the function to invoke.
--input FILE, -i FILE : path, optional (default: STDIN)
Input data file or stdin or s3 location of where to read the dump from.
--output FILE, -o FILE : path, optional (default: STDOUT)
Output location for the response data (can be a file, stdout, or s3 location).
--error FILE, -e FILE : path, optional (default: STDERR)
Error output location (can be a file, stderr, or s3 location).
--n-iters N, -n N : int, optional (default: 1)
Number of iterations to run. Set to 0 to run indefinitely.
--debug : flag, optional
Enables debug logging.
Returns
-------
int
Exit code: 0 on success, 1 on error.
"""
if args is None:
parser = VerboseArgumentParser(description=f"DaggerML {cls.__name__} CLI")
parser.add_argument("uri")
parser.add_argument("-i", "--input", default=sys.stdin)
parser.add_argument("-o", "--output", default=sys.stdout)
parser.add_argument("-e", "--error", default=sys.stderr)
parser.add_argument("-n", "--n-iters", default=1, type=int)
parser.add_argument("--debug", action="store_true")
args = parser.parse_args()
config = EnvConfig.from_env(debug=args.debug)
cls._setup(config)
try:
n_iters = args.n_iters if args.n_iters > 0 else float("inf")
logger.debug("reading data from %r", args.input)
dump = _read_data(args.input)
while n_iters > 0:
resp, msg = cls.send_to_remote(args.uri, config, dump)
_write_data(msg, args.error, mode="a")
if resp:
_write_data(resp, args.output)
return 0
n_iters -= 1
if n_iters > 0:
time.sleep(0.2)
return 0
except Exception as e:
logger.exception("Error in adapter")
try:
_write_data(str(Error(e)), args.error)
except Exception:
logger.exception("cannot write to %r", args.error)
return 1
finally:
cls._teardown()
[docs]
@classmethod
def funkify(cls, uri, data):
return Resource(uri, data=data, adapter=cls.ADAPTER)
[docs]
@classmethod
def send_to_remote(cls, uri, config: EnvConfig, dump: str) -> tuple[str, str]:
"""Send data to a remote service specified by the URI.
Parameters
----------
uri : str
The URI of the remote service.
config : EnvConfig
Configuration for the run, including cache path, cache key, S3 bucket, etc.
dump : str
The opaque blob to send to the remote service.
Returns
-------
tuple[str, str]
A tuple containing the response data and a message. If the response is truthy,
we pass it on to the caller via --output flag.
The message is written to the --error flag.
"""
raise NotImplementedError("send_to_remote not implemented for this adapter")