import json
import logging
import shutil
import subprocess
import time
from dataclasses import dataclass, field, fields
from tempfile import TemporaryDirectory
from traceback import format_exception
from typing import Any, Callable, Optional, Union
from daggerml.util import (
BackoffWithJitter,
current_time_millis,
kwargs2opts,
postwalk,
properties,
raise_ex,
replace,
setter,
)
log = logging.getLogger(__name__)
DATA_TYPE = {}
Scalar = Union[str, int, float, bool, type(None), "Resource", "Node"]
Collection = Union[list, tuple, set, dict]
def dml_type(cls=None, **opts):
def decorator(cls):
DATA_TYPE[opts.get("alias", None) or cls.__name__] = cls
return cls
return decorator(cls) if cls else decorator
def from_data(data):
n, *args = data if isinstance(data, list) else [None, data]
if n is None:
return args[0]
if n == "l":
return [from_data(x) for x in args]
if n == "s":
return {from_data(x) for x in args}
if n == "d":
return {k: from_data(v) for (k, v) in args}
if n in DATA_TYPE:
return DATA_TYPE[n](*[from_data(x) for x in args])
raise ValueError(f"no decoder for type: {n}")
def to_data(obj):
if isinstance(obj, Node):
obj = obj.ref
if isinstance(obj, tuple):
obj = list(obj)
n = obj.__class__.__name__
if isinstance(obj, (type(None), str, bool, int, float)):
return obj
if isinstance(obj, (list, set)):
return [n[0], *[to_data(x) for x in obj]]
if isinstance(obj, dict):
return [n[0], *[[k, to_data(v)] for k, v in obj.items()]]
if n in DATA_TYPE:
return [n, *[to_data(getattr(obj, x.name)) for x in fields(obj)]]
raise ValueError(f"no encoder for type: {n}")
def from_json(text):
return from_data(json.loads(text))
def to_json(obj):
return json.dumps(to_data(obj), separators=(",", ":"))
@dml_type
@dataclass(frozen=True)
class Ref: # noqa: F811
"""
Reference to a DaggerML object.
Parameters
----------
to : str
Reference identifier
"""
to: str
[docs]
@dml_type
@dataclass(frozen=True)
class Resource: # noqa: F811
"""
Representation of an externally managed object with an identifier.
Parameters
----------
uri : str
Resource URI
data : str, optional
Associated data
adapter : str, optional
Resource adapter name
"""
uri: str
data: Optional[str] = None
adapter: Optional[str] = None
[docs]
@dml_type
@dataclass
class Error(Exception): # noqa: F811
"""
Custom error type for DaggerML.
Parameters
----------
message : Union[str, Exception]
Error message or exception
context : dict, optional
Additional error context
code : str, optional
Error code
"""
message: Union[str, Exception]
context: dict = field(default_factory=dict)
code: Optional[str] = None
def __post_init__(self):
if isinstance(self.message, Error):
ex = self.message
self.message = ex.message
self.context = ex.context
self.code = ex.code
elif isinstance(self.message, Exception):
ex = self.message
self.message = str(ex)
self.context = {"trace": format_exception(type(ex), value=ex, tb=ex.__traceback__)}
self.code = type(ex).__name__
else:
self.code = type(self).__name__ if self.code is None else self.code
def __str__(self):
return "".join(self.context.get("trace", [self.message]))
[docs]
@dataclass
class Dml:
"""
DaggerML cli client wrapper
"""
config_dir: Union[str, None] = None
project_dir: Union[str, None] = None
cache_path: Union[str, None] = None
repo: Union[str, None] = None
user: Union[str, None] = None
branch: Union[str, None] = None
token: Union[str, None] = None
tmpdirs: dict[str, TemporaryDirectory] = field(default_factory=dict)
@property
def kwargs(self) -> dict:
out = {
"config_dir": self.config_dir,
"project_dir": self.project_dir,
"cache_path": self.cache_path,
"repo": self.repo,
"user": self.user,
"branch": self.branch,
}
return {k: v for k, v in out.items() if v is not None}
[docs]
@classmethod
def temporary(cls, repo="test", user="user", branch="main", cache_path=None, **kwargs) -> "Dml":
"""
Create a temporary Dml instance with specified parameters.
Parameters
----------
repo : str, default="test"
user : str, default="user"
branch : str, default="main"
**kwargs : dict
Additional keyword arguments for configuration include `config_dir`, `project_dir`, and `cache_path`.
If any of those is provided, it will not create a temporary directory for that parameter. If provided and
set to None, the dml default will be used.
"""
tmpdirs = {k: TemporaryDirectory(prefix="dml-") for k in ["config_dir", "project_dir"] if k not in kwargs}
self = cls(
repo=repo,
user=user,
branch=branch,
cache_path=cache_path,
**{k: v.name for k, v in tmpdirs.items()},
tmpdirs=tmpdirs,
)
if self.kwargs["repo"] not in [x["name"] for x in self("repo", "list")]:
self("repo", "create", self.kwargs["repo"])
return self
[docs]
def cleanup(self):
[x.cleanup() for x in self.tmpdirs.values()]
def __call__(self, *args: str, input=None, as_text: bool = False) -> Any:
path = shutil.which("dml")
argv = [path, *kwargs2opts(**self.kwargs), *args]
resp = subprocess.run(argv, check=False, capture_output=True, text=True, input=input)
if resp.returncode != 0:
raise_ex(Error(resp.stderr or "DML command failed", code="DmlError"))
log.debug("dml command stderr: %s", resp.stderr)
if resp.stderr:
log.error(resp.stderr.rstrip())
try:
resp = resp.stdout or "" if as_text else json.loads(resp.stdout or "null")
except json.decoder.JSONDecodeError:
pass
return resp
def __getattr__(self, name: str):
def invoke(*args, **kwargs):
opargs = to_json([name, args, kwargs])
token = self.token or to_json([])
return raise_ex(from_data(self("api", "invoke", token, input=opargs)))
return invoke
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
self.cleanup()
@property
def envvars(self):
return {f"DML_{k.upper()}": str(v) for k, v in self.kwargs.items()}
[docs]
def new(self, name="", message="", data=None, message_handler=None) -> "Dag":
opts = kwargs2opts(dump="-") if data else []
token = self("api", "create", *opts, name, message, input=data, as_text=True)
return Dag(replace(self, token=token), message_handler)
[docs]
def load(self, name: Union[str, "Node"], recurse=False) -> "Dag":
return Dag(replace(self, token=None), _ref=self.get_dag(name, recurse=recurse))
@dataclass
class Boxed:
value: Any
def make_node(dag: "Dag", ref: Ref) -> "Node":
"""
Create a Node from a Dag and Ref.
Parameters
----------
dag : Dag
The parent DAG.
ref : Ref
The reference to the node.
Returns
-------
Node
A Node instance representing the reference in the DAG.
"""
info = dag._dml("node", "describe", ref.to)
if info["data_type"] == "list":
return ListNode(dag, ref, _info=info)
if info["data_type"] == "dict":
return DictNode(dag, ref, _info=info)
if info["data_type"] == "set":
return ListNode(dag, ref, _info=info)
if info["data_type"] == "resource":
return ResourceNode(dag, ref, _info=info)
return Node(dag, ref, _info=info)
[docs]
@dataclass
class Dag:
_dml: Dml
_message_handler: Optional[Callable] = None
_ref: Optional[Ref] = None
_init_complete: bool = False
def __post_init__(self):
self._init_complete = True
def __hash__(self):
"Useful only for tests."
return 42
def __enter__(self):
"Catch exceptions and commit an Error"
assert not self._ref
return self
def __exit__(self, exc_type, exc_value, traceback):
if exc_value is not None:
self._commit(Error(exc_value))
def __getitem__(self, name) -> "Node":
return make_node(self, self._dml.get_node(name, self._ref))
# return Node(self, self._dml.get_node(name, self._ref))
def __setitem__(self, name, value) -> "Node":
assert not self._ref
if isinstance(value, Ref):
return self._dml.set_node(name, value)
return self._put(value, name=name)
def __len__(self) -> int:
return len(self._dml.get_names(self._ref))
def __iter__(self):
for k in self.keys():
yield k
def __setattr__(self, name, value):
priv = name.startswith("_")
flds = name in {x.name for x in fields(self)}
prps = name in properties(self)
init = not self._init_complete
boxd = isinstance(value, Boxed)
if (flds and init) or (not self._ref and ((not flds and not priv) or prps or boxd)):
value = value.value if boxd else value
if flds or (prps and setter(self, name)):
return super(Dag, self).__setattr__(name, value)
elif not prps:
return self.__setitem__(name, value)
raise AttributeError(f"can't set attribute: '{name}'")
def __getattr__(self, name):
return self.__getitem__(name)
@property
def argv(self) -> "Node":
"Access the dag's argv node"
return make_node(self, self._dml.get_argv(self._ref))
# return Node(self, self._dml.get_argv(self._ref))
@property
def result(self) -> "Node":
ref = self._dml.get_result(self._ref)
assert ref, f"'{self.__class__.__name__}' has no attribute 'result'"
return make_node(self, ref)
# return Node(self, ref) if ref else ref
@result.setter
def result(self, value):
return self._commit(value)
@property
def keys(self) -> list[str]:
return lambda: self._dml.get_names(self._ref).keys()
@property
def values(self) -> list["Node"]:
def result():
nodes = self._dml.get_names(self._ref).values()
return [make_node(self, x) for x in nodes]
return result
def _put(self, value: Union[Scalar, Collection], *, name=None, doc=None) -> "Node":
"""
Add a value to the DAG.
Parameters
----------
value : Union[Scalar, Collection]
Value to add
name : str, optional
Name for the node
doc : str, optional
Documentation
Returns
-------
Node
Node representing the value
"""
value = postwalk(
value,
lambda x: isinstance(x, Node) and x.dag._ref,
lambda x: self._load(x.dag, x.ref),
)
return make_node(self, self._dml.put_literal(value, name=name, doc=doc))
def _load(self, dag_name, node=None, *, name=None, doc=None) -> "Node":
"""
Load a DAG by name.
Parameters
----------
dag_name : str
Name of the DAG to load
name : str, optional
Name for the node
doc : str, optional
Documentation
Returns
-------
Node
Node representing the loaded DAG
"""
dag = dag_name if isinstance(dag_name, str) else dag_name._ref
return make_node(self, self._dml.put_load(dag, node, name=name, doc=doc))
def _commit(self, value) -> "Node":
"""
Commit a value to the DAG.
Parameters
----------
value : Union[Node, Error, Any]
Value to commit
"""
value = value if isinstance(value, (Node, Error)) else self._put(value)
dump = self._dml.commit(value)
if self._message_handler:
self._message_handler(dump)
self._ref = Boxed(Ref(json.loads(dump)[-1][1][1]))
[docs]
@dataclass(frozen=True)
class Node: # noqa: F811
"""
Representation of a node in a DaggerML DAG.
Parameters
----------
dag : Dag
Parent DAG
ref : Ref
Node reference
"""
dag: Dag
ref: Ref
_info: dict = field(default_factory=dict)
def __repr__(self):
ref_id = self.ref if isinstance(self.ref, Error) else self.ref.to
return f"{self.__class__.__name__}({ref_id})"
def __hash__(self):
return hash(self.ref)
@property
def argv(self) -> "Node":
"Access the node's argv list"
return [make_node(self.dag, x) for x in self.dag._dml.get_argv(self)]
[docs]
def load(self, *keys: Union[str, int]) -> Dag:
"""
Convenience wrapper around `dml.load(node)`
If `key` is provided, it considers this node to be a collection created
by the appropriate method and loads the dag that corresponds to this key
Parameters
----------
*keys : str, optional
Key to load from the DAG. If not provided, the entire DAG is loaded.
Returns
-------
Dag
The dag that this node was imported from (or in the case of a function call, this returns the fndag)
Examples
--------
>>> dml = Dml.temporary()
>>> dag = dml.new("test", "test")
>>> l0 = dag._put(42)
>>> c0 = dag._put({"a": 1, "b": [l0, "23"]})
>>> assert c0.load("b", 0) == l0
>>> assert c0.load("b").load(0) == l0
>>> assert c0["b"][0] != l0 # this is a different node, not the same as l0
>>> dml.cleanup()
"""
if len(keys) == 0:
return self.dag._dml.load(self)
data = self.dag._dml("node", "backtrack", self.ref.to, *map(str, keys))
return make_node(self.dag, from_data(data))
@property
def type(self):
""" Get the data type of the node."""
return self._info["data_type"]
[docs]
def value(self):
"""
Get the concrete value of this node.
Returns
-------
Any
The actual value represented by this node
"""
return self.dag._dml.get_node_value(self.ref)
class ResourceNode(Node):
def __call__(self, *args, name=None, doc=None, sleep=None, timeout=0) -> "Node":
"""
Call this node as a function.
Parameters
----------
*args : Any
Arguments to pass to the function
name : str, optional
Name for the result node
doc : str, optional
Documentation
sleep : callable, optional
A nullary function that returns sleep time in milliseconds
timeout : int, default=30000
Maximum time to wait in milliseconds
Returns
-------
Node
Result node
Raises
------
TimeoutError
If the function call exceeds the timeout
Error
If the function returns an error
"""
sleep = sleep or BackoffWithJitter()
args = [self.dag._put(x) for x in args]
end = current_time_millis() + timeout
while timeout <= 0 or current_time_millis() < end:
resp = self.dag._dml.start_fn([self, *args], name=name, doc=doc)
if resp:
return make_node(self.dag, resp)
time.sleep(sleep() / 1000)
raise TimeoutError(f"invoking function: {self.value()}")
class CollectionNode(Node): # noqa: F811
"""
Representation of a collection node in a DaggerML DAG.
Parameters
----------
dag : Dag
Parent DAG
ref : Ref
Node reference
"""
def __getitem__(self, key: Union[slice, str, int, "Node"]) -> "Node":
"""
Get the `key` item. It should be the same as if you were working on the
actual value.
Returns
-------
Node
Node with the length of the collection
Raises
------
Error
If the node isn't a collection (e.g. list, set, or dict).
Examples
--------
>>> dml = Dml.temporary()
>>> dag = dml.new("test", "test")
>>> node = dag._put({"a": 1, "b": [5, 6]})
>>> nested = node["a"]
>>> isinstance(nested, Node)
True
>>> nested.value()
1
>>> node["b"][0].value() # lists too
5
"""
if isinstance(key, slice):
key = [key.start, key.stop, key.step]
return make_node(self.dag, self.dag._dml.get(self, key))
def contains(self, item, *, name=None, doc=None):
"""
For collection nodes, checks to see if `item` is in `self`
Returns
-------
Node
Node with the boolean of is `item` in `self`
"""
return make_node(self.dag, self.dag._dml.contains(self, item, name=name, doc=doc))
def __contains__(self, item):
return self.contains(item).value() # has to return boolean
def __len__(self): # python requires this to be an int
"""
Get the node's length
Returns
-------
Node
Node with the length of the collection
Raises
------
Error
If the node isn't a collection (e.g. list, set, or dict).
"""
if self._info["length"]:
return self._info["length"]
raise Error(f"Cannot get length of type: {self._info['data_type']}")
def get(self, key, default=None, *, name=None, doc=None):
"""
For a dict node, return the value for key if key exists, else default.
If default is not given, it defaults to None, so that this method never raises a KeyError.
"""
return make_node(self.dag, self.dag._dml.get(self, key, default, name=name, doc=doc))
class ListNode(CollectionNode): # noqa: F811
"""
Representation of a collection node in a DaggerML DAG.
Parameters
----------
dag : Dag
Parent DAG
ref : Ref
Node reference
"""
def __iter__(self):
"""
Iterate over the node's values (items if it's a list, and keys if it's a
dict)
Returns
-------
Node
Result node
Raises
------
Error
If the node isn't a collection (e.g. list, set, or dict).
"""
for i in range(len(self)):
yield self[i]
def conj(self, item, *, name=None, doc=None):
"""
For a list or set node, append an item
Returns
-------
Node
Node containing the new collection
Notes
-----
`append` is an alias `conj`
"""
return make_node(self.dag, self.dag._dml.conj(self, item, name=name, doc=doc))
def append(self, item, *, name=None, doc=None):
"""
For a list or set node, append an item
Returns
-------
Node
Node containing the new collection
See Also
--------
conj : The main implementation
"""
return self.conj(item, name=name, doc=doc)
class DictNode(CollectionNode): # noqa: F811
def keys(self) -> list[str]:
"""
Get the keys of a dictionary node.
Parameters
----------
name : str, optional
Name for the result node
doc : str, optional
Documentation
Returns
-------
list[str]
List of keys in the dictionary node
"""
return self._info["keys"].copy()
def __iter__(self):
"""
Iterate over the node's values (items if it's a list, and keys if it's a
dict)
Returns
-------
Node
Result node
Raises
------
Error
If the node isn't a collection (e.g. list, set, or dict).
"""
for k in self.keys():
yield k
def items(self):
"""
Iterate over key-value pairs of a dictionary node.
Returns
-------
Iterator[tuple[Node, Node]]
Iterator over (key, value) pairs
"""
if self.type != "dict":
raise Error(f"Cannot iterate items of type: {self.type}")
for k in self:
yield k, self[k]
def values(self) -> list["Node"]:
"""
Get the values of a dictionary node.
Parameters
----------
name : str, optional
Name for the result node
doc : str, optional
Documentation
Returns
-------
list[Node]
List of values in the dictionary node
"""
return [self[k] for k in self]
def assoc(self, key, value, *, name=None, doc=None):
"""
For a dict node, associate a new value into the map
Returns
-------
Node
Node containing the new dict
"""
return make_node(self.dag, self.dag._dml.assoc(self, key, value, name=name, doc=doc))
def update(self, update):
"""
For a dict node, update like python dicts
Returns
-------
Node
Node containing the new collection
Notes
-----
calls `assoc` iteratively for k, v pairs in update.
See Also
--------
assoc : The main implementation
"""
for k, v in update.items():
self = self.assoc(k, v)
return self