Source code for retentioneering.eventstream.helpers.group_events_bulk_helper
from __future__ import annotations
from typing import Callable, Dict, List, Optional
from pandas import DataFrame, Series
from typing_extensions import (  # required for pydantic and python < 3.9.2
    NotRequired,
    Required,
    TypedDict,
)
from retentioneering.backend.tracker import (
    collect_data_performance,
    time_performance,
    track,
)
from retentioneering.utils.doc_substitution import docstrings
from ..types import EventstreamSchemaType, EventstreamType
EventstreamFilter = Callable[[DataFrame, Optional[EventstreamSchemaType]], DataFrame]
GroupingRulesDict = Dict[str, EventstreamFilter]
class GroupEventsRule(TypedDict, total=False):
    event_name: Required[str]
    func: Required[EventstreamFilter]
    event_type: NotRequired[str]
class GroupEventsBulkHelperMixin:
[docs]    @docstrings.with_indent(12)
    @time_performance(  # type: ignore
        scope="group_events_bulk",
        event_name="helper",
        event_value="combine",
    )
    def group_events_bulk(
        self: EventstreamType,
        grouping_rules: List[GroupEventsRule] | GroupingRulesDict,
        ignore_intersections: bool = False,
    ) -> EventstreamType:
        """
        Apply multiple grouping rules simultaneously.
        See also :py:meth:`GroupEvents<retentioneering.data_processors_lib.group_events.GroupEvents>`
        Parameters
        ----------
            %(GroupEventsBulk.parameters)s
        Returns
        -------
        Eventstream
            ``Eventstream`` with the grouped events according to the given grouping rules.
        """
        # avoid circular import
        from retentioneering.data_processors_lib import (
            GroupEventsBulk,
            GroupEventsBulkParams,
        )
        from retentioneering.preprocessing_graph import PreprocessingGraph
        from retentioneering.preprocessing_graph.nodes import EventsNode
        params = GroupEventsBulkParams(grouping_rules=grouping_rules, ignore_intersections=ignore_intersections)  # type: ignore
        calling_params = params.dict()
        p = PreprocessingGraph(source_stream=self)  # type: ignore
        node = EventsNode(processor=GroupEventsBulk(params=params))
        p.add_node(node=node, parents=[p.root])
        result = p.combine(node)
        del p
        collect_data_performance(
            scope="map_events",
            event_name="metadata",
            called_params=calling_params,
            performance_data={},
            eventstream_index=self._eventstream_index,
            parent_eventstream_index=self._eventstream_index,
            child_eventstream_index=result._eventstream_index,
        )
        return result