"""S3 storage utilities."""
import json
import logging
import os
import subprocess
from dataclasses import dataclass, field, replace
from io import BytesIO
from itertools import groupby
from pathlib import Path
from tempfile import NamedTemporaryFile
from typing import Union
from urllib.parse import urlparse
import boto3
from dml_util.aws import get_client
from dml_util.core.daggerml import Node, Resource
from dml_util.core.utils import batched, compute_hash, exactly_one, js_dump
logger = logging.getLogger(__name__)
[docs]
@dataclass
class S3Store:
"""
S3 Store for DML
Parameters
----------
bucket : str
S3 bucket name. Defaults to the value of the environment variable "DML_S3_BUCKET".
prefix : str
S3 prefix. Defaults to the value of the environment variable "DML_S3_PREFIX".
client : boto3.client, optional
Boto3 S3 client. Defaults to a new client created using the `get_client` function.
Notes
-----
- If `prefix` is not provided, "/data" is appended to the `DML_S3_PREFIX` environment variable.
- `prefix` is stripped of leading and trailing slashes, so if you want to use a prefix like "/foo/", you'll need to
handle those uris directly. E.g. to put data at "s3://my-bucket//foo/bar", you would use
`S3Store().put(data, uri="s3://my-bucket//foo/bar")`.
Examples
--------
>>> s3 = S3Store(bucket="my-bucket", prefix="my-prefix")
>>> s3.put(data=b"Hello, World!", name="greeting.txt") # doctest: +SKIP
Resource(uri='s3://my-bucket/my-prefix/greeting.txt')
>>> s3.ls(recursive=True) # doctest: +SKIP
['s3://my-bucket/my-prefix/greeting.txt']
>>> s3.get("greeting.txt") # doctest: +SKIP
b'Hello, World!'
>>> s3.exists("greeting.txt") # doctest: +SKIP
True
>>> s3.rm("greeting.txt") # doctest: +SKIP
>>> s3.exists("greeting.txt") # doctest: +SKIP
False
>>> s3.put_js({"key": "value"}, name="data") # doctest: +SKIP
Resource(uri='s3://my-bucket/my-prefix/data.json')
>>> s3.get_js("data") # doctest: +SKIP
{'key': 'value'}
>>> s3.tar(dml, path="my_data", excludes=["*.tmp"]) # doctest: +SKIP
Resource(uri='s3://my-bucket/my-prefix/my_data.tar')
>>> s3.untar("s3://my-bucket/my-prefix/my_data.tar", dest="my_data") # doctest: +SKIP
# Extracts the tar archive to the local directory "my_data"
>>> s3.cd("new-prefix")
S3Store(bucket='my-bucket', prefix='my-prefix/new-prefix')
>>> s3.cd("..") # Go back to the previous prefix
S3Store(bucket='my-bucket', prefix='')
"""
bucket: str = field(default_factory=lambda: os.getenv("DML_S3_BUCKET"))
prefix: str = None
client: "boto3.client" = field(default_factory=lambda: get_client("s3"), repr=False)
def __post_init__(self):
if self.prefix is None:
self.prefix = os.getenv("DML_S3_PREFIX", "") + "/data"
self.prefix = self.prefix.strip("/")
logger.debug("Initialized S3Store at s3://%s/%s", self.bucket, self.prefix)
[docs]
def parse_uri(self, name_or_uri):
"""
Parse a URI or name into bucket and key.
Examples
--------
>>> s3 = S3Store(bucket="my-bucket", prefix="my-prefix")
>>> s3.parse_uri("s3://my-other-bucket/my-key")
('my-other-bucket', 'my-key')
>>> s3.parse_uri("my-key")
('my-bucket', 'my-prefix/my-key')
>>> s3.parse_uri(Resource("s3://my-other-bucket/my-key"))
('my-other-bucket', 'my-key')
"""
if isinstance(name_or_uri, Node):
name_or_uri = name_or_uri.value()
if isinstance(name_or_uri, Resource):
name_or_uri = name_or_uri.uri
p = urlparse(name_or_uri)
if p.scheme == "s3":
return p.netloc, p.path[1:]
key = f"{self.prefix}/{name_or_uri}" if self.prefix else name_or_uri
return self.bucket, key
def _name2uri(self, name):
bkt, key = self.parse_uri(name)
return f"s3://{bkt}/{key}"
def _ls(self, uri=None, recursive=False):
kw = {}
if not recursive:
kw["Delimiter"] = "/"
bucket, prefix = self.parse_uri(uri or f"s3://{self.bucket}/{self.prefix}")
prefix = prefix.rstrip("/") + "/" if prefix else ""
paginator = self.client.get_paginator("list_objects_v2")
for page in paginator.paginate(Bucket=bucket, Prefix=prefix, **kw):
for obj in page.get("Contents", []):
key = obj["Key"]
uri = f"s3://{bucket}/{key}"
yield uri
[docs]
def ls(self, s3_root=None, *, recursive=False, lazy=False):
"""
List objects in the S3 bucket.
Parameters
----------
s3_root : str, optional
Name or s3 root to list. Defaults to s3://<bucket>/<prefix>/.
recursive : bool
If True, list all objects recursively. Defaults to False.
lazy : bool
If True, return a generator. Defaults to False.
Returns
-------
generator or list
A generator or list of S3 URIs.
"""
resp = self._ls(s3_root, recursive=recursive)
if not lazy:
resp = list(resp)
return resp
[docs]
def exists(self, name_or_uri):
bucket, key = self.parse_uri(name_or_uri)
try:
self.client.head_object(Bucket=bucket, Key=key)
return True
except Exception as e:
if getattr(e, "response", {}).get("Error", {}).get("Code") == "404":
return False
raise
[docs]
def get(self, name_or_uri):
bucket, key = self.parse_uri(name_or_uri)
resp = self.client.get_object(Bucket=bucket, Key=key)
return resp["Body"].read()
[docs]
def put(self, data=None, filepath=None, name=None, uri=None, suffix=None):
exactly_one(data=data, filepath=filepath)
exactly_one(name=name, uri=uri, suffix=suffix)
# TODO: look for registered serdes through python packaging
data = open(filepath, "rb") if data is None else BytesIO(data)
try:
if uri is None and name is None:
name = compute_hash(data) + (suffix or "")
bucket, key = self.parse_uri(uri or name)
self.client.upload_fileobj(data, bucket, key)
return Resource(f"s3://{bucket}/{key}")
finally:
if filepath is not None:
data.close()
[docs]
def put_js(self, data, uri=None, **kw) -> Resource:
suffix = ".json" if uri is None else None
return self.put(js_dump(data, **kw).encode(), uri=uri, suffix=suffix)
[docs]
def get_js(self, uri):
return json.loads(self.get(uri).decode())
[docs]
def tar(self, dml, path, excludes=()):
"""Create a tar archive and store it in S3."""
exclude_flags = [["--exclude", x] for x in excludes]
exclude_flags = [y for x in exclude_flags for y in x]
with NamedTemporaryFile(suffix=".tar") as tmpf:
dml(
"util",
"tar",
*exclude_flags,
str(path),
tmpf.name,
)
return self.put(filepath=tmpf.name, suffix=".tar")
[docs]
def untar(self, tar_uri, dest):
"""Extract a tar archive from S3 to a local directory."""
p = urlparse(tar_uri.uri)
with NamedTemporaryFile(suffix=".tar") as tmpf:
self.client.download_file(p.netloc, p.path[1:], tmpf.name)
subprocess.run(["tar", "-xvf", tmpf.name, "-C", dest], check=True)
[docs]
def rm(self, *name_or_uris: Union[str, Resource, list[Union[str, Resource]]]):
"""Remove objects from S3."""
if len(name_or_uris) == 1 and isinstance(name_or_uris[0], (list, tuple)):
name_or_uris = name_or_uris[0]
if len(name_or_uris) == 0:
return
for bucket, objs in groupby(map(self.parse_uri, sorted(name_or_uris)), key=lambda x: x[0]):
for batch in batched((x[1] for x in objs), 1000):
self.client.delete_objects(
Bucket=bucket,
Delete={"Objects": [{"Key": x} for x in batch]},
)
[docs]
def cd(self, new_prefix) -> "S3Store":
"""Change the prefix of the S3 store."""
root = Path(self.prefix)
new_path = str((root / new_prefix).resolve().relative_to(os.getcwd()))
return replace(self, prefix=new_path if new_path != "." else "")