Source code for retentioneering.data_processors_lib.filter_events

from inspect import signature
from typing import Callable, Optional

import pandas as pd
from pandas import DataFrame, Series

from retentioneering.backend.tracker import (
    collect_data_performance,
    time_performance,
    track,
)
from retentioneering.data_processor import DataProcessor
from retentioneering.eventstream.schema import EventstreamSchema
from retentioneering.eventstream.types import EventstreamSchemaType
from retentioneering.params_model import ParamsModel
from retentioneering.utils.doc_substitution import docstrings
from retentioneering.utils.hash_object import hash_dataframe
from retentioneering.widget.widgets import ReteFunction


[docs]class FilterEventsParams(ParamsModel): """ A class with parameters for :py:class:`.FilterEvents` class. """ func: Callable[[DataFrame, Optional[EventstreamSchema]], Series] _widgets = { "func": ReteFunction(), }
[docs]@docstrings.get_sections(base="FilterEvents") # type: ignore class FilterEvents(DataProcessor): """ Filters input ``eventstream`` on the basis of custom conditions. Parameters ---------- func : Callable[[DataFrame, Optional[EventstreamSchema]], bool] Custom function that returns boolean mask the same length as input ``eventstream``. - If ``True`` - the row will be left in the eventstream. - If ``False`` - the row will be deleted from the eventstream. Returns ------- Eventstream ``Eventstream`` with events that should be deleted from input ``eventstream``. Notes ----- See :doc:`Data processors user guide</user_guides/dataprocessors>` for the details. """ params: FilterEventsParams @time_performance( scope="filter_events", event_name="init", ) def __init__(self, params: FilterEventsParams): super().__init__(params=params) @time_performance( scope="filter_events", event_name="apply", ) def apply(self, df: DataFrame, schema: EventstreamSchemaType) -> DataFrame: func: Callable[[DataFrame, Optional[EventstreamSchemaType]], Series] = self.params.func # type: ignore expected_args_count = len(signature(func).parameters) if expected_args_count == 1: mask = func(df) # type: ignore else: mask = func(df, schema) result = df[mask] collect_data_performance( scope="filter_events", event_name="metadata", called_params=self.to_dict()["values"], performance_data={ "parent": { "shape": df.shape, "hash": hash_dataframe(df), }, "child": { "shape": result.shape, "hash": hash_dataframe(result), }, }, ) return result