Source code for floweaver.sankey_definition

from __future__ import annotations
from textwrap import dedent
from pprint import pformat
from collections import OrderedDict
from collections.abc import Iterable, Mapping
from typing import Any, TypeAlias, TypeVar

import attrs
from attrs import field

from . import sentinel
from .ordering import Ordering
from .utils import pairwise

from typing import Literal, TYPE_CHECKING

if TYPE_CHECKING:
    from .partition import Partition


# adapted from https://stackoverflow.com/a/47663099/1615465
def no_default_vals_in_repr(cls):
    """Class decorator on top of attr.s that omits attributes from repr that
    have their default value"""

    defaults = OrderedDict()
    for attribute in cls.__attrs_attrs__:
        if hasattr(attribute.default, "factory"):
            assert not attribute.default.takes_self, "not implemented"
            defaults[attribute.name] = attribute.default.factory()
        else:
            defaults[attribute.name] = attribute.default

    def repr_(self):
        real_cls = self.__class__
        qualname = getattr(real_cls, "__qualname__", None)
        if qualname is not None:
            class_name = qualname.rsplit(">.", 1)[-1]
        else:
            class_name = real_cls.__name__
        attributes = defaults.keys()
        return "{0}({1})".format(
            class_name,
            ", ".join(
                name + "=" + repr(getattr(self, name))
                for name in attributes
                if getattr(self, name) != defaults[name]
            ),
        )

    cls.__repr__ = repr_
    return cls


# SankeyDefinition


def _convert_bundles_to_dict(bundles):
    if not isinstance(bundles, dict):
        bundles = {k: v for k, v in enumerate(bundles)}
    return bundles


def _convert_ordering(ordering):
    if isinstance(ordering, Ordering):
        return ordering
    else:
        return Ordering(ordering)


def _validate_bundles(instance, attribute, bundles):
    # Check bundles
    for k, b in bundles.items():
        if not b.from_elsewhere:
            if b.source not in instance.nodes:
                raise ValueError('Unknown source "{}" in bundle {}'.format(b.source, k))
            if not isinstance(instance.nodes[b.source], ProcessGroup):
                raise ValueError("Source of bundle {} is not a process group".format(k))
        if not b.to_elsewhere:
            if b.target not in instance.nodes:
                raise ValueError('Unknown target "{}" in bundle {}'.format(b.target, k))
            if not isinstance(instance.nodes[b.target], ProcessGroup):
                raise ValueError("Target of bundle {} is not a process group".format(k))
        for u in b.waypoints:
            if u not in instance.nodes:
                raise ValueError('Unknown waypoint "{}" in bundle {}'.format(u, k))
            if not isinstance(instance.nodes[u], Waypoint):
                raise ValueError(
                    'Waypoint "{}" of bundle {} is not a waypoint'.format(u, k)
                )


def _validate_ordering(instance, attribute, ordering):
    for layer_bands in ordering.layers:
        for band_nodes in layer_bands:
            for u in band_nodes:
                if u not in instance.nodes:
                    raise ValueError('Unknown node "{}" in ordering'.format(u))


# Historically a mix of ints (for automatic keys representing conversion from a
# list) and strings (for implicit elsewhere bundles, and user-provided
# dictionaries) are used. Strictly this could probably be more flexible and
# allow other hashable, sortable, serialisable types.
BundleID: TypeAlias = str | int
BundleID_T = TypeVar("BundleID_T", bound=BundleID)


[docs] @attrs.define(slots=True, frozen=True) class SankeyDefinition(object): nodes: dict[str, ProcessGroup | Waypoint] bundles: dict[BundleID, Bundle] = field( converter=_convert_bundles_to_dict, validator=_validate_bundles ) ordering: Ordering = field( converter=_convert_ordering, validator=_validate_ordering ) flow_selection: str | None = None flow_partition: Partition | None = None time_partition: Partition | None = None # Define this explicitly to help type checkers def __init__( self, nodes: dict[str, ProcessGroup | Waypoint], bundles: Iterable[Bundle] | Mapping[Any, Bundle], ordering: Ordering | Iterable, flow_selection: str | None = None, flow_partition: Partition | None = None, time_partition: Partition | None = None, ): self.__attrs_init__( # type: ignore nodes, _convert_bundles_to_dict(bundles), _convert_ordering(ordering), flow_selection, flow_partition, time_partition, ) def copy(self): return self.__class__( self.nodes.copy(), self.bundles.copy(), self.ordering, self.flow_selection, self.flow_partition, self.time_partition, ) def to_code(self): nodes = "\n".join( " %s: %s," % (repr(k), pformat(v)) for k, v in self.nodes.items() ) ordering = "\n".join( " %s," % repr([list(x) for x in layer]) for layer in self.ordering.layers # convert to list just because it looks neater ) bundles = "\n".join( " %s," % pformat(bundle) for bundle in self.bundles.values() ) if self.flow_selection is not None: flow_selection = "flow_selection = %s\n\n" % pformat(self.flow_selection) else: flow_selection = "" if self.flow_partition is not None: flow_partition = "flow_partition = %s\n\n" % pformat(self.flow_partition) else: flow_partition = "" if self.time_partition is not None: time_partition = "time_partition = %s\n\n" % pformat(self.time_partition) else: time_partition = "" code = dedent(""" from floweaver import ( ProcessGroup, Waypoint, Partition, Group, Elsewhere, Bundle, SankeyDefinition, ) nodes = { %s } ordering = [ %s ] bundles = [ %s ] %s%s%ssdd = SankeyDefinition(nodes, bundles, ordering%s%s%s) """) % ( nodes, ordering, bundles, flow_selection, flow_partition, time_partition, (", flow_selection=flow_selection" if flow_selection else ""), (", flow_partition=flow_partition" if flow_partition else ""), (", time_partition=time_parititon" if time_partition else ""), ) return code
# ProcessGroup def _validate_direction(instance, attribute, value): if value not in "LR": raise ValueError("direction must be L or R")
[docs] @no_default_vals_in_repr @attrs.define(slots=True) class ProcessGroup(object): """A ProcessGroup represents a group of processes from the underlying dataset. The processes to include are defined by the `selection`. By default they are all lumped into one node in the diagram, but by defining a `partition` this can be controlled. Attributes ---------- selection : list or string If a list of strings, they are taken as process ids. If a single string, it is taken as a Pandas query string run against the process table. partition : Partition, optional Defines how to split the ProcessGroup into subgroups. direction : 'R' or 'L' Direction of flow, default 'R' (left-to-right). title : string, optional Label for the ProcessGroup. If not set, the ProcessGroup id will be used. """ selection: list[str] | str partition: Partition | None = None direction: Literal["R", "L"] = field(validator=_validate_direction, default="R") title: str | None = field( default=None, validator=attrs.validators.optional(attrs.validators.instance_of(str)), ) # Define this explicitly to help type checkers def __init__( self, selection: Iterable[str] | str, partition: Partition | None = None, direction: Literal["R", "L"] = "R", title: str | None = None, ): if not isinstance(selection, str): selection = list(selection) self.__attrs_init__( # type: ignore selection, partition, direction, title )
# Waypoint
[docs] @no_default_vals_in_repr @attrs.define(slots=True) class Waypoint(object): """A Waypoint represents a control point along a :class:`Bundle` of flows. There are two reasons to define Waypoints: to control the routing of :class:`Bundle` s of flows through the diagram, and to split flows according to some attributes by setting a `partition`. Attributes ---------- partition : Partition, optional Defines how to split the Waypoint into subgroups. direction : 'R' or 'L' Direction of flow, default 'R' (left-to-right). title : string, optional Label for the Waypoint. If not set, the Waypoint id will be used. """ partition: Partition | None = None direction: Literal["R", "L"] = field(validator=_validate_direction, default="R") title: str | None = field( default=None, validator=attrs.validators.optional(attrs.validators.instance_of(str)), )
# Bundle Elsewhere = sentinel.create("Elsewhere") def _validate_flow_selection(instance, attribute, value): if instance.source == instance.target and not value: raise ValueError( "flow_selection is required for bundle with same source and target" )
[docs] @no_default_vals_in_repr @attrs.define(frozen=True, slots=True) class Bundle(object): """A Bundle represents a set of flows between two :class:`ProcessGroup`s. Attributes ---------- source : string The id of the :class:`ProcessGroup` at the start of the Bundle. target : string The id of the :class:`ProcessGroup` at the end of the Bundle. waypoints : list of strings Optional list of ids of :class:`Waypoint`s the Bundle should pass through. flow_selection : string, optional Query string to filter the flows included in this Bundle. flow_partition : Partition, optional Defines how to split the flows in the Bundle into sub-flows. Often you want the same Partition for all the Bundles in the diagram, see :attr:`SankeyDefinition.flow_partition`. default_partition : Partition, optional Defines the Partition applied to any Waypoints automatically added to route the Bundle across layers of the diagram. """ source: str | Elsewhere target: str | Elsewhere waypoints: tuple[str, ...] = field(default=attrs.Factory(tuple), converter=tuple) flow_selection: list[str] | str | None = field( default=None, validator=_validate_flow_selection ) flow_partition: Partition | None = None default_partition: Partition | None = None # Define this explicitly to help type checkers def __init__( self, source: str | Elsewhere, target: str | Elsewhere, waypoints: Iterable[str] = (), flow_selection: list[str] | str | None = None, flow_partition: Partition | None = None, default_partition: Partition | None = None, ): self.__attrs_init__( # type: ignore source, target, tuple(waypoints), flow_selection, flow_partition, default_partition, ) @property def to_elsewhere(self): """True if the target of the Bundle is Elsewhere (outside the system boundary).""" return self.target is Elsewhere @property def from_elsewhere(self): """True if the source of the Bundle is Elsewhere (outside the system boundary).""" return self.source is Elsewhere @property def segments(self) -> tuple[str]: """Tuple of pairwise node ids making up the bundle's segments. e.g. ((source, waypoint1), (waypoint1, waypoint2), ... (waypointN, target))""" nodes = (self.source,) + self.waypoints + (self.target,) return tuple(pairwise(nodes))