Source code for retentioneering.preprocessing_graph.preprocessing_graph

from __future__ import annotations

import json
from typing import Any, List, Literal, Optional, TypedDict, cast

import networkx
from IPython.core.display import HTML, DisplayHandle, display
from pydantic import ValidationError

from retentioneering.backend import JupyterServer, ServerManager
from retentioneering.backend.callback import list_dataprocessor, list_dataprocessor_mock
from retentioneering.eventstream.types import EventstreamType
from retentioneering.exceptions.server import ServerErrorWithResponse
from retentioneering.exceptions.widget import WidgetParseError
from retentioneering.preprocessing_graph.nodes import (
    EventsNode,
    MergeNode,
    Node,
    SourceNode,
    build_node,
)
from retentioneering.templates import PreprocessingGraphRenderer


class NodeData(TypedDict):
    name: str
    pk: str
    description: Optional[str]
    processor: Optional[dict]


class NodeLink(TypedDict):
    source: str
    target: str


class Payload(TypedDict):
    directed: bool
    nodes: list[NodeData]
    links: list[NodeLink]


class CombineHandlerPayload(TypedDict):
    node_pk: str


class FieldErrorDesc(TypedDict):
    field: str
    msg: str


class CreateNodeErrorDesc(TypedDict):
    type: Literal["node_error"]
    node_pk: str
    msg: Optional[str]
    fields_errors: List[FieldErrorDesc]


[docs]class PreprocessingGraph: """ Collection of methods for preprocessing graph construction and calculation. Parameters ---------- source_stream : EventstreamType Source eventstream. Notes ----- See :doc:`Preprocessing user guide</user_guides/preprocessing>` for the details. """ root: SourceNode combine_result: EventstreamType | None _ngraph: networkx.DiGraph __server_manager: ServerManager | None = None __server: JupyterServer | None = None def __init__(self, source_stream: EventstreamType) -> None: self.root = SourceNode(source=source_stream) self.combine_result = None self._ngraph = networkx.DiGraph() self._ngraph.add_node(self.root)
[docs] def add_node(self, node: Node, parents: List[Node]) -> None: """ Add node to ``PreprocessingGraph`` instance. Parameters ---------- node : Node An instance of either ``EventsNode`` or ``MergeNode``. parents : list of Nodes - If ``node`` is ``EventsNode`` - only 1 parent must be defined. - If ``node`` is ``MergeNode`` - at least 2 parents have to be defined. Returns ------- None See Also -------- PreprocessingGraph.combine : Start PreprocessingGraph recalculation. .EventsNode : Regular nodes of a preprocessing graph. .MergeNode : Merge nodes of a preprocessing graph. """ self.__valiate_already_exists(node) self.__validate_not_found(parents) if node.events is not None: self.__validate_schema(node.events) if not isinstance(node, MergeNode) and len(parents) > 1: raise ValueError("multiple parents are only allowed for merge nodes!") self._ngraph.add_node(node) for parent in parents: self._ngraph.add_edge(parent, node)
[docs] def combine(self, node: Node) -> EventstreamType: """ Run calculations from the ``SourceNode`` up to the specified ``node``. Parameters ---------- node : Node Instance of either ``SourceNode``, ``EventsNode`` or ``MergeNode``. Returns ------- EventstreamType ``Eventstream`` with all changes applied by data processors. """ self.__validate_not_found([node]) if isinstance(node, SourceNode): return node.events.copy() if isinstance(node, EventsNode): return self._combine_events_node(node) return self._combine_merge_node(node)
def _combine_events_node(self, node: EventsNode) -> EventstreamType: parent = self._get_events_node_parent(node) parent_events = self.combine(parent) events = node.processor.apply(parent_events) parent_events._join_eventstream(events) return parent_events def _combine_merge_node(self, node: MergeNode) -> EventstreamType: parents = self._get_merge_node_parents(node) curr_eventstream: Optional[EventstreamType] = None for parent_node in parents: if curr_eventstream is None: curr_eventstream = self.combine(parent_node) else: new_eventstream = self.combine(parent_node) curr_eventstream.append_eventstream(new_eventstream) node.events = curr_eventstream return cast(EventstreamType, curr_eventstream)
[docs] def get_parents(self, node: Node) -> List[Node]: """ Show parents of the specified ``node``. Parameters ---------- node : Node Instance of one of the classes SourceNode, EventsNode or MergeNode. Returns ------- list of Nodes """ self.__validate_not_found([node]) parents: List[Node] = [] for parent in self._ngraph.predecessors(node): parents.append(parent) return parents
def _get_merge_node_parents(self, node: MergeNode) -> List[Node]: parents = self.get_parents(node) if len(parents) == 0: raise ValueError("orphan merge node!") return parents def _get_events_node_parent(self, node: EventsNode) -> Node: parents = self.get_parents(node) if len(parents) > 1: raise ValueError("invalid graph: events node has more than 1 parent") return parents[0] def __validate_schema(self, eventstream: EventstreamType) -> bool: return self.root.events.schema.is_equal(eventstream.schema) def __valiate_already_exists(self, node: Node) -> None: if node in self._ngraph.nodes: raise ValueError("node already exists!") def __validate_not_found(self, nodes: List[Node]) -> None: for node in nodes: if node not in self._ngraph.nodes: raise ValueError("node not found!")
[docs] def display(self, width: int = 960, height: int = 600) -> DisplayHandle: """ Show constructed ``PreprocessingGraph``. Parameters ---------- width : int, default 960 Width of plot in pixels. height : int, default 600 Height of plot in pixels. Returns ------- Rendered preprocessing graph. """ if not self.__server_manager: self.__server_manager = ServerManager() if not self.__server: self.__server = self.__server_manager.create_server() self.__server.register_action("list-dataprocessor-mock", list_dataprocessor_mock) self.__server.register_action("list-dataprocessor", list_dataprocessor) self.__server.register_action("set-graph", self._set_graph_handler) self.__server.register_action("get-graph", self.export) self.__server.register_action("combine", self._combine_handler) render = PreprocessingGraphRenderer() return display( HTML( render.show( server_id=self.__server.pk, env=self.__server_manager.check_env(), width=width, height=height ) ) )
[docs] def export(self, payload: dict[str, Any]) -> dict: """ Show ``PreprocessingGraph`` as a dict. Parameters ---------- payload : dict Returns ------- dict """ source, target, link = "source", "target", "links" graph = self._ngraph data = { "directed": graph.is_directed(), "nodes": [n.export() for n in graph], link: [{source: u.pk, target: v.pk} for u, v, d in graph.edges(data=True)], } return data
def _export_to_json(self) -> str: data = self.export(payload=dict()) return json.dumps(data) def _combine_handler(self, payload: CombineHandlerPayload) -> None: node = self._find_node(payload["node_pk"]) if not node: raise ServerErrorWithResponse(message="node not found!", type="unexpected_error") self.combine_result = self.combine(node) def _set_graph_handler(self, payload: Payload) -> dict: current_graph = self._ngraph current_root = self.root def restore_graph() -> None: self._ngraph = current_graph self.root = current_root try: self._set_graph(payload=payload) return self.export({}) except ServerErrorWithResponse as err: restore_graph() raise err except Exception as err: restore_graph() raise ServerErrorWithResponse(message=str(err), type="unexpected_error") def _set_graph(self, payload: Payload) -> None: """ Payload example: { "nodes": [ { "name": "SourceNode", "pk": "0dc3b706-e6cc-401e-96f7-6a45d3947d5c" }, { "name": "EventsNode", "pk": "07921cb0-60b8-45af-928d-272d1b622b25", "processor": { "name": "SimpleGroup", "values": {"event_name": "add_to_cart", "event_type": "group_alias"}, }, }, { "name": "EventsNode", "pk": "114251ae-0f03-45e6-a163-af51bb02dfd5", "processor": { "name": "SimpleGroup", "values": {"event_name": "logout", "event_type": "group_alias"}, }, }, ], "links": [ { 'source': '0dc3b706-e6cc-401e-96f7-6a45d3947d5c', 'target': '07921cb0-60b8-45af-928d-272d1b622b25' }, { 'source': '07921cb0-60b8-45af-928d-272d1b622b25', 'target': '114251ae-0f03-45e6-a163-af51bb02dfd5' } ] } """ errors: List[CreateNodeErrorDesc] = [] nodes: List[Node] = [] # create nodes & validate params for node in payload["nodes"]: node_pk = node["pk"] processor = node.get("processor", {}) processor_name = processor.get("name", None) if processor else None processor_params = processor.get("values", None) if processor else None description = node.get("description", None) try: actual_node = build_node( source_stream=self.root.events, pk=node_pk, node_name=node["name"], processor_name=processor_name, processor_params=processor_params, descriptionn=description, ) nodes.append(actual_node) except Exception as error: error_desc = self._build_node_error_desc(node_pk=node_pk, error=error) errors.append(error_desc) if errors: raise ServerErrorWithResponse(message="set graph error", type="create_nodes_error", errors=errors) self._ngraph = networkx.DiGraph() # add nodes for created_node in nodes: if isinstance(created_node, SourceNode): self.root = created_node self._ngraph.add_node(created_node) # add links # @TODO: validate links (graph structure) for link in payload["links"]: source = self._find_node(link["source"]) target = self._find_node(link["target"]) if not source: raise ServerErrorWithResponse(message="source not found", type="create_link_error") if not target: raise ServerErrorWithResponse(message="target not found", type="create_link_error") self._ngraph.add_edge(source, target) def _build_node_error_desc(self, node_pk: str, error: Exception) -> CreateNodeErrorDesc: if isinstance(error, ValidationError): return self._build_pydantic_error_desc( node_pk=node_pk, validation_error_exception=error, ) if isinstance(error, WidgetParseError): field_errors: List[FieldErrorDesc] = ( [{"field": error.field_name, "msg": str(error)}] if error.field_name else [] ) return { "type": "node_error", "msg": str(error), "node_pk": node_pk, "fields_errors": field_errors, } return { "type": "node_error", "msg": str(error), "node_pk": node_pk, "fields_errors": [], } def _build_pydantic_error_desc( self, node_pk: str, validation_error_exception: ValidationError ) -> CreateNodeErrorDesc: raw_errs = validation_error_exception.errors() result_errors: List[FieldErrorDesc] = [] for raw_err in raw_errs: loc = raw_err.get("loc", ()) field = next(iter(loc), None) msg = raw_err.get("msg") result_errors.append( { "field": str(field), "msg": msg, } ) return {"type": "node_error", "node_pk": node_pk, "msg": "node error", "fields_errors": result_errors} def _find_parents_by_links(self, target_node: str, link_list: list[NodeLink]) -> list[Node]: parents: list[str] = [] for node in link_list: if node["target"] == target_node: parents.append(node["source"]) parent_nodes = [self._find_node(parent) for parent in parents] return parent_nodes # type: ignore def _find_node(self, pk: str) -> Node | None: for node in self._ngraph: if node.pk == pk: return node else: return None