Source code for dml_util.core.utils

"""Core utilities with no external dependencies.

This module contains utility functions that are used throughout the DaggerML utilities
package. These functions have minimal or no external dependencies on DaggerML itself,
making them suitable for use in environments where DaggerML is not available.

Functions
---------
tree_map
    Apply a function to elements of a nested data structure that satisfy a predicate.
dict_product
    Generate all combinations of dictionary values for parameter sweeps.
now
    Get current time.
_run_cli
    Run a command-line program and capture output.
if_read_file
    Read a file if it exists.
proc_exists
    Check if a process exists.
js_dump
    Dump data as JSON with consistent formatting.
compute_hash
    Compute hash of a file-like object.
exactly_one
    Ensure exactly one of multiple parameters is provided.
"""

import errno
import hashlib
import json
import logging
import os
import subprocess
from itertools import islice, product
from time import time

logger = logging.getLogger(__name__)


[docs] def tree_map(predicate, fn, item): if predicate(item): item = fn(item) if isinstance(item, list): return [tree_map(predicate, fn, x) for x in item] if isinstance(item, dict): return {k: tree_map(predicate, fn, v) for k, v in item.items()} return item
[docs] def dict_product(d): """ Given a dictionary of lists, yield all possible combinations of the lists. Good for grid searches. Parameters ---------- d : dict A dictionary where the keys are strings and the values are lists. The keys represent the names of the parameters, and the values are the possible values for those parameters. Yields ------ dict A dictionary representing a single combination of parameter values. The keys are the same as the input dictionary, and the values are the corresponding values from the input lists. Examples -------- >>> d = {'a': [1, 2], 'b': ['x', 'y']} >>> for combination in dict_product(d): ... print(combination) {'a': 1, 'b': 'x'} {'a': 1, 'b': 'y'} {'a': 2, 'b': 'x'} {'a': 2, 'b': 'y'} """ keys = list(d.keys()) for combination in product(*d.values()): yield dict(zip(keys, combination))
[docs] def now(): return time()
def _run_cli(command, capture_output=True, check=True, **kw): result = subprocess.run(command, capture_output=capture_output, text=True, check=False, **kw) logger.debug("command: %r", command) for line in (result.stderr or "").splitlines(): if line: logger.debug("stderr: %r", line) logger.debug("end STDERR for command: %r", command) if result.returncode != 0: msg = f"_run_cli: {command}\n{result.returncode = }" if capture_output: msg += f"\n{result.stdout}\n\n{result.stderr}" if check: raise RuntimeError(msg) return return (result.stdout or "").strip()
[docs] def if_read_file(path): if os.path.exists(path): with open(path) as f: return f.read()
[docs] def proc_exists(pid): try: # Check if the process exists os.kill(pid, 0) except ProcessLookupError: return False # No such process except PermissionError: return True # Exists but we don't have permission try: # Check if it's a zombie process (POSIX only) _, status = os.waitpid(pid, os.WNOHANG) if status != 0: return False # It's a zombie or has exited except ChildProcessError: pass # Not our child process; can't wait on it except OSError as e: if e.errno != errno.ECHILD: raise # Unexpected error return True
[docs] def js_dump(data, **kw): return json.dumps(data, sort_keys=True, separators=(",", ":"), **kw)
[docs] def compute_hash(obj, chunk_size=8192, hash_algorithm="sha256"): hash_fn = hashlib.new(hash_algorithm) while chunk := obj.read(chunk_size): hash_fn.update(chunk) obj.seek(0) return hash_fn.hexdigest()
[docs] def batched(iterable, n, *, strict=False): # batched('ABCDEFG', 3) → ABC DEF G if n < 1: raise ValueError('n must be at least one') iterator = iter(iterable) while batch := tuple(islice(iterator, n)): if strict and len(batch) != n: raise ValueError('batched(): incomplete batch') yield batch
[docs] def exactly_one(**kw): keys = [k for k, v in kw.items() if v is not None] if len(keys) == 0: msg = f"must specify one of: {sorted(kw.keys())}" raise ValueError(msg) if len(keys) > 1: msg = f"must specify only one of: {sorted(kw.keys())} but {keys} are all not None" raise ValueError(msg)