Source code for dml_util.aws.dynamodb

"""DynamoDB state management.

This module provides state management functionality using AWS DynamoDB.
It includes classes for storing and retrieving state with locking mechanisms
to ensure data consistency in distributed environments.
"""

import json
import logging
import os
from dataclasses import dataclass, field
from uuid import uuid4

import boto3

from dml_util.aws import get_client
from dml_util.core.state import TIMEOUT, State
from dml_util.core.utils import js_dump, now

logger = logging.getLogger(__name__)


[docs] @dataclass class DynamoState(State): """DynamoDB-based state management. This class implements state management using AWS DynamoDB. It provides methods for storing and retrieving state with locking capabilities to ensure data consistency in distributed environments. Parameters ---------- cache_key : str Unique identifier for the state record in DynamoDB. run_id : str, optional Unique identifier for the current run, used for locking. Defaults to a random UUID. timeout : int, optional Lock timeout in seconds. Defaults to TIMEOUT (5 seconds). db : boto3.client, optional DynamoDB client. Defaults to a new client created using get_client. tb : str, optional DynamoDB table name. Defaults to the value of the environment variable DYNAMODB_TABLE. """ cache_key: str run_id: str = field(default_factory=lambda: uuid4().hex) timeout: int = field(default=TIMEOUT) db: "boto3.client" = field(default_factory=lambda: get_client("dynamodb")) tb: str = field(default_factory=lambda: os.getenv("DYNAMODB_TABLE")) def _update(self, key=None, **kw): try: return self.db.update_item( TableName=self.tb, Key={"cache_key": {"S": key or self.cache_key}}, **kw, ) except Exception as e: if getattr(e, "response", {}).get("Error", {}).get("Code") == "ConditionalCheckFailedException": logger.info("could not update %r (invalid lock)", self.cache_key) return raise
[docs] def get(self, key=None): """ returns: None if could not acquire lock {} if there's no data data otherwise """ logger.info("acquiring lock for %r", self.cache_key) ut = now() resp = self._update( key, UpdateExpression="SET #lk = :lk, #ut = :ut", ConditionExpression="attribute_not_exists(#lk) OR #lk = :lk OR #ut < :to", ExpressionAttributeNames={ "#lk": "lock_key", "#ut": "update_time", }, ExpressionAttributeValues={ ":lk": {"S": self.run_id}, ":ut": {"N": str(ut)}, ":to": {"N": str(ut - self.timeout)}, }, ReturnValues="ALL_NEW", ) if resp is None: return obj = resp["Attributes"].get("obj", {}) return obj and json.loads(obj["S"])
[docs] def put(self, obj): logger.info("putting data for %r", self.cache_key) resp = self._update( UpdateExpression="SET #obj = :obj, #ut = :ut", ConditionExpression="#lk = :lk", ExpressionAttributeNames={ "#lk": "lock_key", "#obj": "obj", "#ut": "update_time", }, ExpressionAttributeValues={ ":lk": {"S": self.run_id}, ":obj": {"S": js_dump(obj)}, ":ut": {"N": str(round(now(), 2))}, }, ) return resp is not None
[docs] def unlock(self, key=None): logger.info("releasing lock for %r", self.cache_key) try: resp = self._update( key, UpdateExpression="REMOVE #lk", ConditionExpression="#lk = :lk", ExpressionAttributeNames={"#lk": "lock_key"}, ExpressionAttributeValues={":lk": {"S": self.run_id}}, ) return resp is not None except Exception: pass
[docs] def delete(self): try: return self.db.delete_item( TableName=self.tb, Key={"cache_key": {"S": self.cache_key}}, ConditionExpression="#lk = :lk", ExpressionAttributeNames={"#lk": "lock_key"}, ExpressionAttributeValues={":lk": {"S": self.run_id}}, ) except Exception as e: if getattr(e, "response", {}).get("Error", {}).get("Code") != "ConditionalCheckFailedException": raise