from __future__ import annotations
import json
import os
from pathlib import Path
from typing import Any, List, Optional, cast
import networkx
import pandas as pd
from IPython.core.display import HTML, DisplayHandle, display
from pydantic import ValidationError
from retentioneering import RETE_CONFIG
from retentioneering.backend import JupyterServer, ServerManager
from retentioneering.backend.callback import list_dataprocessor, list_dataprocessor_mock
from retentioneering.backend.tracker import (
collect_data_performance,
time_performance,
track,
tracker,
)
from retentioneering.eventstream.types import EventstreamType
from retentioneering.exceptions.server import ServerErrorWithResponse
from retentioneering.exceptions.widget import WidgetParseError
from retentioneering.preprocessing_graph.interface import (
CombineHandlerPayload,
CreateNodeErrorDesc,
FieldErrorDesc,
NodeLink,
Payload,
PreprocessingGraphConfig,
)
from retentioneering.preprocessing_graph.nodes import (
EventsNode,
MergeNode,
Node,
SourceNode,
build_node,
)
from retentioneering.templates import PreprocessingGraphRenderer
[docs]class PreprocessingGraph:
"""
Collection of methods for preprocessing graph construction and calculation.
Parameters
----------
source_stream : EventstreamType
Source eventstream.
Notes
-----
See :doc:`Preprocessing user guide</user_guides/preprocessing>` for the details.
"""
DEFAULT_GRAPH_URL = "https://static.server.retentioneering.com/package/@rete/preprocessing-graph/version/2/dist/preprocessing-graph.umd.js"
DEFAULT_GRAPH_STYLE_URL = (
"https://static.server.retentioneering.com/package/@rete/preprocessing-graph/version/2/dist/style.css"
)
root: SourceNode
_combine_result: EventstreamType | None
_ngraph: networkx.DiGraph
__server_manager: ServerManager | None = None
__server: JupyterServer | None = None
@time_performance( # type: ignore
scope="preprocessing_graph",
event_name="init",
)
def __init__(self, source_stream: EventstreamType) -> None:
self._source_stream = source_stream
self.root = SourceNode(source=source_stream)
self._combine_result = None
self._ngraph = networkx.DiGraph()
self._ngraph.add_node(self.root)
[docs] def add_node(self, node: Node, parents: List[Node]) -> None:
"""
Add node to ``PreprocessingGraph`` instance.
Parameters
----------
node : Node
An instance of either ``EventsNode`` or ``MergeNode``.
parents : list of Nodes
- If ``node`` is ``EventsNode`` - only 1 parent must be defined.
- If ``node`` is ``MergeNode`` - at least 2 parents have to be defined.
Returns
-------
None
See Also
--------
PreprocessingGraph.combine : Start PreprocessingGraph recalculation.
.EventsNode : Regular nodes of a preprocessing graph.
.MergeNode : Merge nodes of a preprocessing graph.
"""
self.__valiate_already_exists(node)
self.__validate_not_found(parents)
if node.events is not None:
self.__validate_schema(node.events)
if not isinstance(node, MergeNode) and len(parents) > 1:
raise ValueError("multiple parents are only allowed for merge nodes!")
self._ngraph.add_node(node)
for parent in parents:
self._ngraph.add_edge(parent, node)
[docs] @time_performance(
scope="preprocessing_graph",
event_name="combine",
)
def combine(self, node: Node) -> EventstreamType:
"""
Run calculations from the ``SourceNode`` up to the specified ``node``.
Parameters
----------
node : Node
Instance of either ``SourceNode``, ``EventsNode`` or ``MergeNode``.
Returns
-------
EventstreamType
``Eventstream`` with all changes applied by data processors.
"""
self.__validate_not_found([node])
if isinstance(node, SourceNode):
result_eventstream = node.events.copy()
elif isinstance(node, EventsNode):
result_eventstream = self._combine_events_node(node)
else:
result_eventstream = self._combine_merge_node(node)
collect_data_performance(
scope="preprocessing_graph",
called_params={"node": node},
performance_data={
"parent": {
"index": getattr(self.__get_node_eventstream(node), "_eventstream_index", None),
"hash": getattr(self.__get_node_eventstream(node), "_hash", None),
},
"child": {"index": result_eventstream._eventstream_index, "hash": result_eventstream._hash},
},
eventstream_index=getattr(self.__get_node_eventstream(node), "_eventstream_index", None),
parent_eventstream_index=getattr(self.__get_node_eventstream(node), "_eventstream_index", None),
child_eventstream_index=result_eventstream._eventstream_index,
)
return result_eventstream
def __get_node_eventstream(self, node: Node) -> pd.DataFrame | pd.Series | None:
if events := getattr(node, "events", None):
return events
return None
def _combine_events_node(self, node: EventsNode) -> EventstreamType:
parent = self._get_events_node_parent(node)
parent_events = self.combine(parent)
return node.processor._get_new_data(eventstream=parent_events)
def _combine_merge_node(self, node: MergeNode) -> EventstreamType:
parents = self._get_merge_node_parents(node)
curr_eventstream: Optional[EventstreamType] = None
for parent_node in parents:
if curr_eventstream is None:
curr_eventstream = self.combine(parent_node)
else:
new_eventstream = self.combine(parent_node)
curr_eventstream.index_order = new_eventstream.index_order
curr_eventstream.append_eventstream(new_eventstream)
node.events = curr_eventstream
return cast(EventstreamType, curr_eventstream)
[docs] def get_parents(self, node: Node) -> List[Node]:
"""
Show parents of the specified ``node``.
Parameters
----------
node : Node
Instance of one of the classes SourceNode, EventsNode or MergeNode.
Returns
-------
list of Nodes
"""
self.__validate_not_found([node])
parents: List[Node] = []
for parent in self._ngraph.predecessors(node):
parents.append(parent)
return parents
def _get_merge_node_parents(self, node: MergeNode) -> List[Node]:
parents = self.get_parents(node)
if len(parents) == 0:
raise ValueError("orphan merge node!")
return parents
def _get_events_node_parent(self, node: EventsNode) -> Node:
parents = self.get_parents(node)
if len(parents) > 1:
raise ValueError("invalid preprocessing graph: events node has more than 1 parent")
return parents[0]
def __validate_schema(self, eventstream: EventstreamType) -> bool:
return self.root.events.schema.is_equal(eventstream.schema)
def __valiate_already_exists(self, node: Node) -> None:
if node in self._ngraph.nodes:
raise ValueError("node already exists!")
def __validate_not_found(self, nodes: List[Node]) -> None:
for node in nodes:
if node not in self._ngraph.nodes:
raise ValueError("node not found!")
@property
def graph_url(self) -> str:
env_url = os.getenv("RETE_PREPROCESSING_GRAPH_URL", "")
return env_url if env_url else self.DEFAULT_GRAPH_URL
@property
def graph_style_url(self) -> str:
env_style = os.getenv("RETE_PREPROCESSING_GRAPH_STYLE_URL", "")
return env_style if env_style else self.DEFAULT_GRAPH_STYLE_URL
[docs] @time_performance(
scope="preprocessing_graph",
event_name="display",
)
def display(self, width: int = 960, height: int = 600) -> DisplayHandle:
"""
Show constructed ``PreprocessingGraph``.
Parameters
----------
width : int, default 960
Width of plot in pixels.
height : int, default 600
Height of plot in pixels.
Returns
-------
Rendered preprocessing graph.
"""
if not self.__server_manager:
self.__server_manager = ServerManager()
if not self.__server:
self.__server = self.__server_manager.create_server()
self.__server.register_action("list-dataprocessor-mock", list_dataprocessor_mock)
self.__server.register_action("list-dataprocessor", list_dataprocessor)
self.__server.register_action("set-graph", self._set_graph_handler)
self.__server.register_action("get-graph", self._export_handler)
self.__server.register_action("combine", self._combine_handler)
render = PreprocessingGraphRenderer()
collect_data_performance(
scope="preprocessing_graph",
event_name="metadata",
called_params={"width": width, "height": height},
eventstream_index=self.root.events._eventstream_index,
)
return display(
HTML(
render.show(
server_id=self.__server.pk,
env=self.__server_manager.check_env(),
width=width,
height=height,
graph_url=self.graph_url,
graph_style_url=self.graph_style_url,
kernel_id=self.__server.kernel_id,
tracking_hardware_id=RETE_CONFIG.user.pk,
tracking_eventstream_index=self._source_stream._eventstream_index,
tracking_scope="preprocessing_graph",
)
)
)
def _export_handler(self, payload: dict[str, Any]) -> dict:
"""
Show ``PreprocessingGraph`` as a dict.
Parameters
----------
payload : dict
Returns
-------
dict
"""
source, target, link = "source", "target", "links"
graph = self._ngraph
data = {
"directed": graph.is_directed(),
"nodes": [n.export() for n in graph],
link: [{source: u.pk, target: v.pk} for u, v, d in graph.edges(data=True)],
}
return data
def __resolve_path(self, relative_path: str) -> str:
path = Path(relative_path)
absolute_path = str(path.resolve())
return absolute_path
[docs] @time_performance(
scope="preprocessing_graph",
event_name="export_to_file",
)
def export_to_file(self, filename: str) -> None:
"""
Export constructed ``PreprocessingGraph`` to a json file.
Parameters
----------
filename : str
Path to the json file.
Returns
-------
None
"""
# validate access rights to file
graph_path = self.__resolve_path(filename)
if not os.access(os.path.dirname(graph_path), os.W_OK):
raise PermissionError(f"{os.path.dirname(graph_path)} is not writable")
# export preprocessing_graph
data = self._export_handler({})
with open(graph_path, "w") as f:
json.dump(data, f)
[docs] @time_performance(
scope="preprocessing_graph",
event_name="import_from_file",
)
def import_from_file(self, filename: str) -> None:
"""
Import constructed ``PreprocessingGraph`` from a json file.
Parameters
----------
filename : str
Path to the json file.
Returns
-------
None
"""
graph_path = self.__resolve_path(filename)
# validate access rights to file
if not os.access(os.path.dirname(graph_path), os.R_OK):
raise PermissionError(f"{os.path.dirname(graph_path)} is not readable")
with open(graph_path, "r") as f:
data = json.load(f)
if validation_error := self._validate_payload(payload=data):
raise ValueError("Invalid json file: %s" % validation_error)
payload: Payload = Payload(**data) # type: ignore
self._set_graph_handler(payload=payload)
def _combine_handler(self, payload: CombineHandlerPayload) -> None:
node = self._find_node(payload["node_pk"])
if not node:
raise ServerErrorWithResponse(message="node not found!", type="unexpected_error")
self._combine_result = self.combine(node)
@property
@time_performance(
scope="preprocessing_graph",
event_name="combine_result",
)
def combine_result(self) -> Optional[EventstreamType]:
"""
Keep and get the last combining result from preprocessing graph GUI.
Returns
-------
EventstreamType or None
Preprocessed eventstream.
"""
return self._combine_result
def _set_graph_handler(self, payload: Payload) -> dict:
current_graph = self._ngraph
current_root = self.root
def restore_graph() -> None:
self._ngraph = current_graph
self.root = current_root
try:
self._set_graph(payload=payload)
return self._export_handler({})
except ServerErrorWithResponse as err:
restore_graph()
raise err
except Exception as err:
restore_graph()
raise ServerErrorWithResponse(message=str(err), type="unexpected_error")
def _validate_payload(self, payload: Payload) -> None | Exception:
try:
config = PreprocessingGraphConfig(**payload)
return None
except Exception as err:
return err
def _set_graph(self, payload: Payload) -> None:
"""
Payload example:
{
"nodes": [
{
"name": "SourceNode",
"pk": "0dc3b706-e6cc-401e-96f7-6a45d3947d5c"
},
{
"name": "EventsNode",
"pk": "07921cb0-60b8-45af-928d-272d1b622b25",
"processor": {
"name": "SimpleGroup",
"values": {"event_name": "add_to_cart", "event_type": "group_alias"},
},
},
{
"name": "EventsNode",
"pk": "114251ae-0f03-45e6-a163-af51bb02dfd5",
"processor": {
"name": "SimpleGroup",
"values": {"event_name": "logout", "event_type": "group_alias"},
},
},
],
"links": [
{
'source': '0dc3b706-e6cc-401e-96f7-6a45d3947d5c',
'target': '07921cb0-60b8-45af-928d-272d1b622b25'
},
{
'source': '07921cb0-60b8-45af-928d-272d1b622b25',
'target': '114251ae-0f03-45e6-a163-af51bb02dfd5'
}
]
}
"""
errors: List[CreateNodeErrorDesc] = []
nodes: List[Node] = []
# create nodes & validate params
for node in payload["nodes"]:
node_pk = node["pk"]
processor = node.get("processor", {})
processor_name = processor.get("name", None) if processor else None
processor_params = processor.get("values", None) if processor else None
description = node.get("description", None)
try:
actual_node = build_node(
source_stream=self.root.events,
pk=node_pk,
node_name=node["name"],
processor_name=processor_name,
processor_params=processor_params,
descriptionn=description,
)
nodes.append(actual_node)
except Exception as error:
error_desc = self._build_node_error_desc(node_pk=node_pk, error=error)
errors.append(error_desc)
if errors:
raise ServerErrorWithResponse(
message="set preprocessing graph error", type="create_nodes_error", errors=errors
)
self._ngraph = networkx.DiGraph()
# add nodes
for created_node in nodes:
if isinstance(created_node, SourceNode):
self.root = created_node
self._ngraph.add_node(created_node)
# add links
# @TODO: validate links (graph structure)
for link in payload["links"]:
source = self._find_node(link["source"])
target = self._find_node(link["target"])
if not source:
raise ServerErrorWithResponse(message="source not found", type="create_link_error")
if not target:
raise ServerErrorWithResponse(message="target not found", type="create_link_error")
self._ngraph.add_edge(source, target)
def _build_node_error_desc(self, node_pk: str, error: Exception) -> CreateNodeErrorDesc:
if isinstance(error, ValidationError):
return self._build_pydantic_error_desc(
node_pk=node_pk,
validation_error_exception=error,
)
if isinstance(error, WidgetParseError):
field_errors: List[FieldErrorDesc] = (
[{"field": error.field_name, "msg": str(error)}] if error.field_name else []
)
return {
"type": "node_error",
"msg": str(error),
"node_pk": node_pk,
"fields_errors": field_errors,
}
return {
"type": "node_error",
"msg": str(error),
"node_pk": node_pk,
"fields_errors": [],
}
def _build_pydantic_error_desc(
self, node_pk: str, validation_error_exception: ValidationError
) -> CreateNodeErrorDesc:
raw_errs = validation_error_exception.errors()
result_errors: List[FieldErrorDesc] = []
for raw_err in raw_errs:
loc = raw_err.get("loc", ())
field = next(iter(loc), None)
msg = raw_err.get("msg")
result_errors.append(
{
"field": str(field),
"msg": msg,
}
)
return {"type": "node_error", "node_pk": node_pk, "msg": "node error", "fields_errors": result_errors}
def _find_parents_by_links(self, target_node: str, link_list: list[NodeLink]) -> list[Node]:
parents: list[str] = []
for node in link_list:
if node["target"] == target_node:
parents.append(node["source"])
parent_nodes = [self._find_node(parent) for parent in parents]
return parent_nodes # type: ignore
def _find_node(self, pk: str) -> Node | None:
for node in self._ngraph:
if node.pk == pk:
return node
else:
return None