Source code for retentioneering.data_processors_lib.group_events_bulk
from inspect import signature
from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np
import pandas as pd
from pydantic.dataclasses import dataclass
from retentioneering.backend.tracker import collect_data_performance, time_performance
from retentioneering.data_processor import DataProcessor
from retentioneering.eventstream.types import EventstreamSchemaType
from retentioneering.params_model import ParamsModel
from retentioneering.utils.doc_substitution import docstrings
from retentioneering.utils.hash_object import hash_dataframe
from retentioneering.widget.widgets import EnumWidget
EventstreamFilter = Callable[[pd.DataFrame, Optional[EventstreamSchemaType]], Any]
GroupingRulesDict = Dict[str, EventstreamFilter]
[docs]@dataclass
class GroupEventsRule:
    event_name: str
    func: EventstreamFilter
    event_type: Optional[str] = None 
class GroupEventsBulkParams(ParamsModel):
    grouping_rules: Union[List[GroupEventsRule], GroupingRulesDict]
    ignore_intersections: bool = False
    _widgets = {
        # @TODO: is stub for editor, fix later
        "grouping_rules": EnumWidget(),
    }
def combine_masks(masks: List[pd.Series]) -> pd.Series:
    mask_arrays = [mask.values for mask in masks]
    combined_mask = np.sum(mask_arrays, axis=0) > 1  # type: ignore
    result_mask = pd.Series(combined_mask, index=masks[0].index)
    return result_mask
[docs]@docstrings.get_sections(base="GroupEventsBulk")  # type: ignore
class GroupEventsBulk(DataProcessor):
    """
    Apply multiple grouping rules simultaneously.
    See also :py:meth:`GroupEvents<retentioneering.data_processors_lib.group_events.GroupEvents>`
    Parameters
    ----------
    grouping_rules : list or dict
        - If list, each list element is a dictionary with mandatory keys ``event_name`` and ``func`` and an
          optional key ``event_type``. Their meaning is the same as for
          :py:meth:`GroupEvents<retentioneering.data_processors_lib.group_events.GroupEvents>`.
        - If dict, the keys are considered as ``event_name``, values are considered as ``func``.
          Setting ``event_type`` is not supported in this case.
    ignore_intersections : bool, default False
        If ``False``, a ``ValueError`` is raised in case any event from the input eventstream matches
        more than one grouping rule. Otherwise, the first appropriate rule from ``grouping_rules`` is applied.
    Returns
    -------
    Eventstream
        ``Eventstream`` with the grouped events according to the given grouping rules.
    """
    params: GroupEventsBulkParams
    @time_performance(
        scope="group_events_bulk",
        event_name="init",
    )
    def __init__(self, params: GroupEventsBulkParams) -> None:
        super().__init__(params=params)
    @time_performance(
        scope="group_events_bulk",
        event_name="apply",
    )
    def apply(self, df: pd.DataFrame, schema: EventstreamSchemaType) -> pd.DataFrame:
        rules = self.params.grouping_rules
        ignore_intersections = self.params.ignore_intersections
        if isinstance(rules, dict):
            rules_list: List[GroupEventsRule] = []
            for key, val in rules.items():
                rules_list.append(GroupEventsRule(event_name=key, func=val))  # type: ignore
            rules = rules_list
        parent_info = {
            "shape": df.shape,
            "hash": hash_dataframe(df),
        }
        masks: List[pd.Series] = []
        source = df.copy()
        for rule in rules:
            event_name = rule.event_name
            func: Callable = rule.func
            event_type = rule.event_type if rule.event_type else "group_alias"
            expected_args_count = len(signature(func).parameters)
            if expected_args_count == 1:
                mask = func(df)  # type: ignore
                if not ignore_intersections:
                    source_mask = func(source)  # type: ignore
                    masks.append(source_mask)
            else:
                mask = func(df, schema)
                if not ignore_intersections:
                    source_mask = func(source, schema)
                    masks.append(source_mask)
            with pd.option_context("mode.chained_assignment", None):
                df.loc[mask, schema.event_type] = event_type
                df.loc[mask, schema.event_name] = event_name
        if not ignore_intersections:
            intersection_mask = combine_masks(masks)
            has_intersections = intersection_mask.any()
            if has_intersections:
                raise ValueError(
                    "GroupEventsBulk Dataprocessor error. Mapping rules are intersected. Use ignore_intersections=True or fix the intersections"
                )
        collect_data_performance(
            scope="group_events_bulk",
            event_name="metadata",
            called_params=self.to_dict()["values"],
            performance_data={
                "parent": parent_info,
                "child": {
                    "shape": df.shape,
                    "hash": hash_dataframe(df),
                },
            },
        )
        return df