"""
Implementation of a Lambda function that runs a job on AWS Batch.
Environment variables:
- CPU_QUEUE: The name of the CPU job queue.
- GPU_QUEUE: The name of the GPU job queue.
- BATCH_TASK_ROLE_ARN: The ARN of the IAM role for Batch tasks.
"""
import logging
import os
from typing import TYPE_CHECKING, Optional
from botocore.exceptions import ClientError
from dml_util.aws import get_client
from dml_util.runners.lambda_ import LambdaRunner
if TYPE_CHECKING:
import boto3
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
DFLT_PROP = {"vcpus": 1, "memory": 512}
PENDING_STATES = ["SUBMITTED", "PENDING", "RUNNABLE", "STARTING", "RUNNING"]
SUCCESS_STATE = "SUCCEEDED"
FAILED_STATE = "FAILED"
[docs]
class BatchRunner(LambdaRunner):
_client: Optional["boto3.client"] = None
@property
def client(self):
"""Return the AWS Batch client."""
if self._client is None:
self._client = get_client("batch")
return self._client
[docs]
def submit(self):
sub_adapter, sub_uri, sub_kwargs = self.input.get_sub()
kw = self.input.kwargs.copy()
kw.pop("sub")
image = kw.pop("image")["uri"]
container_props = DFLT_PROP
container_props.update(kw)
needs_gpu = any(x["type"] == "GPU" for x in container_props.get("resourceRequirements", []))
logger.info("createing job definition with name: %r", f"fn-{self.input.cache_key}")
response = self.client.register_job_definition(
jobDefinitionName=f"fn-{self.input.cache_key}",
type="container",
containerProperties={
"image": image,
"command": [
sub_adapter,
"-n",
"-1",
"-i",
self.s3.put(sub_kwargs.encode(), name="input.dump").uri,
"-o",
self.s3._name2uri("output.dump"),
"-e",
self.s3._name2uri("error.dump"),
sub_uri,
],
"environment": [
*[{"name": k, "value": v} for k, v in self.config.to_envvars().items()],
],
"jobRoleArn": os.environ["BATCH_TASK_ROLE_ARN"],
**container_props,
},
)
job_def = response["jobDefinitionArn"]
logger.info("created job definition with arn: %r", job_def)
response = self.client.submit_job(
jobName=f"fn-{self.input.cache_key}",
jobQueue=os.environ["GPU_QUEUE" if needs_gpu else "CPU_QUEUE"],
jobDefinition=job_def,
)
logger.info("Job submitted: %r", response["jobId"])
job_id = response["jobId"]
return {"job_def": job_def, "job_id": job_id}
[docs]
def describe_job(self, state):
job_id = state["job_id"]
response = self.client.describe_jobs(jobs=[job_id])
logger.info(
"Job %r (input.cache_key: %r) description: %r",
job_id,
self.input.cache_key,
response,
)
if len(response) == 0:
return None, None
job = response["jobs"][0]
self.job_desc = job
status = job["status"]
return job_id, status
[docs]
def update(self, state):
if state == {}:
state = self.submit()
job_id = state["job_id"]
return state, f"{job_id = } submitted", {}
job_id, status = self.describe_job(state)
msg = f"{job_id = } {status}"
logger.info(msg)
if status in PENDING_STATES:
return state, msg, {}
if self.s3.exists("error.dump"):
err = self.s3.get("error.dump").decode()
logger.info("%r found with content: %r", self.s3._name2uri("error.dump"), err)
msg += f"\n\n{err}"
if status == SUCCESS_STATE and self.s3.exists("output.dump"):
logger.info("job finished successfully and output was written...")
js = self.s3.get("output.dump").decode()
logger.info("dump = %r", js)
return None, msg, js
if not self.s3.exists("output.dump"):
msg = f"{msg} (no output found)"
logger.info("file: %r does not exist", self.s3._name2uri("output.dump"))
if "statusReason" in self.job_desc:
msg = f"{msg} (reason: {self.job_desc['statusReason']})"
logger.info(msg)
raise RuntimeError(f"{msg = }")
[docs]
def gc(self, state):
super().gc(state)
if state:
job_id, status = self.describe_job(state)
try:
self.client.cancel_job(jobId=job_id, reason="gc")
except ClientError:
pass
job_def = state["job_def"]
try:
self.client.deregister_job_definition(jobDefinition=job_def)
logger.info("Successfully deregistered: %r", job_def)
except ClientError as e:
if e.response.get("Error", {}).get("Code") != "ClientException":
raise
if "DEREGISTERED" not in e.response.get("Error", {}).get("Message"):
raise