Skip to content

Splitting - API Reference

warprec.data.splitting.splitter.Splitter

Splitter class will handle the splitting of the data.

Source code in warprec/data/splitting/splitter.py
class Splitter:
    """Splitter class will handle the splitting of the data."""

    def split_transaction(
        self,
        data: FrameT,
        user_id_label: str = "user_id",
        item_id_label: str = "item_id",
        rating_label: str = "rating",
        timestamp_label: str = "timestamp",
        test_strategy: Optional[SplittingStrategies | str] = None,
        test_ratio: Optional[float] = None,
        test_k: Optional[int] = None,
        test_folds: Optional[int] = None,
        test_timestamp: Optional[Union[int, str]] = None,
        test_seed: int = 42,
        val_strategy: Optional[SplittingStrategies | str] = None,
        val_ratio: Optional[float] = None,
        val_k: Optional[int] = None,
        val_folds: Optional[int] = None,
        val_timestamp: Optional[Union[int, str]] = None,
        val_seed: int = 42,
    ) -> Tuple[
        DataFrame[Any],
        Optional[List[Tuple[DataFrame[Any], DataFrame[Any]]] | DataFrame[Any]],
        DataFrame[Any],
    ]:
        """The main method of the class. This method must be called to split the data.

        When called, this method will return the splitting calculated by
        the splitting method selected in the configuration file.

        This method accepts transaction data, and will return the DataFrames of split data.

        A transaction is defined by at least a user_id, an item_id.

        Args:
            data (FrameT): The DataFrame to be splitted.
            user_id_label (str): The user_id label.
            item_id_label (str): The item_id label.
            rating_label (str): The rating label.
            timestamp_label (str): The timestamp label.
            test_strategy (Optional[SplittingStrategies | str]): The splitting strategy to use for test set.
            test_ratio (Optional[float]): The ratio value for test set.
            test_k (Optional[int]): The k value for test set.
            test_folds (Optional[int]): The folds value for test set.
            test_timestamp (Optional[Union[int, str]]): The timestamp to be used for the test set.
                Either an integer or 'best'.
            test_seed (int): The seed value for test set. Defaults to 42.
            val_strategy (Optional[SplittingStrategies | str]): The splitting strategy to use for validation set.
            val_ratio (Optional[float]): The ratio value for validation set.
            val_k (Optional[int]): The k value for validation set.
            val_folds (Optional[int]): The folds value for validation set.
            val_timestamp (Optional[Union[int, str]]): The timestamp to be used for the validation set.
                Either an integer or 'best'.
            val_seed (int): The seed value for validation set.  Defaults to 42.

        Returns:
            Tuple[DataFrame[Any], Optional[List[Tuple[DataFrame[Any], DataFrame[Any]]] | DataFrame[Any]], DataFrame[Any]]:
                - DataFrame[Any]: The original train data, used to train
                    the final model of the experiment.
                - Optional[List[Tuple[DataFrame[Any], DataFrame[Any]]] | DataFrame[Any]]: Either return a list of tuples
                    - DataFrame[Any]: The train data used to train the model.
                    - DataFrame[Any]: The validation data used to evaluate
                        the model during training.
                    or just a single DataFrame representing the validation set.
                - DataFrame[Any]: The unique test data, used at the end of
                    the experiment to evaluate the model.
        """
        data = nw.from_native(data, pass_through=True)

        # Parse strings
        if isinstance(test_strategy, str):
            test_strategy = SplittingStrategies(test_strategy)

        if isinstance(val_strategy, str):
            val_strategy = SplittingStrategies(val_strategy)

        # Test set
        split_process_start_time = time.time()
        logger.msg(
            f"Starting test splitting process with {test_strategy.value} splitting strategy."
        )
        test_split_time_start = time.time()
        original_train_set, test_set = self.process_split(
            data,
            test_strategy,
            user_id_label=user_id_label,
            item_id_label=item_id_label,
            rating_label=rating_label,
            timestamp_label=timestamp_label,
            ratio=test_ratio,
            k=test_k,
            folds=test_folds,
            timestamp=test_timestamp,
            seed=test_seed,
        )[0]
        test_split_time = time.time() - test_split_time_start
        logger.msg(f"Test splitting completed in : {test_split_time:.2f}s")

        # Optional validation folding
        validation_folds: List[Tuple[DataFrame[Any], Optional[DataFrame[Any]]]] = []
        if val_strategy is not None:
            logger.msg(
                f"Starting validation splitting process with {val_strategy.value} splitting strategy."
            )
            validation_split_time_start = time.time()
            folds = self.process_split(
                original_train_set,
                val_strategy,
                user_id_label=user_id_label,
                item_id_label=item_id_label,
                rating_label=rating_label,
                timestamp_label=timestamp_label,
                ratio=val_ratio,
                k=val_k,
                folds=val_folds,
                timestamp=val_timestamp,
                seed=val_seed,
            )
            for train, validation in folds:
                validation_folds.append((train, validation))
            validation_split_time = time.time() - validation_split_time_start
            logger.msg(
                f"Validation splitting completed in : {validation_split_time:.2f}s"
            )

        # Logging of splitting process
        split_process_time = time.time() - split_process_start_time
        logger.positive(f"Splitting process over in {split_process_time:.2f}s.")

        # Filter out the test set
        test_set = self.filter_sets(
            original_train_set, test_set, user_id_label, item_id_label, "Test"
        )

        if len(validation_folds) == 0:
            # CASE 1: Only train and test set
            return (original_train_set, None, test_set)

        if len(validation_folds) == 1:
            # CASE 2: Train/Validation/Test
            train_set, validation_set = validation_folds[0]
            test_set = self.filter_sets(
                train_set, test_set, user_id_label, item_id_label, "Validation"
            )
            return (train_set, validation_set, test_set)

        # Filter out each validation set based on
        # corresponding train set
        for train, validation in validation_folds:
            validation = self.filter_sets(
                train, validation, user_id_label, item_id_label, "Validation"
            )

        # CASE 3: N folds of train and validation + the test set
        return (original_train_set, validation_folds, test_set)

    def process_split(
        self,
        data: FrameT,
        strategy: SplittingStrategies,
        user_id_label: str = "user_id",
        item_id_label: str = "item_id",
        rating_label: str = "rating",
        timestamp_label: str = "timestamp",
        ratio: Optional[float] = None,
        k: Optional[int] = None,
        folds: Optional[int] = None,
        timestamp: Optional[Union[int, str]] = None,
        seed: int = 42,
    ) -> List[Tuple[DataFrame[Any], DataFrame[Any]]]:
        """Process the splitting based on the selected strategy.

        Args:
            data (FrameT): The DataFrame to be splitted.
            strategy (SplittingStrategies): The splitting strategy to use.
            user_id_label (str): The user_id label.
            item_id_label (str): The item_id label.
            rating_label (str): The rating label.
            timestamp_label (str): The timestamp label.
            ratio (Optional[float]): The ratio value.
            k (Optional[int]): The k value.
            folds (Optional[int]): The folds value.
            timestamp (Optional[Union[int, str]]): The timestamp to be used for the splitting.
                Either an integer or 'best'.
            seed (int): The seed value. Defaults to 42.

        Returns:
            List[Tuple[DataFrame[Any], DataFrame[Any]]]: A list of tuples containing the train and evaluation sets.
        """
        splitting_strategy = splitting_registry.get(strategy.value)
        split = splitting_strategy(
            data,
            user_id_label=user_id_label,
            item_id_label=item_id_label,
            rating_label=rating_label,
            timestamp_label=timestamp_label,
            ratio=ratio,
            k=k,
            folds=folds,
            timestamp=timestamp,
            seed=seed,
        )
        return split

    def filter_sets(
        self,
        train_set: DataFrame[Any],
        evaluation_set: DataFrame[Any],
        user_id_label: str = "user_id",
        item_id_label: str = "item_id",
        eval_set_name: Optional[str] = None,
    ) -> DataFrame[Any]:
        """Filter the evaluation set based on the train set.

        Args:
            train_set (DataFrame[Any]): The training set.
            evaluation_set (DataFrame[Any]): The evaluation set to be filtered.
            user_id_label (str): The user ID label.
            item_id_label (str): The item ID label.
            eval_set_name (Optional[str]): The name of the evaluation set.
                Used for logging purposes.

        Returns:
            DataFrame[Any]: The filtered evaluation set.
        """
        train_users = train_set.select(user_id_label).unique()
        train_items = train_set.select(item_id_label).unique()

        # Save the evaluation transaction before filtering
        eval_transaction_count = len(evaluation_set)

        filtered_by_users = evaluation_set.join(
            train_users, on=user_id_label, how="inner"
        )

        filtered_final = filtered_by_users.join(
            train_items, on=item_id_label, how="inner"
        )

        # Log any filtering that happened
        if len(filtered_final) < eval_transaction_count:
            eval_set_name = (
                eval_set_name.capitalize() if eval_set_name is not None else "Eval"
            )
            logger.attention(
                f"{eval_set_name} set was not aligned with the training set. "
                f"Filtered out {eval_transaction_count - len(filtered_final)} transactions."
            )

        return filtered_final

filter_sets(train_set, evaluation_set, user_id_label='user_id', item_id_label='item_id', eval_set_name=None)

Filter the evaluation set based on the train set.

Parameters:

Name Type Description Default
train_set DataFrame[Any]

The training set.

required
evaluation_set DataFrame[Any]

The evaluation set to be filtered.

required
user_id_label str

The user ID label.

'user_id'
item_id_label str

The item ID label.

'item_id'
eval_set_name Optional[str]

The name of the evaluation set. Used for logging purposes.

None

Returns:

Type Description
DataFrame[Any]

DataFrame[Any]: The filtered evaluation set.

Source code in warprec/data/splitting/splitter.py
def filter_sets(
    self,
    train_set: DataFrame[Any],
    evaluation_set: DataFrame[Any],
    user_id_label: str = "user_id",
    item_id_label: str = "item_id",
    eval_set_name: Optional[str] = None,
) -> DataFrame[Any]:
    """Filter the evaluation set based on the train set.

    Args:
        train_set (DataFrame[Any]): The training set.
        evaluation_set (DataFrame[Any]): The evaluation set to be filtered.
        user_id_label (str): The user ID label.
        item_id_label (str): The item ID label.
        eval_set_name (Optional[str]): The name of the evaluation set.
            Used for logging purposes.

    Returns:
        DataFrame[Any]: The filtered evaluation set.
    """
    train_users = train_set.select(user_id_label).unique()
    train_items = train_set.select(item_id_label).unique()

    # Save the evaluation transaction before filtering
    eval_transaction_count = len(evaluation_set)

    filtered_by_users = evaluation_set.join(
        train_users, on=user_id_label, how="inner"
    )

    filtered_final = filtered_by_users.join(
        train_items, on=item_id_label, how="inner"
    )

    # Log any filtering that happened
    if len(filtered_final) < eval_transaction_count:
        eval_set_name = (
            eval_set_name.capitalize() if eval_set_name is not None else "Eval"
        )
        logger.attention(
            f"{eval_set_name} set was not aligned with the training set. "
            f"Filtered out {eval_transaction_count - len(filtered_final)} transactions."
        )

    return filtered_final

process_split(data, strategy, user_id_label='user_id', item_id_label='item_id', rating_label='rating', timestamp_label='timestamp', ratio=None, k=None, folds=None, timestamp=None, seed=42)

Process the splitting based on the selected strategy.

Parameters:

Name Type Description Default
data FrameT

The DataFrame to be splitted.

required
strategy SplittingStrategies

The splitting strategy to use.

required
user_id_label str

The user_id label.

'user_id'
item_id_label str

The item_id label.

'item_id'
rating_label str

The rating label.

'rating'
timestamp_label str

The timestamp label.

'timestamp'
ratio Optional[float]

The ratio value.

None
k Optional[int]

The k value.

None
folds Optional[int]

The folds value.

None
timestamp Optional[Union[int, str]]

The timestamp to be used for the splitting. Either an integer or 'best'.

None
seed int

The seed value. Defaults to 42.

42

Returns:

Type Description
List[Tuple[DataFrame[Any], DataFrame[Any]]]

List[Tuple[DataFrame[Any], DataFrame[Any]]]: A list of tuples containing the train and evaluation sets.

Source code in warprec/data/splitting/splitter.py
def process_split(
    self,
    data: FrameT,
    strategy: SplittingStrategies,
    user_id_label: str = "user_id",
    item_id_label: str = "item_id",
    rating_label: str = "rating",
    timestamp_label: str = "timestamp",
    ratio: Optional[float] = None,
    k: Optional[int] = None,
    folds: Optional[int] = None,
    timestamp: Optional[Union[int, str]] = None,
    seed: int = 42,
) -> List[Tuple[DataFrame[Any], DataFrame[Any]]]:
    """Process the splitting based on the selected strategy.

    Args:
        data (FrameT): The DataFrame to be splitted.
        strategy (SplittingStrategies): The splitting strategy to use.
        user_id_label (str): The user_id label.
        item_id_label (str): The item_id label.
        rating_label (str): The rating label.
        timestamp_label (str): The timestamp label.
        ratio (Optional[float]): The ratio value.
        k (Optional[int]): The k value.
        folds (Optional[int]): The folds value.
        timestamp (Optional[Union[int, str]]): The timestamp to be used for the splitting.
            Either an integer or 'best'.
        seed (int): The seed value. Defaults to 42.

    Returns:
        List[Tuple[DataFrame[Any], DataFrame[Any]]]: A list of tuples containing the train and evaluation sets.
    """
    splitting_strategy = splitting_registry.get(strategy.value)
    split = splitting_strategy(
        data,
        user_id_label=user_id_label,
        item_id_label=item_id_label,
        rating_label=rating_label,
        timestamp_label=timestamp_label,
        ratio=ratio,
        k=k,
        folds=folds,
        timestamp=timestamp,
        seed=seed,
    )
    return split

split_transaction(data, user_id_label='user_id', item_id_label='item_id', rating_label='rating', timestamp_label='timestamp', test_strategy=None, test_ratio=None, test_k=None, test_folds=None, test_timestamp=None, test_seed=42, val_strategy=None, val_ratio=None, val_k=None, val_folds=None, val_timestamp=None, val_seed=42)

The main method of the class. This method must be called to split the data.

When called, this method will return the splitting calculated by the splitting method selected in the configuration file.

This method accepts transaction data, and will return the DataFrames of split data.

A transaction is defined by at least a user_id, an item_id.

Parameters:

Name Type Description Default
data FrameT

The DataFrame to be splitted.

required
user_id_label str

The user_id label.

'user_id'
item_id_label str

The item_id label.

'item_id'
rating_label str

The rating label.

'rating'
timestamp_label str

The timestamp label.

'timestamp'
test_strategy Optional[SplittingStrategies | str]

The splitting strategy to use for test set.

None
test_ratio Optional[float]

The ratio value for test set.

None
test_k Optional[int]

The k value for test set.

None
test_folds Optional[int]

The folds value for test set.

None
test_timestamp Optional[Union[int, str]]

The timestamp to be used for the test set. Either an integer or 'best'.

None
test_seed int

The seed value for test set. Defaults to 42.

42
val_strategy Optional[SplittingStrategies | str]

The splitting strategy to use for validation set.

None
val_ratio Optional[float]

The ratio value for validation set.

None
val_k Optional[int]

The k value for validation set.

None
val_folds Optional[int]

The folds value for validation set.

None
val_timestamp Optional[Union[int, str]]

The timestamp to be used for the validation set. Either an integer or 'best'.

None
val_seed int

The seed value for validation set. Defaults to 42.

42

Returns:

Type Description
Tuple[DataFrame[Any], Optional[List[Tuple[DataFrame[Any], DataFrame[Any]]] | DataFrame[Any]], DataFrame[Any]]

Tuple[DataFrame[Any], Optional[List[Tuple[DataFrame[Any], DataFrame[Any]]] | DataFrame[Any]], DataFrame[Any]]: - DataFrame[Any]: The original train data, used to train the final model of the experiment. - Optional[List[Tuple[DataFrame[Any], DataFrame[Any]]] | DataFrame[Any]]: Either return a list of tuples - DataFrame[Any]: The train data used to train the model. - DataFrame[Any]: The validation data used to evaluate the model during training. or just a single DataFrame representing the validation set. - DataFrame[Any]: The unique test data, used at the end of the experiment to evaluate the model.

Source code in warprec/data/splitting/splitter.py
def split_transaction(
    self,
    data: FrameT,
    user_id_label: str = "user_id",
    item_id_label: str = "item_id",
    rating_label: str = "rating",
    timestamp_label: str = "timestamp",
    test_strategy: Optional[SplittingStrategies | str] = None,
    test_ratio: Optional[float] = None,
    test_k: Optional[int] = None,
    test_folds: Optional[int] = None,
    test_timestamp: Optional[Union[int, str]] = None,
    test_seed: int = 42,
    val_strategy: Optional[SplittingStrategies | str] = None,
    val_ratio: Optional[float] = None,
    val_k: Optional[int] = None,
    val_folds: Optional[int] = None,
    val_timestamp: Optional[Union[int, str]] = None,
    val_seed: int = 42,
) -> Tuple[
    DataFrame[Any],
    Optional[List[Tuple[DataFrame[Any], DataFrame[Any]]] | DataFrame[Any]],
    DataFrame[Any],
]:
    """The main method of the class. This method must be called to split the data.

    When called, this method will return the splitting calculated by
    the splitting method selected in the configuration file.

    This method accepts transaction data, and will return the DataFrames of split data.

    A transaction is defined by at least a user_id, an item_id.

    Args:
        data (FrameT): The DataFrame to be splitted.
        user_id_label (str): The user_id label.
        item_id_label (str): The item_id label.
        rating_label (str): The rating label.
        timestamp_label (str): The timestamp label.
        test_strategy (Optional[SplittingStrategies | str]): The splitting strategy to use for test set.
        test_ratio (Optional[float]): The ratio value for test set.
        test_k (Optional[int]): The k value for test set.
        test_folds (Optional[int]): The folds value for test set.
        test_timestamp (Optional[Union[int, str]]): The timestamp to be used for the test set.
            Either an integer or 'best'.
        test_seed (int): The seed value for test set. Defaults to 42.
        val_strategy (Optional[SplittingStrategies | str]): The splitting strategy to use for validation set.
        val_ratio (Optional[float]): The ratio value for validation set.
        val_k (Optional[int]): The k value for validation set.
        val_folds (Optional[int]): The folds value for validation set.
        val_timestamp (Optional[Union[int, str]]): The timestamp to be used for the validation set.
            Either an integer or 'best'.
        val_seed (int): The seed value for validation set.  Defaults to 42.

    Returns:
        Tuple[DataFrame[Any], Optional[List[Tuple[DataFrame[Any], DataFrame[Any]]] | DataFrame[Any]], DataFrame[Any]]:
            - DataFrame[Any]: The original train data, used to train
                the final model of the experiment.
            - Optional[List[Tuple[DataFrame[Any], DataFrame[Any]]] | DataFrame[Any]]: Either return a list of tuples
                - DataFrame[Any]: The train data used to train the model.
                - DataFrame[Any]: The validation data used to evaluate
                    the model during training.
                or just a single DataFrame representing the validation set.
            - DataFrame[Any]: The unique test data, used at the end of
                the experiment to evaluate the model.
    """
    data = nw.from_native(data, pass_through=True)

    # Parse strings
    if isinstance(test_strategy, str):
        test_strategy = SplittingStrategies(test_strategy)

    if isinstance(val_strategy, str):
        val_strategy = SplittingStrategies(val_strategy)

    # Test set
    split_process_start_time = time.time()
    logger.msg(
        f"Starting test splitting process with {test_strategy.value} splitting strategy."
    )
    test_split_time_start = time.time()
    original_train_set, test_set = self.process_split(
        data,
        test_strategy,
        user_id_label=user_id_label,
        item_id_label=item_id_label,
        rating_label=rating_label,
        timestamp_label=timestamp_label,
        ratio=test_ratio,
        k=test_k,
        folds=test_folds,
        timestamp=test_timestamp,
        seed=test_seed,
    )[0]
    test_split_time = time.time() - test_split_time_start
    logger.msg(f"Test splitting completed in : {test_split_time:.2f}s")

    # Optional validation folding
    validation_folds: List[Tuple[DataFrame[Any], Optional[DataFrame[Any]]]] = []
    if val_strategy is not None:
        logger.msg(
            f"Starting validation splitting process with {val_strategy.value} splitting strategy."
        )
        validation_split_time_start = time.time()
        folds = self.process_split(
            original_train_set,
            val_strategy,
            user_id_label=user_id_label,
            item_id_label=item_id_label,
            rating_label=rating_label,
            timestamp_label=timestamp_label,
            ratio=val_ratio,
            k=val_k,
            folds=val_folds,
            timestamp=val_timestamp,
            seed=val_seed,
        )
        for train, validation in folds:
            validation_folds.append((train, validation))
        validation_split_time = time.time() - validation_split_time_start
        logger.msg(
            f"Validation splitting completed in : {validation_split_time:.2f}s"
        )

    # Logging of splitting process
    split_process_time = time.time() - split_process_start_time
    logger.positive(f"Splitting process over in {split_process_time:.2f}s.")

    # Filter out the test set
    test_set = self.filter_sets(
        original_train_set, test_set, user_id_label, item_id_label, "Test"
    )

    if len(validation_folds) == 0:
        # CASE 1: Only train and test set
        return (original_train_set, None, test_set)

    if len(validation_folds) == 1:
        # CASE 2: Train/Validation/Test
        train_set, validation_set = validation_folds[0]
        test_set = self.filter_sets(
            train_set, test_set, user_id_label, item_id_label, "Validation"
        )
        return (train_set, validation_set, test_set)

    # Filter out each validation set based on
    # corresponding train set
    for train, validation in validation_folds:
        validation = self.filter_sets(
            train, validation, user_id_label, item_id_label, "Validation"
        )

    # CASE 3: N folds of train and validation + the test set
    return (original_train_set, validation_folds, test_set)

warprec.data.splitting.strategies.SplittingStrategy

Bases: ABC

Abstract definition of a splitting strategy.

Source code in warprec/data/splitting/strategies.py
class SplittingStrategy(ABC):
    """Abstract definition of a splitting strategy."""

    def _prepare_stable_data(self, data: FrameT) -> Tuple[DataFrame[Any], str]:
        """Helper method to ensure data has a stable row index for tie-breaking.
        Returns the modified data and the name of the index column.
        """
        # Make the DataFrame eager in case of LazyFrames
        if isinstance(data, nw.LazyFrame):
            materialized_data = data.collect()  # type: ignore[assignment]
        else:
            materialized_data = data

        index_col = "__original_row_index__"
        return materialized_data.with_row_index(name=index_col), index_col

    def __call__(
        self, data: FrameT, **kwargs: Any
    ) -> List[Tuple[DataFrame[Any], DataFrame[Any]]]:
        """This method will split the data in train/eval split.

        Args:
            data (FrameT): The FrameT to be splitted.
            **kwargs (Any): The additional keyword arguments.

        Returns:
            List[Tuple[DataFrame[Any], DataFrame[Any]]]:
                - DataFrame[Any]: First partition of splitted data.
                - DataFrame[Any]: Second partition of splitted data.
        """

__call__(data, **kwargs)

This method will split the data in train/eval split.

Parameters:

Name Type Description Default
data FrameT

The FrameT to be splitted.

required
**kwargs Any

The additional keyword arguments.

{}

Returns:

Type Description
List[Tuple[DataFrame[Any], DataFrame[Any]]]

List[Tuple[DataFrame[Any], DataFrame[Any]]]: - DataFrame[Any]: First partition of splitted data. - DataFrame[Any]: Second partition of splitted data.

Source code in warprec/data/splitting/strategies.py
def __call__(
    self, data: FrameT, **kwargs: Any
) -> List[Tuple[DataFrame[Any], DataFrame[Any]]]:
    """This method will split the data in train/eval split.

    Args:
        data (FrameT): The FrameT to be splitted.
        **kwargs (Any): The additional keyword arguments.

    Returns:
        List[Tuple[DataFrame[Any], DataFrame[Any]]]:
            - DataFrame[Any]: First partition of splitted data.
            - DataFrame[Any]: Second partition of splitted data.
    """