Source code for floweaver.compiler.execute

"""Execute a WeaverSpec against flow data to produce SankeyData.

This module implements the execute_weave() function that takes a compiled
WeaverSpec and executes it against actual flow data (from a Dataset) to
produce SankeyData results.

The spec contains pre-expanded nodes and edges with explicit include/exclude
filters. The executor simply:
1. Filters flows according to each edge's filters
2. Aggregates measures for matching flows
3. Applies display styling (colors, widths)
4. Builds the SankeyData output
"""

from ..sankey_data import SankeyData, SankeyNode, SankeyLink
from ..ordering import Ordering
from .combined_router import route_flows
from .spec import (
    CategoricalColorSpec,
    QuantitativeColorSpec,
)


[docs] def execute_weave(spec, dataset): """Execute a WeaverSpec against flow data to produce SankeyData. Parameters ---------- spec : WeaverSpec The compiled spec with routing tree. dataset : Dataset or DataFrame The flow data. Returns ------- SankeyData The resulting Sankey diagram data with nodes and links. """ # Get the flows table if hasattr(dataset, "_table"): flows = dataset._table else: flows = dataset return _execute_with_routing_tree(spec, flows, dataset)
def _execute_with_routing_tree(spec, flows, dataset): """Execute using the new routing tree system.""" routing_tree = spec.routing_tree # Route all flows to edges edge_flow_map = route_flows(flows, routing_tree) # Aggregate flows for each edge links = [] from_elsewhere = {} # node_id -> list of links to_elsewhere = {} # node_id -> list of links for edge_index, flow_indices in edge_flow_map.items(): edge = spec.edges[edge_index] matching = flows.iloc[flow_indices] if len(matching) > 0: data = _aggregate(matching, spec.measures) link_width = data.get(spec.display.link_width, 0.0) color = _apply_color(edge, data, spec.display) title = _compute_title(edge, spec.bundles) link = SankeyLink( source=edge.source, target=edge.target, type=edge.type, time=edge.time, link_width=link_width, data=data, title=title, color=color, opacity=1.0, original_flows=flow_indices, ) if edge.source is None: from_elsewhere.setdefault(edge.target, []).append(link) elif edge.target is None: to_elsewhere.setdefault(edge.source, []).append(link) else: links.append(link) # Build nodes with elsewhere links # Track nodes that appear in regular edges (degree > 0) nodes_in_regular_edges = set() for link in links: nodes_in_regular_edges.add(link.source) nodes_in_regular_edges.add(link.target) # Track all used nodes (including those with only elsewhere edges) used = set(nodes_in_regular_edges) used.update(from_elsewhere.keys()) used.update(to_elsewhere.keys()) nodes = [] for node_id, node_spec in spec.nodes.items(): if node_id in used: nodes.append( SankeyNode( id=node_id, title=node_spec.title, direction=node_spec.direction, hidden=node_spec.hidden, style=node_spec.style, from_elsewhere_links=from_elsewhere.get(node_id, []), to_elsewhere_links=to_elsewhere.get(node_id, []), ) ) # Build groups # Pass nodes_in_regular_edges to filter out nodes with only elsewhere edges # (matching old behavior where degree-0 nodes are filtered from groups) groups = _build_groups(spec.groups, spec.nodes, nodes_in_regular_edges) # Filter ordering ordering = _filter_ordering(spec.ordering, used) return SankeyData( nodes=nodes, links=links, groups=groups, ordering=ordering, dataset=dataset if hasattr(dataset, "_table") else None, ) def _aggregate(df, measures): """Aggregate flow data according to measure specifications. Parameters ---------- df : DataFrame The matching flows. measures : list of MeasureSpec Measure specifications with column names and aggregation functions. Returns ------- dict Aggregated values keyed by column name. """ result = {} for m in measures: col = m.column if col not in df.columns: result[col] = 0.0 continue if m.aggregation == "sum": result[col] = df[col].sum() elif m.aggregation == "mean": result[col] = df[col].mean() else: raise ValueError(f"Unknown aggregation: {m.aggregation}") return result def _apply_color(edge, data, display_spec): """Compute the color for a link based on the display spec. Parameters ---------- edge : EdgeSpec The edge specification. data : dict Aggregated measure values. display_spec : DisplaySpec Display configuration with color spec. Returns ------- str Hex color string. """ color_spec = display_spec.link_color if isinstance(color_spec, CategoricalColorSpec): attr = color_spec.attribute if attr == "type": value = edge.type elif attr == "source": value = edge.source elif attr == "target": value = edge.target elif attr == "time": value = edge.time else: # Assume it's a measure value = data.get(attr) return color_spec.lookup.get(str(value), color_spec.default) elif isinstance(color_spec, QuantitativeColorSpec): value = data.get(color_spec.attribute, 0.0) if color_spec.intensity is not None: intensity_value = data.get(color_spec.intensity, 1.0) if intensity_value != 0: value = value / intensity_value domain = color_spec.domain if domain[1] != domain[0]: normed = (value - domain[0]) / (domain[1] - domain[0]) else: normed = 0.5 # Clamp to [0, 1] normed = max(0.0, min(1.0, normed)) return _interpolate_color(color_spec.palette, normed) else: return "#cccccc" def _interpolate_color(palette, t): """Interpolate a color from a palette. Parameters ---------- palette : list of str List of hex color strings. t : float Value in [0, 1] for interpolation. Returns ------- str Hex color string. """ if not palette: return "#cccccc" # Map t to palette index idx = t * (len(palette) - 1) lo = int(idx) hi = min(lo + 1, len(palette) - 1) if lo == hi: return palette[lo] # Linear interpolation between adjacent colors frac = idx - lo c_lo = _hex_to_rgb(palette[lo]) c_hi = _hex_to_rgb(palette[hi]) r = int(c_lo[0] + frac * (c_hi[0] - c_lo[0])) g = int(c_lo[1] + frac * (c_hi[1] - c_lo[1])) b = int(c_lo[2] + frac * (c_hi[2] - c_lo[2])) return f"#{r:02x}{g:02x}{b:02x}" def _hex_to_rgb(hex_color): """Convert hex color to RGB tuple.""" hex_color = hex_color.lstrip("#") return tuple(int(hex_color[i : i + 2], 16) for i in (0, 2, 4)) def _compute_title(edge, bundle_specs): """Compute link title from bundle provenance. Parameters ---------- edge : EdgeSpec The edge specification. bundle_specs : list of BundleSpec Bundle specifications for provenance. Returns ------- str Title string (typically the flow type). """ # For now, just use the type as title # In future could incorporate bundle info return edge.type def _build_groups(group_specs, node_specs, used_nodes): """Build groups in the format expected by SankeyData. Parameters ---------- group_specs : list of GroupSpec Group specifications from the WeaverSpec. node_specs : dict Mapping of node IDs to NodeSpec objects. used_nodes : set Set of node IDs that are actually used. Returns ------- list of dict Groups in SankeyData format. """ groups = [] for g in group_specs: # Filter to only include groups with nodes that are used nodes_in_group = [n for n in g.nodes if n in used_nodes] if len(nodes_in_group) == 0: # Skip empty groups continue # Determine group type from the first node's type # (all nodes in a group have the same type since they come from the same ProcessGroup/Waypoint) group_type = node_specs[nodes_in_group[0]].type # Only include groups with more than one node, or where the group # title is different from the node title # Logic from results_graph.py:99 # Treat empty string as equivalent to None - use group id for comparison if len(nodes_in_group) == 1: node_title = node_specs[nodes_in_group[0]].title group_title = g.title if g.title else g.id include = node_title != group_title else: include = True if include: groups.append( { "id": g.id, "title": g.title if g.title is not None else "", "type": group_type, "nodes": nodes_in_group, } ) return groups def _filter_ordering(ordering, used_nodes): """Filter ordering to only include used nodes. Parameters ---------- ordering : list of list of list of str The ordering from the spec. used_nodes : set Set of node IDs that are actually used. Returns ------- Ordering Filtered ordering. """ filtered = [] for layer in ordering: filtered_layer = [] for band in layer: filtered_band = [n for n in band if n in used_nodes] filtered_layer.append(filtered_band) if any(band for band in filtered_layer): filtered.append(filtered_layer) return Ordering(filtered)