from __future__ import annotations
from typing import Any, Dict
import numpy as np
import pandas as pd
import plotly.graph_objects as go
import seaborn as sns
from retentioneering.backend.tracker import (
    collect_data_performance,
    time_performance,
    track,
    tracker,
)
from retentioneering.eventstream.types import EventstreamType
from retentioneering.tooling.mixins.ended_events import EndedEventsMixin
[docs]class StepSankey(EndedEventsMixin):
    """
    A class for the visualization of user paths in stepwise manner using Sankey diagram.
    Parameters
    ----------
    eventstream : EventstreamType
    See Also
    --------
    .Eventstream.step_sankey : Call StepSankey tool as an eventstream method.
    .CollapseLoops : Find loops and create new synthetic events in the paths of all users having such sequences.
    .StepMatrix : This class provides methods for step matrix calculation and visualization.
    Notes
    -----
    See :doc:`StepSankey user guide</user_guides/step_sankey>` for the details.
    """
    max_steps: int
    threshold: int | float
    sorting: list | None
    targets: list[str] | str | None
    data_grp_nodes: pd.DataFrame
    data: pd.DataFrame
    data_grp_links: pd.DataFrame
    data_for_plot: dict
    @time_performance(
        scope="step_sankey",
        event_name="init",
    )
    def __init__(self, eventstream: EventstreamType) -> None:
        self.__eventstream = eventstream
        self.user_col = self.__eventstream.schema.user_id
        self.event_col = self.__eventstream.schema.event_name
        self.time_col = self.__eventstream.schema.event_timestamp
        self.event_index_col = self.__eventstream.schema.event_index
        self.data_grp_nodes = pd.DataFrame()
        self.data = pd.DataFrame()
        self.data_grp_links = pd.DataFrame()
        self.data_for_plot = {}
    @staticmethod
    def _make_color(
        event: str,
        all_events: list,
        palette: list,
    ) -> tuple[int, int, int]:
        """
        It is a color picking function
        Parameters
        ----------
        event : str
            An event for color setting
        all_events : list
            A list of all events
        palette : list
            A list of colors
        Returns
        -------
        str
            A picked color for certain event
        """
        return palette[list(all_events).index(event)]
    @staticmethod
    def _round_up(
        n: float,
        dec: float,
    ) -> float:
        """
        Rounds the value up to the nearest value assuming a grid with ``dec`` step.
        E.g. ``_round_up(0.51, 0.05) = 0.55``, ``_round_up(0.55, 0.05) = 0.6``
        Parameters
        ----------
        n : float
            A number to round up
        dec : float
            A decimal for correct rounding up
        Returns
        -------
        float
            Rounded value
        """
        return round(n - n % dec + dec, 2)
    def _get_nodes_positions(self, df: pd.DataFrame) -> tuple[list[float], list[float]]:
        """
        It is a function for placing nodes at the x and y coordinates of plotly lib plot canvas.
        Parameters
        ----------
        df : pandas Dataframe
            A dataframe that contains aggregated information about the nodes.
        Returns
        -------
        tuple[list[float], list[float]]
            Two lists with the corresponding coordinates x and y.
        """
        # NOTE get x axis length
        x_len = len(df["step"].unique())
        # NOTE declare positions
        x_positions = []
        y_positions = []
        # NOTE get maximum range for placing middle points
        y_range = 0.95 - 0.05
        # NOTE going inside ranked events
        for step in sorted(df["step"].unique()):
            # NOTE placing x-axis points as well
            for _ in df[df["step"] == step][self.event_col]:
                x_positions.append([round(x, 2) for x in np.linspace(0.05, 0.95, x_len)][step - 1])
            # NOTE it always works very well if you have less than 4 values at current rank
            y_len = len(df[df["step"] == step][self.event_col])
            # NOTE at this case using came positions as x-axis because we don't need to calculate something more
            if y_len < 4:
                for p in [round(y, 2) for y in np.linspace(0.05, 0.95, y_len)]:
                    y_positions.append(p)
            # NOTE jumping in to complex part
            else:
                # NOTE total sum for understanding do we need extra step size or not
                total_sum = df[df["step"] == step]["usr_cnt"].sum()
                # NOTE step size for middle points
                y_step = round(y_range / total_sum, 2)
                # NOTE cumulative sum for understanding do we need use default step size or not
                cumulative_sum = 0
                # NOTE ENDED action
                ended_sum = df[(df["step"] == step) & (df[self.event_col] == "ENDED")]["usr_cnt"].sum()
                last_point = self._round_up(ended_sum / total_sum, 0.05)
                iterate_sum = 0
                # NOTE going deeper inside each event
                for n, event in enumerate(df[df["step"] == step][self.event_col]):
                    # NOTE placing first event at first possible position
                    if n == 0:
                        y_positions.append(0.05)
                    # NOTE placing last event at last possible position
                    elif n + 1 == y_len:
                        y_positions.append(0.95)
                    # NOTE placing middle points
                    else:
                        # NOTE we found out that 70% of total sum is the best cap for doing this case
                        if iterate_sum / total_sum > 0.2 and event != "ENDED":
                            # NOTE placing first point after the biggest one at the next position
                            # but inside [.1; .3] range
                            y_positions.append(
                                round(y_positions[-1] + np.minimum(np.maximum(y_step * iterate_sum, 0.1), 0.3), 2)
                            )
                        # NOTE placing points after the biggest
                        else:
                            # NOTE placing little points at the all available space
                            y_positions.append(
                                round(y_positions[-1] + (0.95 - last_point - y_positions[-1]) / (y_len - n), 2)
                            )
                    # NOTE set sum for next step
                    iterate_sum = df[(df["step"] == step) & (df[self.event_col] == event)]["usr_cnt"].to_numpy()[0]
                    # NOTE update cumulative sum
                    cumulative_sum += iterate_sum
        return x_positions, y_positions
    def _pad_end_events(self, data: pd.DataFrame) -> pd.DataFrame:
        """
        If the number of events in a user's path is less than self.max_steps, then the function pads the path with
        multiply ENDED events. It is required for correct visualization of the trajectories which are
        shorter than self.max_steps.
        """
        pad = (
            data.groupby(self.user_col, as_index=False)[self.event_col]
            .count()
            .loc[lambda df_: df_[self.event_col] < self.max_steps]  # type: ignore
            .assign(repeat_number=lambda df_: self.max_steps - df_[self.event_col])
        )
        repeats = pd.DataFrame({self.user_col: np.repeat(pad[self.user_col], pad["repeat_number"])})
        padded_end_events = pd.merge(repeats, data[data[self.event_col] == "ENDED"], on=self.user_col)
        result = pd.concat([data, padded_end_events]).sort_values([self.user_col, self.event_index_col])
        return result
    def _prepare_data(self, data: pd.DataFrame) -> pd.DataFrame:
        data = self._add_ended_events(
            data=data, schema=self.__eventstream.schema, weight_col=self.__eventstream.schema.user_id
        )
        data = self._pad_end_events(data)
        # NOTE set new columns using declared functions
        data[self.time_col] = pd.to_datetime(data[self.time_col])
        data["step"] = data.groupby(self.user_col)[self.event_index_col].rank(method="first").astype(int)
        data = data.sort_values(by=["step", self.time_col]).reset_index(drop=True)
        data = self._get_next_event_and_timedelta(data)
        # NOTE threshold
        data["event_users"] = data.groupby(by=["step", self.event_col])[self.user_col].transform("nunique")
        data["total_users"] = data.loc[data["step"] == 1, self.user_col].nunique()
        data["perc"] = data["event_users"] / data["total_users"]
        if isinstance(self.threshold, float):
            column_to_compare = "perc"
        else:
            # assume that self.threshold must be of int type here
            column_to_compare = "event_users"
        events_to_keep = ["ENDED"]
        if self.targets is not None:
            events_to_keep += self.targets
        threshold_events = (
            data.loc[data["step"] <= self.max_steps, :]
            .groupby(by=self.event_col, as_index=False)[column_to_compare]
            .max()
            .loc[
                lambda df_: (df_[column_to_compare] <= self.threshold) & (~df_[self.event_col].isin(events_to_keep))
            ]  # type: ignore
            .loc[:, self.event_col]
        )
        data.loc[data[self.event_col].isin(threshold_events), self.event_col] = f"thresholded_{len(threshold_events)}"
        # NOTE rearrange the data taking into account recently added thresholded events
        data["step"] = data.groupby(self.user_col)[self.event_index_col].rank(method="first").astype(int)
        data = self._get_next_event_and_timedelta(data)
        # NOTE use max_steps for filtering data
        data = data.loc[data["step"] <= self.max_steps, :]
        # TODO: Do we really need to replace NA values?
        # NOTE skip mean calculating error
        data["time_to_next"] = data["time_to_next"].fillna(data["time_to_next"].min())
        return data
    def _render_plot(
        self,
        data_for_plot: dict,
        data_grp_nodes: pd.DataFrame,
        autosize: bool = True,
        width: int | None = None,
        height: int | None = None,
    ) -> go.Figure:
        # NOTE fill lists for plot
        targets = []
        sources = []
        values = []
        time_to_next = []
        for source_key in data_for_plot["links_dict"].keys():
            for target_key, target_value in data_for_plot["links_dict"][source_key].items():
                sources.append(source_key)
                targets.append(target_key)
                values.append(target_value["unique_users"])
                time_to_next.append(
                    str(pd.to_timedelta(target_value["avg_time_to_next"] / target_value["unique_users"])).split(".")[0]
                )
        # NOTE fill another lists for plot
        labels = []
        colors = []
        percs = []
        for key in data_for_plot["nodes_dict"].keys():
            labels += list(data_for_plot["nodes_dict"][key]["sources"])
            colors += list(data_for_plot["nodes_dict"][key]["color"])
            percs += list(data_for_plot["nodes_dict"][key]["percs"])
        # NOTE get colors for plot
        for idx, color in enumerate(colors):
            colors[idx] = "rgb" + str(color) + ""
        # NOTE get positions for plot
        x, y = self._get_nodes_positions(df=data_grp_nodes)
        # NOTE make plot
        fig = go.Figure(
            data=[
                go.Sankey(
                    arrangement="snap",
                    node=dict(
                        thickness=15,
                        line=dict(color="black", width=0.5),
                        label=labels,
                        color=colors,
                        customdata=percs,
                        hovertemplate="Total unique users: %{value} (%{customdata}% of total)<extra></extra>",
                        x=x,
                        y=y,
                        pad=20,
                    ),
                    link=dict(
                        source=sources,
                        target=targets,
                        value=values,
                        label=time_to_next,
                        hovertemplate="%{value} unique users went from %{source.label} to %{target.label}.<br />"
                        + "<br />It took them %{label} in average.<extra></extra>",
                    ),
                )
            ]
        )
        fig.update_layout(font=dict(size=15), plot_bgcolor="white", autosize=autosize, width=width, height=height)
        return fig
    def _get_links(
        self, data: pd.DataFrame, data_for_plot: dict, data_grp_nodes: pd.DataFrame
    ) -> tuple[dict, pd.DataFrame]:
        # NOTE create links aggregated dataframe
        data_grp_links = (
            data[data["step"] <= self.max_steps - 1]
            .groupby(by=["step", self.event_col, "next_event"])[[self.user_col, "time_to_next"]]
            .agg({self.user_col: ["count"], "time_to_next": ["sum"]})
            .reset_index()
            .rename(columns={self.user_col: "usr_cnt", "time_to_next": "time_to_next_sum"})
        )
        data_grp_links.columns = data_grp_links.columns.droplevel(1)
        data_grp_links = data_grp_links.merge(
            data_grp_nodes[["step", self.event_col, "index"]],
            how="inner",
            on=["step", self.event_col],
        )
        data_grp_links.loc[:, "next_step"] = data_grp_links["step"] + 1
        data_grp_links = data_grp_links.merge(
            data_grp_nodes[["step", self.event_col, "index"]].rename(
                columns={"step": "next_step", self.event_col: "next_event", "index": "next_index"}
            ),
            how="inner",
            on=["next_step", "next_event"],
        )
        data_grp_links.sort_values(by=["index", "usr_cnt"], ascending=[True, False], inplace=True)
        data_grp_links.reset_index(drop=True, inplace=True)
        # NOTE generating links plot dict
        data_for_plot.update({"links_dict": dict()})
        for index in data_grp_links["index"].unique():
            for next_index in data_grp_links[data_grp_links["index"] == index]["next_index"].unique():
                _unique_users, _avg_time_to_next = (
                    data_grp_links.loc[
                        (data_grp_links["index"] == index) & (data_grp_links["next_index"] == next_index),
                        ["usr_cnt", "time_to_next_sum"],
                    ]
                    .to_numpy()
                    .T
                )
                if index in data_for_plot["links_dict"]:
                    if next_index in data_for_plot["links_dict"][index]:
                        data_for_plot["links_dict"][index][next_index]["unique_users"] = _unique_users[0]
                        data_for_plot["links_dict"][index][next_index]["avg_time_to_next"] = np.timedelta64(
                            _avg_time_to_next[0]
                        )
                    else:
                        data_for_plot["links_dict"][index].update(
                            {
                                next_index: {
                                    "unique_users": _unique_users[0],
                                    "avg_time_to_next": np.timedelta64(_avg_time_to_next[0]),
                                }
                            }
                        )
                else:
                    data_for_plot["links_dict"].update(
                        {
                            index: {
                                next_index: {
                                    "unique_users": _unique_users[0],
                                    "avg_time_to_next": np.timedelta64(_avg_time_to_next[0]),
                                }
                            }
                        }
                    )
        return data_for_plot, data_grp_links
    def _get_nodes(self, data: pd.DataFrame) -> tuple[dict, pd.DataFrame]:
        all_events = list(data[self.event_col].unique())
        palette = self._prepare_palette(all_events)
        # NOTE create nodes aggregate dataframe
        data_grp_nodes = (
            data.groupby(by=["step", self.event_col])[self.user_col]
            .nunique()
            .reset_index()
            .rename(columns={self.user_col: "usr_cnt"})
        )
        data_grp_nodes.loc[:, "usr_cnt_total"] = data_grp_nodes.groupby(by=["step"])["usr_cnt"].transform("sum")
        data_grp_nodes.loc[:, "perc"] = np.round(
            (data_grp_nodes.loc[:, "usr_cnt"] / data_grp_nodes.loc[:, "usr_cnt_total"]) * 100, 2
        )
        data_grp_nodes.sort_values(
            by=["step", "usr_cnt", self.event_col],
            ascending=[True, False, True],
            inplace=True,
        )
        data_grp_nodes.reset_index(
            drop=True,
            inplace=True,
        )
        data_grp_nodes.loc[:, "color"] = data_grp_nodes[self.event_col].apply(
            lambda x: self._make_color(x, all_events, palette)
        )
        data_grp_nodes.loc[:, "index"] = data_grp_nodes.index  # type: ignore
        # NOTE doing right ranking
        if self.sorting is None:
            data_grp_nodes.loc[:, "sorting"] = 100
        else:
            for n, s in enumerate(self.sorting):
                data_grp_nodes.loc[data_grp_nodes[self.event_col] == s, "sorting"] = n
            data_grp_nodes.loc[:, "sorting"].fillna(100, inplace=True)
        # NOTE placing ENDED at the end
        data_grp_nodes.loc[data_grp_nodes[self.event_col] == "ENDED", "sorting"] = 101
        # NOTE using custom ordering
        data_grp_nodes.loc[:, "sorting"] = data_grp_nodes.loc[:, "sorting"].astype(int)
        # @TODO: step variable is not used inside the loop. The loop might be invalid. Vladimir Kukushkin
        # NOTE doing loop for valid ranking
        for step in data_grp_nodes["step"].unique():
            # NOTE saving last level order
            data_grp_nodes.loc[:, "order_by"] = (
                data_grp_nodes.groupby(by=[self.event_col])["index"].transform("shift").fillna(100).astype(int)
            )
            # NOTE placing ENDED events at the end
            data_grp_nodes.loc[data_grp_nodes[self.event_col] == "ENDED", "sorting"] = 101
            # NOTE creating new indexes
            data_grp_nodes.sort_values(
                by=["step", "sorting", "order_by", "usr_cnt", self.event_col],
                ascending=[True, True, True, False, True],
                inplace=True,
            )
            data_grp_nodes.reset_index(
                drop=True,
                inplace=True,
            )
            data_grp_nodes.loc[:, "index"] = data_grp_nodes.index  # type: ignore
        # NOTE generating nodes plot dict
        data_for_plot: Dict[str, Any] = dict()
        data_for_plot.update({"nodes_dict": dict()})
        for step in data_grp_nodes["step"].unique():
            data_for_plot["nodes_dict"].update({step: dict()})
            _sources, _color, _sources_index, _percs = (
                data_grp_nodes.loc[data_grp_nodes["step"] == step, [self.event_col, "color", "index", "perc"]]
                .to_numpy()
                .T
            )
            data_for_plot["nodes_dict"][step].update(
                {
                    "sources": list(_sources),
                    "color": list(_color),
                    "sources_index": list(_sources_index),
                    "percs": list(_percs),
                }
            )
        return data_for_plot, data_grp_nodes
    @staticmethod
    def _prepare_palette(all_events: list) -> list[tuple]:
        # NOTE default color palette
        palette_hex = ["50BE97", "E4655C", "FCC865", "BFD6DE", "3E5066", "353A3E", "E6E6E6"]
        # NOTE convert HEX to RGB
        palette = []
        for color in palette_hex:
            rgb_color = tuple(int(color[i : i + 2], 16) for i in (0, 2, 4))
            palette.append(rgb_color)
        # NOTE extend color palette if number of events more than default colors list
        complementary_palette = sns.color_palette("deep", len(all_events) - len(palette))
        if len(complementary_palette) > 0:
            colors = complementary_palette.as_hex()
            for c in colors:
                col = c[1:]
                palette.append(tuple(int(col[i : i + 2], 16) for i in (0, 2, 4)))
        return palette
    def _get_next_event_and_timedelta(self, data: pd.DataFrame) -> pd.DataFrame:
        grouped = data.groupby(self.user_col)
        data["next_event"] = grouped[self.event_col].shift(-1)
        data["next_timestamp"] = grouped[self.time_col].shift(-1)
        data["time_to_next"] = data["next_timestamp"] - data[self.time_col]
        data = data.drop("next_timestamp", axis=1)
        return data
[docs]    @time_performance(scope="step_sankey", event_name="fit")
    def fit(
        self,
        max_steps: int = 10,
        threshold: int | float = 0.05,
        sorting: list | None = None,
        targets: list[str] | str | None = None,
    ) -> None:
        """
        Calculate the sankey diagram internal values with the defined parameters.
        Applying ``fit`` method is necessary for the following usage
        of any visualization or descriptive ``StepSankey`` methods.
        Parameters
        ----------
        max_steps : int, default 10
            Maximum number of steps in trajectories to include. Should be > 1.
        threshold : float | int, default 0.05
            Used to remove rare events from the plot. An event is collapsed to ``thresholded_N`` artificial event if
            its maximum frequency across all the steps is less than or equal to ``threshold``. The frequency is set
            with respect to ``threshold`` type:
            - If ``int`` - the frequency is the number of unique users who had given event at given step.
            - If ``float`` - percentage of users: the same as for ``int``, but divided by the number of unique users.
            The events which are prohibited for collapsing could be enlisted in ``target`` parameter.
        sorting : list of str, optional
            Define the order of the events visualized at each step. The events that are not represented in the list
            will follow after the events from the list.
        targets : str or list of str, optional
            Contain events that are prohibited for collapsing with ``threshold`` parameter.
        Raises
        ------
        ValueError
            If ``max_steps`` parameter is <= 1.
        """
        data = self.__eventstream.to_dataframe(copy=True)[
            [self.user_col, self.event_col, self.time_col, self.event_index_col]
        ]
        if max_steps <= 1:
            raise ValueError("max_steps parameter must be > 1!")
        called_params = {
            "max_steps": max_steps,
            "threshold": threshold,
            "sorting": sorting,
            "targets": targets,
        }
        self.max_steps = max_steps
        self.threshold = threshold
        self.sorting = sorting
        self.targets = targets
        self.data = self._prepare_data(data)
        data_for_plot, self.data_grp_nodes = self._get_nodes(self.data)
        self.data_for_plot, self.data_grp_links = self._get_links(self.data, data_for_plot, self.data_grp_nodes)
        performance_info = {
            "nodes_count": len(self.data_grp_nodes),
            "links_count": len(self.data_grp_links),
        }
        collect_data_performance(
            scope="step_sankey",
            event_name="metadata",
            called_params=called_params,
            performance_data=performance_info,
            eventstream_index=self.__eventstream._eventstream_index,
        ) 
[docs]    @time_performance(
        scope="step_sankey",
        event_name="plot",
    )
    def plot(self, autosize: bool = True, width: int | None = None, height: int | None = None) -> go.Figure:
        """
        Create a Sankey interactive plot based on the calculated values.
        Should be used after :py:func:`fit`.
        Parameters
        ----------
        autosize : bool, default True
            Plotly autosize parameter. See :plotly_autosize:`plotly documentation<>`.
        width : int, optional
            Plot's width (in px). See :plotly_width:`plotly documentation<>`.
        height : int, optional
            Plot's height (in px). See :plotly_height:`plotly documentation<>`.
        Returns
        -------
        plotly.graph_objects.Figure
        """
        called_params = {
            "autosize": autosize,
            "width": width,
            "height": height,
        }
        figure = self._render_plot(
            data_for_plot=self.data_for_plot,
            data_grp_nodes=self.data_grp_nodes,
            autosize=autosize,
            width=width,
            height=height,
        )
        performance_info = {
            "nodes_count": len(self.data_grp_nodes),
            "links_count": len(self.data_grp_links),
        }
        collect_data_performance(
            scope="step_sankey",
            event_name="metadata",
            called_params=called_params,
            performance_data=performance_info,
            eventstream_index=self.__eventstream._eventstream_index,
        )
        return figure 
    @property
    @time_performance(
        scope="step_sankey",
        event_name="values",
    )
    def values(self) -> tuple[pd.DataFrame, pd.DataFrame]:
        """
        Returns two pd.DataFrames which the Sankey diagram is based on.
        Should be used after :py:func:`fit`.
        Returns
        -------
        tuple[pd.DataFrame, pd.DataFrame]
            1. Contains the nodes of the diagram.
            2. Contains the edges of the diagram.
        """
        return self.data_grp_nodes, self.data_grp_links
    @property
    @time_performance(
        scope="step_sankey",
        event_name="params",
    )
    def params(self) -> dict:
        """
        Returns the parameters used for the last fitting.
        Should be used after :py:func:`fit`.
        """
        return {
            "max_steps": self.max_steps,
            "threshold": self.threshold,
            "sorting": self.sorting,
            "targets": self.targets,
        }