from __future__ import annotations
import json
from typing import Any, List, Literal, Optional, TypedDict, cast
import networkx
from IPython.core.display import HTML, DisplayHandle, display
from pydantic import ValidationError
from retentioneering.backend import JupyterServer, ServerManager
from retentioneering.backend.callback import list_dataprocessor, list_dataprocessor_mock
from retentioneering.eventstream.types import EventstreamType
from retentioneering.exceptions.server import ServerErrorWithResponse
from retentioneering.exceptions.widget import WidgetParseError
from retentioneering.preprocessing_graph.nodes import (
    EventsNode,
    MergeNode,
    Node,
    SourceNode,
    build_node,
)
from retentioneering.templates import PreprocessingGraphRenderer
class NodeData(TypedDict):
    name: str
    pk: str
    description: Optional[str]
    processor: Optional[dict]
class NodeLink(TypedDict):
    source: str
    target: str
class Payload(TypedDict):
    directed: bool
    nodes: list[NodeData]
    links: list[NodeLink]
class CombineHandlerPayload(TypedDict):
    node_pk: str
class FieldErrorDesc(TypedDict):
    field: str
    msg: str
class CreateNodeErrorDesc(TypedDict):
    type: Literal["node_error"]
    node_pk: str
    msg: Optional[str]
    fields_errors: List[FieldErrorDesc]
[docs]class PreprocessingGraph:
    """
    Collection of methods for preprocessing graph construction and calculation.
    Parameters
    ----------
    source_stream : EventstreamType
        Source eventstream.
    Notes
    -----
    See :doc:`Preprocessing user guide</user_guides/preprocessing>` for the details.
    """
    root: SourceNode
    combine_result: EventstreamType | None
    _ngraph: networkx.DiGraph
    __server_manager: ServerManager | None = None
    __server: JupyterServer | None = None
    def __init__(self, source_stream: EventstreamType) -> None:
        self.root = SourceNode(source=source_stream)
        self.combine_result = None
        self._ngraph = networkx.DiGraph()
        self._ngraph.add_node(self.root)
[docs]    def add_node(self, node: Node, parents: List[Node]) -> None:
        """
        Add node to ``PreprocessingGraph`` instance.
        Parameters
        ----------
        node : Node
            An instance of either ``EventsNode`` or ``MergeNode``.
        parents : list of Nodes
            - If ``node`` is ``EventsNode`` - only 1 parent must be defined.
            - If ``node`` is ``MergeNode`` - at least 2 parents have to be defined.
        Returns
        -------
        None
        See Also
        --------
        PreprocessingGraph.combine : Start PreprocessingGraph recalculation.
        .EventsNode : Regular nodes of a preprocessing graph.
        .MergeNode : Merge nodes of a preprocessing graph.
        """
        self.__valiate_already_exists(node)
        self.__validate_not_found(parents)
        if node.events is not None:
            self.__validate_schema(node.events)
        if not isinstance(node, MergeNode) and len(parents) > 1:
            raise ValueError("multiple parents are only allowed for merge nodes!")
        self._ngraph.add_node(node)
        for parent in parents:
            self._ngraph.add_edge(parent, node) 
[docs]    def combine(self, node: Node) -> EventstreamType:
        """
        Run calculations from the ``SourceNode`` up to the specified ``node``.
        Parameters
        ----------
        node : Node
            Instance of either ``SourceNode``, ``EventsNode`` or ``MergeNode``.
        Returns
        -------
        EventstreamType
            ``Eventstream`` with all changes applied by data processors.
        """
        self.__validate_not_found([node])
        if isinstance(node, SourceNode):
            return node.events.copy()
        if isinstance(node, EventsNode):
            return self._combine_events_node(node)
        return self._combine_merge_node(node) 
    def _combine_events_node(self, node: EventsNode) -> EventstreamType:
        parent = self._get_events_node_parent(node)
        parent_events = self.combine(parent)
        events = node.processor.apply(parent_events)
        parent_events._join_eventstream(events)
        return parent_events
    def _combine_merge_node(self, node: MergeNode) -> EventstreamType:
        parents = self._get_merge_node_parents(node)
        curr_eventstream: Optional[EventstreamType] = None
        for parent_node in parents:
            if curr_eventstream is None:
                curr_eventstream = self.combine(parent_node)
            else:
                new_eventstream = self.combine(parent_node)
                curr_eventstream.append_eventstream(new_eventstream)
        node.events = curr_eventstream
        return cast(EventstreamType, curr_eventstream)
[docs]    def get_parents(self, node: Node) -> List[Node]:
        """
        Show parents of the specified ``node``.
        Parameters
        ----------
        node : Node
            Instance of one of the classes SourceNode, EventsNode or MergeNode.
        Returns
        -------
        list of Nodes
        """
        self.__validate_not_found([node])
        parents: List[Node] = []
        for parent in self._ngraph.predecessors(node):
            parents.append(parent)
        return parents 
    def _get_merge_node_parents(self, node: MergeNode) -> List[Node]:
        parents = self.get_parents(node)
        if len(parents) == 0:
            raise ValueError("orphan merge node!")
        return parents
    def _get_events_node_parent(self, node: EventsNode) -> Node:
        parents = self.get_parents(node)
        if len(parents) > 1:
            raise ValueError("invalid graph: events node has more than 1 parent")
        return parents[0]
    def __validate_schema(self, eventstream: EventstreamType) -> bool:
        return self.root.events.schema.is_equal(eventstream.schema)
    def __valiate_already_exists(self, node: Node) -> None:
        if node in self._ngraph.nodes:
            raise ValueError("node already exists!")
    def __validate_not_found(self, nodes: List[Node]) -> None:
        for node in nodes:
            if node not in self._ngraph.nodes:
                raise ValueError("node not found!")
[docs]    def display(self, width: int = 960, height: int = 600) -> DisplayHandle:
        """
        Show constructed ``PreprocessingGraph``.
        Parameters
        ----------
        width : int, default 960
            Width of plot in pixels.
        height : int, default 600
            Height of plot in pixels.
        Returns
        -------
            Rendered preprocessing graph.
        """
        if not self.__server_manager:
            self.__server_manager = ServerManager()
        if not self.__server:
            self.__server = self.__server_manager.create_server()
            self.__server.register_action("list-dataprocessor-mock", list_dataprocessor_mock)
            self.__server.register_action("list-dataprocessor", list_dataprocessor)
            self.__server.register_action("set-graph", self._set_graph_handler)
            self.__server.register_action("get-graph", self.export)
            self.__server.register_action("combine", self._combine_handler)
        render = PreprocessingGraphRenderer()
        return display(
            HTML(
                render.show(
                    server_id=self.__server.pk, env=self.__server_manager.check_env(), width=width, height=height
                )
            )
        ) 
[docs]    def export(self, payload: dict[str, Any]) -> dict:
        """
        Show ``PreprocessingGraph`` as a dict.
        Parameters
        ----------
        payload : dict
        Returns
        -------
        dict
        """
        source, target, link = "source", "target", "links"
        graph = self._ngraph
        data = {
            "directed": graph.is_directed(),
            "nodes": [n.export() for n in graph],
            link: [{source: u.pk, target: v.pk} for u, v, d in graph.edges(data=True)],
        }
        return data 
    def _export_to_json(self) -> str:
        data = self.export(payload=dict())
        return json.dumps(data)
    def _combine_handler(self, payload: CombineHandlerPayload) -> None:
        node = self._find_node(payload["node_pk"])
        if not node:
            raise ServerErrorWithResponse(message="node not found!", type="unexpected_error")
        self.combine_result = self.combine(node)
    def _set_graph_handler(self, payload: Payload) -> dict:
        current_graph = self._ngraph
        current_root = self.root
        def restore_graph() -> None:
            self._ngraph = current_graph
            self.root = current_root
        try:
            self._set_graph(payload=payload)
            return self.export({})
        except ServerErrorWithResponse as err:
            restore_graph()
            raise err
        except Exception as err:
            restore_graph()
            raise ServerErrorWithResponse(message=str(err), type="unexpected_error")
    def _set_graph(self, payload: Payload) -> None:
        """
        Payload example:
        {
            "nodes": [
                {
                    "name": "SourceNode",
                    "pk": "0dc3b706-e6cc-401e-96f7-6a45d3947d5c"
                },
                {
                    "name": "EventsNode",
                    "pk": "07921cb0-60b8-45af-928d-272d1b622b25",
                    "processor": {
                        "name": "SimpleGroup",
                        "values": {"event_name": "add_to_cart", "event_type": "group_alias"},
                    },
                },
                {
                    "name": "EventsNode",
                    "pk": "114251ae-0f03-45e6-a163-af51bb02dfd5",
                    "processor": {
                        "name": "SimpleGroup",
                        "values": {"event_name": "logout", "event_type": "group_alias"},
                    },
                },
            ],
            "links": [
                {
                    'source': '0dc3b706-e6cc-401e-96f7-6a45d3947d5c',
                    'target': '07921cb0-60b8-45af-928d-272d1b622b25'
                },
                {
                    'source': '07921cb0-60b8-45af-928d-272d1b622b25',
                    'target': '114251ae-0f03-45e6-a163-af51bb02dfd5'
                }
            ]
        }
        """
        errors: List[CreateNodeErrorDesc] = []
        nodes: List[Node] = []
        # create nodes & validate params
        for node in payload["nodes"]:
            node_pk = node["pk"]
            processor = node.get("processor", {})
            processor_name = processor.get("name", None) if processor else None
            processor_params = processor.get("values", None) if processor else None
            description = node.get("description", None)
            try:
                actual_node = build_node(
                    source_stream=self.root.events,
                    pk=node_pk,
                    node_name=node["name"],
                    processor_name=processor_name,
                    processor_params=processor_params,
                    descriptionn=description,
                )
                nodes.append(actual_node)
            except Exception as error:
                error_desc = self._build_node_error_desc(node_pk=node_pk, error=error)
                errors.append(error_desc)
        if errors:
            raise ServerErrorWithResponse(message="set graph error", type="create_nodes_error", errors=errors)
        self._ngraph = networkx.DiGraph()
        # add nodes
        for created_node in nodes:
            if isinstance(created_node, SourceNode):
                self.root = created_node
            self._ngraph.add_node(created_node)
        # add links
        # @TODO: validate links (graph structure)
        for link in payload["links"]:
            source = self._find_node(link["source"])
            target = self._find_node(link["target"])
            if not source:
                raise ServerErrorWithResponse(message="source not found", type="create_link_error")
            if not target:
                raise ServerErrorWithResponse(message="target not found", type="create_link_error")
            self._ngraph.add_edge(source, target)
    def _build_node_error_desc(self, node_pk: str, error: Exception) -> CreateNodeErrorDesc:
        if isinstance(error, ValidationError):
            return self._build_pydantic_error_desc(
                node_pk=node_pk,
                validation_error_exception=error,
            )
        if isinstance(error, WidgetParseError):
            field_errors: List[FieldErrorDesc] = (
                [{"field": error.field_name, "msg": str(error)}] if error.field_name else []
            )
            return {
                "type": "node_error",
                "msg": str(error),
                "node_pk": node_pk,
                "fields_errors": field_errors,
            }
        return {
            "type": "node_error",
            "msg": str(error),
            "node_pk": node_pk,
            "fields_errors": [],
        }
    def _build_pydantic_error_desc(
        self, node_pk: str, validation_error_exception: ValidationError
    ) -> CreateNodeErrorDesc:
        raw_errs = validation_error_exception.errors()
        result_errors: List[FieldErrorDesc] = []
        for raw_err in raw_errs:
            loc = raw_err.get("loc", ())
            field = next(iter(loc), None)
            msg = raw_err.get("msg")
            result_errors.append(
                {
                    "field": str(field),
                    "msg": msg,
                }
            )
        return {"type": "node_error", "node_pk": node_pk, "msg": "node error", "fields_errors": result_errors}
    def _find_parents_by_links(self, target_node: str, link_list: list[NodeLink]) -> list[Node]:
        parents: list[str] = []
        for node in link_list:
            if node["target"] == target_node:
                parents.append(node["source"])
        parent_nodes = [self._find_node(parent) for parent in parents]
        return parent_nodes  # type: ignore
    def _find_node(self, pk: str) -> Node | None:
        for node in self._ngraph:
            if node.pk == pk:
                return node
        else:
            return None