Source code for retentioneering.preprocessing_graph.nodes

from __future__ import annotations

import uuid
from typing import Any, Optional, Type, Union

from retentioneering.data_processor import DataProcessor
from retentioneering.data_processor.registry import dataprocessor_registry
from retentioneering.eventstream.types import EventstreamType
from retentioneering.params_model.registry import params_model_registry


class BaseNode:
    processor: Optional[DataProcessor]
    events: Optional[EventstreamType]
    pk: str
    description: Optional[str]

    def __init__(self, **kwargs: Any) -> None:
        self.pk = str(uuid.uuid4())

    def __str__(self) -> str:
        data = {"name": self.__class__.__name__, "pk": self.pk}
        return str(data)

    __repr__ = __str__

    def export(self) -> dict:
        data: dict[str, Any] = {"name": self.__class__.__name__, "pk": self.pk}
        if self.description:
            data["description"] = self.description

        if processor := getattr(self, "processor", None):
            data["processor"] = processor.to_dict()
        return data

    def copy(self) -> BaseNode:
        return BaseNode()


class SourceNode(BaseNode):
    events: EventstreamType
    description: Optional[str]

    def __init__(self, source: EventstreamType, description: Optional[str] = None) -> None:
        super().__init__()
        self.events = source
        self.description = description

    def __copy__(self) -> SourceNode:
        return SourceNode(
            source=self.events.copy(),
            description=self.description,
        )

    def copy(self) -> SourceNode:
        return self.__copy__()


[docs]class EventsNode(BaseNode): """ Class for regular nodes of a PreprocessingGraph. Notes ----- See :doc:`Preprocessing user guide</user_guides/preprocessing>` for the details. See Also -------- .PreprocessingGraph.add_node : Add a node to PreprocessingGraph. .PreprocessingGraph.combine : Run calculations of PreprocessingGraph. .MergeNode : Merging nodes of a PreprocessingGraph. """ processor: DataProcessor events: Optional[EventstreamType] description: Optional[str] def __init__(self, processor: DataProcessor, description: Optional[str] = None) -> None: super().__init__() self.processor = processor self.events = None self.description = description def __copy__(self) -> EventsNode: return EventsNode( processor=self.processor.copy(), description=self.description, ) def copy(self) -> EventsNode: return self.__copy__()
[docs]class MergeNode(BaseNode): """ Class for merging nodes of a PreprocessingGraph. Notes ----- See :doc:`Preprocessing user guide</user_guides/preprocessing>` for the details. See Also -------- .PreprocessingGraph.add_node : Add a node to PreprocessingGraph. .PreprocessingGraph.combine : Run calculations of PreprocessingGraph. .EventsNode : Regular nodes of a PreprocessingGraph. """ events: Optional[EventstreamType] description: Optional[str] def __init__(self, description: Optional[str] = None) -> None: super().__init__() self.events = None self.description = description def __copy__(self) -> MergeNode: return MergeNode( description=self.description, ) def copy(self) -> MergeNode: return self.__copy__()
Node = Union[SourceNode, EventsNode, MergeNode] nodes = { "MergeNode": MergeNode, "EventsNode": EventsNode, "SourceNode": SourceNode, } class NotFoundDataprocessor(Exception): pass def build_node( source_stream: EventstreamType, pk: str, node_name: str, processor_name: str | None = None, processor_params: dict[str, Any] | None = None, descriptionn: Optional[str] = None, ) -> Node: _node = nodes[node_name] node_kwargs = {} if node_name == "SourceNode": node_kwargs["source"] = source_stream if not processor_params: processor_params = {} if processor_name and node_name == "EventsNode": _params_model_registry = params_model_registry.get_registry() _dataprocessor_registry = dataprocessor_registry.get_registry() _processor: Type[DataProcessor] = _dataprocessor_registry[processor_name] # type: ignore params_name = _processor.__annotations__["params"] _params_model = _params_model_registry[params_name] if type(params_name) is str else params_name params_model = _params_model(**processor_params) processor: DataProcessor = _processor(params=params_model) node_kwargs["processor"] = processor # type: ignore node = _node(**node_kwargs) node.pk = pk node.description = descriptionn return node