Skip to content

Sequential - API Reference

Auto-generated documentation for sequential recommender model classes.

warprec.recommenders.sequential_recommender.caser.Caser

Bases: IterativeRecommender, SequentialRecommenderUtils

Implementation of Caser algorithm from "Personalized Top-N Sequential Recommendation via Convolutional Sequence Embedding" in WSDM 2018.

Parameters:

Name Type Description Default
params dict

Model parameters.

required
info dict

The dictionary containing dataset information.

required
*args Any

Variable length argument list.

()
seed int

The seed to use for reproducibility.

42
**kwargs Any

Arbitrary keyword arguments.

{}

Attributes:

Name Type Description
DATALOADER_TYPE

The type of dataloader used.

embedding_size int

The dimension of the item and user embeddings.

n_h int

The number of horizontal filters.

n_v int

The number of vertical filters.

dropout_prob float

The probability of dropout for the fully connected layer.

reg_weight float

The L2 regularization weight.

weight_decay float

The value of weight decay used in the optimizer.

batch_size int

The batch size used during training.

epochs int

The number of training epochs.

learning_rate float

The learning rate value.

neg_samples int

The number of negative samples.

max_seq_len int

The maximum length of sequences.

Source code in warprec/recommenders/sequential_recommender/caser.py
@model_registry.register(name="Caser")
class Caser(IterativeRecommender, SequentialRecommenderUtils):
    """Implementation of Caser algorithm from
    "Personalized Top-N Sequential Recommendation via Convolutional Sequence Embedding"
    in WSDM 2018.

    Args:
        params (dict): Model parameters.
        info (dict): The dictionary containing dataset information.
        *args (Any): Variable length argument list.
        seed (int): The seed to use for reproducibility.
        **kwargs (Any): Arbitrary keyword arguments.

    Attributes:
        DATALOADER_TYPE: The type of dataloader used.
        embedding_size (int): The dimension of the item and user embeddings.
        n_h (int): The number of horizontal filters.
        n_v (int): The number of vertical filters.
        dropout_prob (float): The probability of dropout for the fully connected layer.
        reg_weight (float): The L2 regularization weight.
        weight_decay (float): The value of weight decay used in the optimizer.
        batch_size (int): The batch size used during training.
        epochs (int): The number of training epochs.
        learning_rate (float): The learning rate value.
        neg_samples (int): The number of negative samples.
        max_seq_len (int): The maximum length of sequences.
    """

    # Dataloader definition
    DATALOADER_TYPE = DataLoaderType.SEQUENTIAL_LOADER_WITH_USER_ID

    # Model hyperparameters
    embedding_size: int
    n_h: int
    n_v: int
    dropout_prob: float
    reg_weight: float
    weight_decay: float
    batch_size: int
    epochs: int
    learning_rate: float
    neg_samples: int
    max_seq_len: int

    def __init__(
        self,
        params: dict,
        info: dict,
        *args: Any,
        seed: int = 42,
        **kwargs: Any,
    ):
        super().__init__(params, info, *args, seed=seed, **kwargs)

        # Layers
        self.user_embedding = nn.Embedding(self.n_users, self.embedding_size)
        self.item_embedding = nn.Embedding(
            self.n_items + 1, self.embedding_size, padding_idx=self.n_items
        )

        # Vertical conv layer
        self.conv_v = nn.Conv2d(
            in_channels=1, out_channels=self.n_v, kernel_size=(self.max_seq_len, 1)
        )

        # Horizontal conv layer
        lengths = [i + 1 for i in range(self.max_seq_len)]
        self.conv_h = nn.ModuleList(
            [
                nn.Conv2d(
                    in_channels=1,
                    out_channels=self.n_h,
                    kernel_size=(i, self.embedding_size),
                )
                for i in lengths
            ]
        )

        # Fully-connected layers
        self.fc1_dim_v = self.n_v * self.embedding_size
        self.fc1_dim_h = self.n_h * len(lengths)
        fc1_dim_in = self.fc1_dim_v + self.fc1_dim_h
        self.fc1 = nn.Linear(fc1_dim_in, self.embedding_size)

        # The second FC layer takes the concatenated output of the first FC layer and the user embedding
        self.fc2 = nn.Linear(
            self.embedding_size + self.embedding_size, self.embedding_size
        )

        self.dropout = nn.Dropout(self.dropout_prob)
        self.ac_conv = nn.ReLU()
        self.ac_fc = nn.ReLU()

        # Initialize weights
        self.apply(self._init_weights)

        # Loss function
        self.main_loss: nn.Module
        if self.neg_samples > 0:
            self.main_loss = BPRLoss()
        else:
            self.main_loss = nn.CrossEntropyLoss()
        self.reg_loss = EmbLoss()

    def get_dataloader(
        self,
        interactions: Interactions,
        sessions: Sessions,
        **kwargs: Any,
    ):
        return sessions.get_sequential_dataloader(
            max_seq_len=self.max_seq_len,
            neg_samples=self.neg_samples,
            batch_size=self.batch_size,
            include_user_id=True,
            **kwargs,
        )

    def training_step(self, batch: Any, batch_idx: int):
        if self.neg_samples > 0:
            user, item_seq, _, pos_item, neg_item = batch
        else:
            user, item_seq, _, pos_item = batch
            neg_item = None

        seq_output = self.forward(user, item_seq)
        pos_items_emb = self.item_embedding(pos_item)  # [batch_size, embedding_size]

        # Calculate main loss and L2 regularization
        if self.neg_samples > 0:
            neg_items_emb = self.item_embedding(
                neg_item
            )  # [batch_size, neg_samples, embedding_size]

            pos_score = torch.sum(seq_output * pos_items_emb, dim=-1)  # [batch_size]
            neg_score = torch.sum(
                seq_output.unsqueeze(1) * neg_items_emb, dim=-1
            )  # [batch_size, neg_samples]
            main_loss = self.main_loss(pos_score, neg_score)

            # L2 regularization
            reg_loss = self.reg_weight * self.reg_loss(
                self.item_embedding(item_seq),
                self.user_embedding(user),
                pos_items_emb,
                neg_items_emb,
            )
        else:
            test_item_emb = self.item_embedding.weight
            logits = torch.matmul(seq_output, test_item_emb.transpose(0, 1))
            main_loss = self.main_loss(logits, pos_item)

            # L2 regularization
            reg_loss = self.reg_weight * self.reg_loss(
                self.item_embedding(item_seq),
                self.user_embedding(user),
                pos_items_emb,
            )

        # Loss logging
        loss = main_loss + reg_loss
        self.log("loss", loss, prog_bar=True, on_step=False, on_epoch=True)
        return loss

    def forward(self, user: Tensor, item_seq: Tensor) -> Tensor:
        """Forward pass of the Caser model.

        Args:
            user (Tensor): The user ID for each sequence [batch_size,].
            item_seq (Tensor): Padded sequences of item IDs [batch_size, max_seq_len].

        Returns:
            Tensor: The final sequence output embedding [batch_size, embedding_size].
        """
        # --- Embedding Look-up ---
        # Unsqueeze to get a 4-D input for convolution layers:
        # (batch_size, 1, max_seq_len, embedding_size)
        item_seq_emb = self.item_embedding(item_seq).unsqueeze(1)
        user_emb = self.user_embedding(user)  # [batch_size, embedding_size]

        # --- Convolutional Layers ---
        out_v = None
        # Vertical convolution
        if self.n_v > 0:
            out_v = self.conv_v(item_seq_emb)
            out_v = out_v.view(-1, self.fc1_dim_v)  # Reshape for FC layer

        # Horizontal convolution
        out_hs = []
        out_h = None
        if self.n_h > 0:
            for conv in self.conv_h:
                conv_out = self.ac_conv(conv(item_seq_emb).squeeze(3))
                pool_out = F.max_pool1d(conv_out, conv_out.size(2)).squeeze(2)
                out_hs.append(pool_out)
            out_h = torch.cat(out_hs, 1)  # Concatenate outputs of all filters

        # Concatenate vertical and horizontal outputs
        conv_out = torch.cat([out_v, out_h], 1)

        # --- Fully-connected Layers ---
        # Apply dropout
        conv_out = self.dropout(conv_out)

        # First FC layer
        z = self.ac_fc(self.fc1(conv_out))

        # Concatenate with user embedding
        x = torch.cat([z, user_emb], 1)

        # Second FC layer
        seq_output = self.fc2(x)
        seq_output = self.ac_fc(seq_output)

        return seq_output

    def predict(
        self,
        user_indices: Tensor,
        user_seq: Tensor,
        *args: Any,
        item_indices: Optional[Tensor] = None,
        **kwargs: Any,
    ) -> Tensor:
        """
        Prediction using the learned session embeddings.

        Args:
            user_indices (Tensor): The batch of user indices.
            user_seq (Tensor): Padded sequences of item IDs for users to predict for.
            *args (Any): List of arguments.
            item_indices (Optional[Tensor]): The batch of item indices. If None,
                full prediction will be produced.
            **kwargs (Any): The dictionary of keyword arguments.

        Returns:
            Tensor: The score matrix {user x item}.
        """
        # Get sequence output embeddings
        seq_output = self.forward(
            user_indices, user_seq
        )  # [batch_size, embedding_size]

        if item_indices is None:
            # Case 'full': prediction on all items
            item_embeddings = self.item_embedding.weight[
                :-1, :
            ]  # [n_items, embedding_size]
            einsum_string = "be,ie->bi"  # b: batch, e: embedding, i: item
        else:
            # Case 'sampled': prediction on a sampled set of items
            item_embeddings = self.item_embedding(
                item_indices
            )  # [batch_size, pad_seq, embedding_size]
            einsum_string = "be,bse->bs"  # b: batch, e: embedding, s: sample

        predictions = torch.einsum(
            einsum_string, seq_output, item_embeddings
        )  # [batch_size, n_items] or [batch_size, pad_seq]
        return predictions

forward(user, item_seq)

Forward pass of the Caser model.

Parameters:

Name Type Description Default
user Tensor

The user ID for each sequence [batch_size,].

required
item_seq Tensor

Padded sequences of item IDs [batch_size, max_seq_len].

required

Returns:

Name Type Description
Tensor Tensor

The final sequence output embedding [batch_size, embedding_size].

Source code in warprec/recommenders/sequential_recommender/caser.py
def forward(self, user: Tensor, item_seq: Tensor) -> Tensor:
    """Forward pass of the Caser model.

    Args:
        user (Tensor): The user ID for each sequence [batch_size,].
        item_seq (Tensor): Padded sequences of item IDs [batch_size, max_seq_len].

    Returns:
        Tensor: The final sequence output embedding [batch_size, embedding_size].
    """
    # --- Embedding Look-up ---
    # Unsqueeze to get a 4-D input for convolution layers:
    # (batch_size, 1, max_seq_len, embedding_size)
    item_seq_emb = self.item_embedding(item_seq).unsqueeze(1)
    user_emb = self.user_embedding(user)  # [batch_size, embedding_size]

    # --- Convolutional Layers ---
    out_v = None
    # Vertical convolution
    if self.n_v > 0:
        out_v = self.conv_v(item_seq_emb)
        out_v = out_v.view(-1, self.fc1_dim_v)  # Reshape for FC layer

    # Horizontal convolution
    out_hs = []
    out_h = None
    if self.n_h > 0:
        for conv in self.conv_h:
            conv_out = self.ac_conv(conv(item_seq_emb).squeeze(3))
            pool_out = F.max_pool1d(conv_out, conv_out.size(2)).squeeze(2)
            out_hs.append(pool_out)
        out_h = torch.cat(out_hs, 1)  # Concatenate outputs of all filters

    # Concatenate vertical and horizontal outputs
    conv_out = torch.cat([out_v, out_h], 1)

    # --- Fully-connected Layers ---
    # Apply dropout
    conv_out = self.dropout(conv_out)

    # First FC layer
    z = self.ac_fc(self.fc1(conv_out))

    # Concatenate with user embedding
    x = torch.cat([z, user_emb], 1)

    # Second FC layer
    seq_output = self.fc2(x)
    seq_output = self.ac_fc(seq_output)

    return seq_output

predict(user_indices, user_seq, *args, item_indices=None, **kwargs)

Prediction using the learned session embeddings.

Parameters:

Name Type Description Default
user_indices Tensor

The batch of user indices.

required
user_seq Tensor

Padded sequences of item IDs for users to predict for.

required
*args Any

List of arguments.

()
item_indices Optional[Tensor]

The batch of item indices. If None, full prediction will be produced.

None
**kwargs Any

The dictionary of keyword arguments.

{}

Returns:

Name Type Description
Tensor Tensor

The score matrix {user x item}.

Source code in warprec/recommenders/sequential_recommender/caser.py
def predict(
    self,
    user_indices: Tensor,
    user_seq: Tensor,
    *args: Any,
    item_indices: Optional[Tensor] = None,
    **kwargs: Any,
) -> Tensor:
    """
    Prediction using the learned session embeddings.

    Args:
        user_indices (Tensor): The batch of user indices.
        user_seq (Tensor): Padded sequences of item IDs for users to predict for.
        *args (Any): List of arguments.
        item_indices (Optional[Tensor]): The batch of item indices. If None,
            full prediction will be produced.
        **kwargs (Any): The dictionary of keyword arguments.

    Returns:
        Tensor: The score matrix {user x item}.
    """
    # Get sequence output embeddings
    seq_output = self.forward(
        user_indices, user_seq
    )  # [batch_size, embedding_size]

    if item_indices is None:
        # Case 'full': prediction on all items
        item_embeddings = self.item_embedding.weight[
            :-1, :
        ]  # [n_items, embedding_size]
        einsum_string = "be,ie->bi"  # b: batch, e: embedding, i: item
    else:
        # Case 'sampled': prediction on a sampled set of items
        item_embeddings = self.item_embedding(
            item_indices
        )  # [batch_size, pad_seq, embedding_size]
        einsum_string = "be,bse->bs"  # b: batch, e: embedding, s: sample

    predictions = torch.einsum(
        einsum_string, seq_output, item_embeddings
    )  # [batch_size, n_items] or [batch_size, pad_seq]
    return predictions

warprec.recommenders.sequential_recommender.fossil.FOSSIL

Bases: IterativeRecommender, SequentialRecommenderUtils

Implementation of FOSSIL algorithm from "Fusing Similarity Models with Markov Chains for Sparse Sequential Recommendation." in ICDM 2016.

FOSSIL uses similarity of the items as main purpose and uses high MC as a way of sequential preference improve of ability of sequential recommendation.

Parameters:

Name Type Description Default
params dict

Model parameters.

required
info dict

The dictionary containing dataset information.

required
*args Any

Variable length argument list.

()
seed int

The seed to use for reproducibility.

42
**kwargs Any

Arbitrary keyword arguments.

{}

Attributes:

Name Type Description
DATALOADER_TYPE

The type of dataloader used.

embedding_size int

The dimension of the item embeddings.

order_len int

The number of last items to consider for high-order Markov chains.

alpha float

The parameter for calculating similarity.

reg_weight float

The L2 regularization weight.

batch_size int

The batch size used for training.

epochs int

The number of training epochs.

learning_rate float

The learning rate value.

neg_samples int

The number of negative samples.

max_seq_len int

The maximum length of sequences.

Source code in warprec/recommenders/sequential_recommender/fossil.py
@model_registry.register(name="FOSSIL")
class FOSSIL(IterativeRecommender, SequentialRecommenderUtils):
    """Implementation of FOSSIL algorithm from
    "Fusing Similarity Models with Markov Chains for Sparse Sequential Recommendation." in ICDM 2016.

    FOSSIL uses similarity of the items as main purpose and uses high MC as a way of sequential preference improve of
    ability of sequential recommendation.

    Args:
        params (dict): Model parameters.
        info (dict): The dictionary containing dataset information.
        *args (Any): Variable length argument list.
        seed (int): The seed to use for reproducibility.
        **kwargs (Any): Arbitrary keyword arguments.

    Attributes:
        DATALOADER_TYPE: The type of dataloader used.
        embedding_size (int): The dimension of the item embeddings.
        order_len (int): The number of last items to consider for high-order Markov chains.
        alpha (float): The parameter for calculating similarity.
        reg_weight (float): The L2 regularization weight.
        batch_size (int): The batch size used for training.
        epochs (int): The number of training epochs.
        learning_rate (float): The learning rate value.
        neg_samples (int): The number of negative samples.
        max_seq_len (int): The maximum length of sequences.
    """

    # Dataloader definition
    DATALOADER_TYPE = DataLoaderType.SEQUENTIAL_LOADER_WITH_USER_ID

    # Model hyperparameters
    embedding_size: int
    order_len: int
    alpha: float
    reg_weight: float
    batch_size: int
    epochs: int
    learning_rate: float
    neg_samples: int
    max_seq_len: int

    def __init__(
        self,
        params: dict,
        info: dict,
        *args: Any,
        seed: int = 42,
        **kwargs: Any,
    ):
        super().__init__(params, info, *args, seed=seed, **kwargs)

        # Define the layers
        self.item_embedding = nn.Embedding(
            self.n_items + 1, self.embedding_size, padding_idx=self.n_items
        )
        self.user_lambda = nn.Embedding(
            self.n_users, self.order_len
        )  # User specific weights for Markov chains
        self.lambda_ = nn.Parameter(
            torch.zeros(self.order_len)
        )  # Global weights for Markov chains

        # Initialize weights
        self.apply(self._init_weights)

        # Loss function
        self.main_loss: nn.Module
        if self.neg_samples > 0:
            self.main_loss = BPRLoss()
        else:
            self.main_loss = nn.CrossEntropyLoss()
        self.reg_loss = EmbLoss()

    def get_dataloader(
        self,
        interactions: Interactions,
        sessions: Sessions,
        **kwargs: Any,
    ):
        return sessions.get_sequential_dataloader(
            max_seq_len=self.max_seq_len,
            neg_samples=self.neg_samples,
            batch_size=self.batch_size,
            include_user_id=True,
            **kwargs,
        )

    def _inverse_seq_item_embedding(
        self, seq_item_embedding: Tensor, seq_item_len: Tensor
    ) -> Tensor:
        """Inverts and pads sequence item embeddings to create a "short" item embedding.

        This method effectively shifts and gathers specific item embeddings from the end
        of each sequence, effectively creating a "short" representation of the sequence
        from its tail, padded with zeros at the beginning. This is often used in models
        where the most recent interactions are of particular interest for higher-order
        Markov chains or similar sequential processing.

        Args:
            seq_item_embedding (Tensor): A tensor representing the embeddings of items in a sequence.
                                         Expected shape: (batch_size, sequence_length, embedding_dim).
            seq_item_len (Tensor): A tensor representing the actual lengths of the sequences.
                                   Expected shape: (batch_size,).

        Returns:
            Tensor: A tensor representing the "short" item embeddings, extracted from the
                    tail of the original sequences and padded.
                    Expected shape: (batch_size, order_len, embedding_dim).
        """
        # Create a tensor of zeros with the same shape and type as seq_item_embedding
        # This will be prepended to the sequence embeddings to act as padding or initial state.
        zeros = torch.zeros_like(seq_item_embedding, dtype=torch.float).to(
            seq_item_embedding.device
        )  # (batch_size, sequence_length, embedding_dim)

        # Concatenate zeros to the beginning of seq_item_embedding along dimension 1 (sequence_length)
        # This effectively shifts the original embeddings to the right, creating a padded sequence.
        item_embedding_zeros = torch.cat(
            [zeros, seq_item_embedding], dim=1
        )  # (batch_size, 2 * sequence_length, embedding_dim)

        # Iterate 'order_len' times to gather specific items from the padded sequence
        embedding_list = []
        for i in range(self.order_len):
            # Calculate the index for gathering. This index is relative to the padded sequence.
            # The indices are designed to gather the last `order_len` items from the original sequence
            # within the context of the `item_embedding_zeros` tensor.
            embedding = self._gather_indexes(
                item_embedding_zeros,
                self.max_seq_len + seq_item_len - self.order_len + i,
            )  # (batch_size, embedding_dim)
            embedding_list.append(embedding.unsqueeze(1))

        # Concatenate all the gathered embeddings along dimension 1
        # This stacks the 'order_len' individual item embeddings for each sequence.
        short_item_embedding = torch.cat(
            embedding_list, dim=1
        )  # (batch_size, order_len, embedding_dim)
        return short_item_embedding

    def _get_high_order_Markov(
        self, high_order_item_embedding: Tensor, user: Tensor
    ) -> Tensor:
        """Calculates a weighted high-order Markov embedding based
            on user and item interactions.

        This method applies user-specific and general lambda weights
        to the high-order item embeddings, effectively creating a weighted sum
        that represents a more personalized high-order Markov state.

        Args:
            high_order_item_embedding (Tensor): A tensor representing the high-order embeddings of items.
                Expected shape: (batch_size, n_items, embedding_dim).
            user (Tensor): A tensor representing the user embedding or features.
                Expected shape: (batch_size, user_feature_dim).

        Returns:
            Tensor: A tensor representing the aggregated high-order Markov embedding after applying
                the lambda weights and summing along the item dimension.
                Expected shape: (batch_size, embedding_dim).
        """
        # Calculate user-specific lambda and unsqueeze dimensions for broadcasting
        user_lambda = self.user_lambda(user).unsqueeze(dim=2)  # (batch_size, 1, 1)

        # Unsqueeze general lambda for broadcasting
        lambda_ = self.lambda_.unsqueeze(dim=0).unsqueeze(
            dim=2
        )  # (1, num_lambda_weights, 1)

        # Add user-specific and general lambda values
        lambda_ = torch.add(user_lambda, lambda_)  # (batch_size, n_items, 1)

        # Apply the combined lambda weights to the high-order item embeddings
        high_order_item_embedding = torch.mul(
            high_order_item_embedding, lambda_
        )  # (batch_size, n_items, embedding_dim)

        # Sum the weighted embeddings along the item dimension
        high_order_item_embedding = high_order_item_embedding.sum(
            dim=1
        )  # (batch_size, embedding_dim)

        return high_order_item_embedding

    def _get_similarity(
        self, seq_item_embedding: Tensor, seq_item_len: Tensor
    ) -> Tensor:
        """Calculates a weighted similarity based on sequence item embeddings and their lengths.

        This method computes a coefficient based on the inverse power of sequence item lengths,
        then multiplies this coefficient with the sum of sequence item embeddings. This effectively
        down-weights the influence of longer sequences in the final similarity calculation.

        Args:
            seq_item_embedding (Tensor): A tensor representing the embeddings of items in a sequence.
                                         Expected shape: (batch_size, sequence_length, embedding_dim).
            seq_item_len (Tensor): A tensor representing the actual lengths of the sequences.
                                   Expected shape: (batch_size,).

        Returns:
            Tensor: A tensor representing the similarity score for each sequence after applying
                    the length-based weighting.
                    Expected shape: (batch_size, embedding_dim).
        """
        # Calculate the coefficient based on sequence length
        coeff = torch.pow(
            seq_item_len.unsqueeze(1).float(), -self.alpha
        )  # (batch_size,  1)

        # Multiply the coefficient with the summed embeddings to compute similarity
        similarity = torch.mul(
            coeff, seq_item_embedding.sum(dim=1)
        )  # (batch_size, embedding_size)
        return similarity

    def training_step(self, batch: Any, batch_idx: int):
        if self.neg_samples > 0:
            user_id, item_seq, item_seq_len, pos_item, neg_item = batch
        else:
            user_id, item_seq, item_seq_len, pos_item = batch
            neg_item = None

        seq_output = self.forward(user_id, item_seq, item_seq_len)

        pos_items_emb = self.item_embedding(pos_item)

        # Calculate main loss and L2 regularization
        if self.neg_samples > 0:
            neg_items_emb = self.item_embedding(neg_item)
            pos_score = torch.sum(seq_output * pos_items_emb, dim=-1)
            neg_score = torch.sum(seq_output.unsqueeze(1) * neg_items_emb, dim=-1)
            main_loss = self.main_loss(pos_score, neg_score)

            # L2 regularization
            reg_loss = self.reg_weight * self.reg_loss(
                self.item_embedding(item_seq),
                self.user_lambda(user_id),
                pos_items_emb,
                neg_items_emb,
            )
        else:
            test_item_emb = self.item_embedding.weight
            logits = torch.matmul(seq_output, test_item_emb.transpose(0, 1))
            main_loss = self.main_loss(logits, pos_item)

            # L2 regularization
            reg_loss = self.reg_weight * self.reg_loss(
                self.item_embedding(item_seq),
                self.user_lambda(user_id),
                pos_items_emb,
            )

        # Loss logging
        loss = main_loss + reg_loss
        self.log("loss", loss, prog_bar=True, on_step=False, on_epoch=True)
        return loss

    def forward(
        self,
        user_id: Tensor,
        item_seq: Tensor,
        item_seq_len: Tensor,
    ) -> Tensor:
        """Forward pass of the FOSSIL model.

        Args:
            user_id (Tensor): User IDs for each sequence [batch_size,].
            item_seq (Tensor): Padded sequences of item IDs [batch_size, max_seq_len].
            item_seq_len (Tensor): Actual lengths of sequences [batch_size,].

        Returns:
            Tensor: The combined embedding for prediction [batch_size, embedding_size].
        """
        seq_item_embedding = self.item_embedding(item_seq)

        high_order_seq_item_embedding = self._inverse_seq_item_embedding(
            seq_item_embedding, item_seq_len
        )
        # batch_size * order_len * embedding

        high_order = self._get_high_order_Markov(high_order_seq_item_embedding, user_id)
        similarity = self._get_similarity(seq_item_embedding, item_seq_len)

        return high_order + similarity

    def predict(
        self,
        user_indices: Tensor,
        user_seq: Tensor,
        seq_len: Tensor,
        *args: Any,
        item_indices: Optional[Tensor] = None,
        **kwargs: Any,
    ) -> Tensor:
        """
        Prediction using the learned session embeddings.

        Args:
            user_indices (Tensor): The batch of user indices.
            user_seq (Tensor): Padded sequences of item IDs for users to predict for.
            seq_len (Tensor): Actual lengths of these sequences, before padding.
            *args (Any): List of arguments.
            item_indices (Optional[Tensor]): The batch of item indices. If None,
                full prediction will be produced.
            **kwargs (Any): The dictionary of keyword arguments.

        Returns:
            Tensor: The score matrix {user x item}.
        """
        # Get sequence output embeddings
        seq_output = self.forward(
            user_indices, user_seq, seq_len
        )  # [batch_size, embedding_size]

        if item_indices is None:
            # Case 'full': prediction on all items
            item_embeddings = self.item_embedding.weight[
                :-1, :
            ]  # [n_items, embedding_size]
            einsum_string = "be,ie->bi"  # b: batch, e: embedding, i: item
        else:
            # Case 'sampled': prediction on a sampled set of items
            item_embeddings = self.item_embedding(
                item_indices
            )  # [batch_size, pad_seq, embedding_size]
            einsum_string = "be,bse->bs"  # b: batch, e: embedding, s: sample

        predictions = torch.einsum(
            einsum_string, seq_output, item_embeddings
        )  # [batch_size, n_items] or [batch_size, pad_seq]
        return predictions

forward(user_id, item_seq, item_seq_len)

Forward pass of the FOSSIL model.

Parameters:

Name Type Description Default
user_id Tensor

User IDs for each sequence [batch_size,].

required
item_seq Tensor

Padded sequences of item IDs [batch_size, max_seq_len].

required
item_seq_len Tensor

Actual lengths of sequences [batch_size,].

required

Returns:

Name Type Description
Tensor Tensor

The combined embedding for prediction [batch_size, embedding_size].

Source code in warprec/recommenders/sequential_recommender/fossil.py
def forward(
    self,
    user_id: Tensor,
    item_seq: Tensor,
    item_seq_len: Tensor,
) -> Tensor:
    """Forward pass of the FOSSIL model.

    Args:
        user_id (Tensor): User IDs for each sequence [batch_size,].
        item_seq (Tensor): Padded sequences of item IDs [batch_size, max_seq_len].
        item_seq_len (Tensor): Actual lengths of sequences [batch_size,].

    Returns:
        Tensor: The combined embedding for prediction [batch_size, embedding_size].
    """
    seq_item_embedding = self.item_embedding(item_seq)

    high_order_seq_item_embedding = self._inverse_seq_item_embedding(
        seq_item_embedding, item_seq_len
    )
    # batch_size * order_len * embedding

    high_order = self._get_high_order_Markov(high_order_seq_item_embedding, user_id)
    similarity = self._get_similarity(seq_item_embedding, item_seq_len)

    return high_order + similarity

predict(user_indices, user_seq, seq_len, *args, item_indices=None, **kwargs)

Prediction using the learned session embeddings.

Parameters:

Name Type Description Default
user_indices Tensor

The batch of user indices.

required
user_seq Tensor

Padded sequences of item IDs for users to predict for.

required
seq_len Tensor

Actual lengths of these sequences, before padding.

required
*args Any

List of arguments.

()
item_indices Optional[Tensor]

The batch of item indices. If None, full prediction will be produced.

None
**kwargs Any

The dictionary of keyword arguments.

{}

Returns:

Name Type Description
Tensor Tensor

The score matrix {user x item}.

Source code in warprec/recommenders/sequential_recommender/fossil.py
def predict(
    self,
    user_indices: Tensor,
    user_seq: Tensor,
    seq_len: Tensor,
    *args: Any,
    item_indices: Optional[Tensor] = None,
    **kwargs: Any,
) -> Tensor:
    """
    Prediction using the learned session embeddings.

    Args:
        user_indices (Tensor): The batch of user indices.
        user_seq (Tensor): Padded sequences of item IDs for users to predict for.
        seq_len (Tensor): Actual lengths of these sequences, before padding.
        *args (Any): List of arguments.
        item_indices (Optional[Tensor]): The batch of item indices. If None,
            full prediction will be produced.
        **kwargs (Any): The dictionary of keyword arguments.

    Returns:
        Tensor: The score matrix {user x item}.
    """
    # Get sequence output embeddings
    seq_output = self.forward(
        user_indices, user_seq, seq_len
    )  # [batch_size, embedding_size]

    if item_indices is None:
        # Case 'full': prediction on all items
        item_embeddings = self.item_embedding.weight[
            :-1, :
        ]  # [n_items, embedding_size]
        einsum_string = "be,ie->bi"  # b: batch, e: embedding, i: item
    else:
        # Case 'sampled': prediction on a sampled set of items
        item_embeddings = self.item_embedding(
            item_indices
        )  # [batch_size, pad_seq, embedding_size]
        einsum_string = "be,bse->bs"  # b: batch, e: embedding, s: sample

    predictions = torch.einsum(
        einsum_string, seq_output, item_embeddings
    )  # [batch_size, n_items] or [batch_size, pad_seq]
    return predictions

warprec.recommenders.sequential_recommender.gru4rec.GRU4Rec

Bases: IterativeRecommender, SequentialRecommenderUtils

Implementation of GRU4Rec algorithm from "Improved Recurrent Neural Networks for Session-based Recommendations." in DLRS 2016.

Parameters:

Name Type Description Default
params dict

Model parameters.

required
info dict

The dictionary containing dataset information.

required
*args Any

Variable length argument list.

()
seed int

The seed to use for reproducibility.

42
**kwargs Any

Arbitrary keyword arguments.

{}

Attributes:

Name Type Description
DATALOADER_TYPE

The type of dataloader used.

embedding_size int

The dimension of the item embeddings.

hidden_size int

The number of features in the hidden state of the GRU.

num_layers int

The number of recurrent layers.

dropout_prob float

The probability of dropout for the embeddings.

reg_weight float

The L2 regularization weight.

weight_decay float

The value of weight decay used in optimizer.

batch_size int

The batch size used for training.

epochs int

The number of training epochs.

learning_rate float

The learning rate value.

neg_samples int

The number of negative samples.

max_seq_len int

The maximum length of sequences.

Source code in warprec/recommenders/sequential_recommender/gru4rec.py
@model_registry.register(name="GRU4Rec")
class GRU4Rec(IterativeRecommender, SequentialRecommenderUtils):
    """Implementation of GRU4Rec algorithm from
    "Improved Recurrent Neural Networks for Session-based Recommendations." in DLRS 2016.

    Args:
        params (dict): Model parameters.
        info (dict): The dictionary containing dataset information.
        *args (Any): Variable length argument list.
        seed (int): The seed to use for reproducibility.
        **kwargs (Any): Arbitrary keyword arguments.

    Attributes:
        DATALOADER_TYPE: The type of dataloader used.
        embedding_size (int): The dimension of the item embeddings.
        hidden_size (int): The number of features in the hidden state of the GRU.
        num_layers (int): The number of recurrent layers.
        dropout_prob (float): The probability of dropout for the embeddings.
        reg_weight (float): The L2 regularization weight.
        weight_decay (float): The value of weight decay used in optimizer.
        batch_size (int): The batch size used for training.
        epochs (int): The number of training epochs.
        learning_rate (float): The learning rate value.
        neg_samples (int): The number of negative samples.
        max_seq_len (int): The maximum length of sequences.
    """

    # Dataloader definition
    DATALOADER_TYPE = DataLoaderType.SEQUENTIAL_LOADER

    # Model hyperparameters
    embedding_size: int
    hidden_size: int
    num_layers: int
    dropout_prob: float
    reg_weight: float
    weight_decay: float
    batch_size: int
    epochs: int
    learning_rate: float
    neg_samples: int
    max_seq_len: int

    def __init__(
        self,
        params: dict,
        info: dict,
        *args: Any,
        seed: int = 42,
        **kwargs: Any,
    ):
        super().__init__(params, info, *args, seed=seed, **kwargs)

        self.item_embedding = nn.Embedding(
            self.n_items + 1,
            self.embedding_size,
            padding_idx=self.n_items,
        )
        self.emb_dropout = nn.Dropout(self.dropout_prob)
        self.gru_layers = nn.GRU(
            input_size=self.embedding_size,
            hidden_size=self.hidden_size,
            num_layers=self.num_layers,
            bias=False,
            batch_first=True,  # Input tensors are (batch, seq_len, features)
        )

        # Dense layer to project GRU output back to embedding_size
        # Used for prediction
        self.dense = nn.Linear(self.hidden_size, self.embedding_size)

        # Initialize weights
        self.apply(self._init_weights)

        # Loss function will be based on number of
        # negative samples
        self.main_loss: nn.Module
        if self.neg_samples > 0:
            self.main_loss = BPRLoss()
        else:
            self.main_loss = nn.CrossEntropyLoss()
        self.reg_loss = EmbLoss()

    def get_dataloader(
        self,
        interactions: Interactions,
        sessions: Sessions,
        **kwargs: Any,
    ):
        return sessions.get_sequential_dataloader(
            max_seq_len=self.max_seq_len,
            neg_samples=self.neg_samples,
            batch_size=self.batch_size,
            **kwargs,
        )

    def training_step(self, batch: Any, batch_idx: int):
        if self.neg_samples > 0:
            item_seq, item_seq_len, pos_item, neg_item = batch
        else:
            item_seq, item_seq_len, pos_item = batch
            neg_item = None

        seq_output = self.forward(item_seq, item_seq_len)
        pos_items_emb = self.item_embedding(pos_item)  # [batch_size, embedding_size]

        # Calculate main loss and L2 regularization
        if self.neg_samples > 0:
            neg_items_emb = self.item_embedding(
                neg_item
            )  # [batch_size, embedding_size]

            pos_score = torch.sum(seq_output * pos_items_emb, dim=-1)  # [batch_size]
            neg_score = torch.sum(
                seq_output.unsqueeze(1) * neg_items_emb, dim=-1
            )  # [batch_size]
            main_loss = self.main_loss(pos_score, neg_score)

            # L2 regularization
            reg_loss = self.reg_weight * self.reg_loss(
                self.item_embedding(item_seq),
                pos_items_emb,
                neg_items_emb,
            )
        else:
            test_item_emb = self.item_embedding.weight  # [n_items, embedding_size]
            logits = torch.matmul(
                seq_output, test_item_emb.transpose(0, 1)
            )  # [batch_size, n_items]
            main_loss = self.main_loss(logits, pos_item)

            # L2 regularization
            reg_loss = self.reg_weight * self.reg_loss(
                self.item_embedding(item_seq),
                pos_items_emb,
            )

        # Loss logging
        loss = main_loss + reg_loss
        self.log("loss", loss, prog_bar=True, on_step=False, on_epoch=True)
        return loss

    def forward(self, item_seq: Tensor, item_seq_len: Tensor) -> Tensor:
        """Forward pass of the GRU4Rec model.

        Args:
            item_seq (Tensor): Padded sequences of item IDs [batch_size, max_seq_len].
            item_seq_len (Tensor): Actual lengths of sequences [batch_size,].

        Returns:
            Tensor: The embedding of the predicted item (last session state)
                    [batch_size, embedding_size].
        """
        item_seq_emb = self.item_embedding(
            item_seq
        )  # [batch_size, max_seq_len, embedding_size]
        item_seq_emb_dropout = self.emb_dropout(item_seq_emb)

        # GRU layers
        # NOTE: Only the output sequence is used in the forward pass
        gru_output, _ = self.gru_layers(item_seq_emb_dropout)
        gru_output = self.dense(gru_output)  # [batch_size, max_seq_len, embedding_size]

        # Use the utility method to gather the last index of
        # the predicted sequence (the next item)
        seq_output = self._gather_indexes(
            gru_output, item_seq_len - 1
        )  # [batch_size, embedding_size]
        return seq_output

    def predict(
        self,
        user_seq: Tensor,
        seq_len: Tensor,
        *args: Any,
        item_indices: Optional[Tensor] = None,
        **kwargs: Any,
    ) -> Tensor:
        """
        Prediction using the learned session embeddings.

        Args:
            user_seq (Tensor): Padded sequences of item IDs for users to predict for.
            seq_len (Tensor): Actual lengths of these sequences, before padding.
            *args (Any): List of arguments.
            item_indices (Optional[Tensor]): The batch of item indices. If None,
                full prediction will be produced.
            **kwargs (Any): The dictionary of keyword arguments.

        Returns:
            Tensor: The score matrix {user x item}.
        """
        # Get sequence output embeddings
        seq_output = self.forward(user_seq, seq_len)  # [batch_size, embedding_size]

        if item_indices is None:
            # Case 'full': prediction on all items
            item_embeddings = self.item_embedding.weight[
                :-1, :
            ]  # [n_items, embedding_size]
            einsum_string = "be,ie->bi"  # b: batch, e: embedding, i: item
        else:
            # Case 'sampled': prediction on a sampled set of items
            item_embeddings = self.item_embedding(
                item_indices
            )  # [batch_size, pad_seq, embedding_size]
            einsum_string = "be,bse->bs"  # b: batch, e: embedding, s: sample

        predictions = torch.einsum(
            einsum_string, seq_output, item_embeddings
        )  # [batch_size, n_items] or [batch_size, pad_seq]
        return predictions

forward(item_seq, item_seq_len)

Forward pass of the GRU4Rec model.

Parameters:

Name Type Description Default
item_seq Tensor

Padded sequences of item IDs [batch_size, max_seq_len].

required
item_seq_len Tensor

Actual lengths of sequences [batch_size,].

required

Returns:

Name Type Description
Tensor Tensor

The embedding of the predicted item (last session state) [batch_size, embedding_size].

Source code in warprec/recommenders/sequential_recommender/gru4rec.py
def forward(self, item_seq: Tensor, item_seq_len: Tensor) -> Tensor:
    """Forward pass of the GRU4Rec model.

    Args:
        item_seq (Tensor): Padded sequences of item IDs [batch_size, max_seq_len].
        item_seq_len (Tensor): Actual lengths of sequences [batch_size,].

    Returns:
        Tensor: The embedding of the predicted item (last session state)
                [batch_size, embedding_size].
    """
    item_seq_emb = self.item_embedding(
        item_seq
    )  # [batch_size, max_seq_len, embedding_size]
    item_seq_emb_dropout = self.emb_dropout(item_seq_emb)

    # GRU layers
    # NOTE: Only the output sequence is used in the forward pass
    gru_output, _ = self.gru_layers(item_seq_emb_dropout)
    gru_output = self.dense(gru_output)  # [batch_size, max_seq_len, embedding_size]

    # Use the utility method to gather the last index of
    # the predicted sequence (the next item)
    seq_output = self._gather_indexes(
        gru_output, item_seq_len - 1
    )  # [batch_size, embedding_size]
    return seq_output

predict(user_seq, seq_len, *args, item_indices=None, **kwargs)

Prediction using the learned session embeddings.

Parameters:

Name Type Description Default
user_seq Tensor

Padded sequences of item IDs for users to predict for.

required
seq_len Tensor

Actual lengths of these sequences, before padding.

required
*args Any

List of arguments.

()
item_indices Optional[Tensor]

The batch of item indices. If None, full prediction will be produced.

None
**kwargs Any

The dictionary of keyword arguments.

{}

Returns:

Name Type Description
Tensor Tensor

The score matrix {user x item}.

Source code in warprec/recommenders/sequential_recommender/gru4rec.py
def predict(
    self,
    user_seq: Tensor,
    seq_len: Tensor,
    *args: Any,
    item_indices: Optional[Tensor] = None,
    **kwargs: Any,
) -> Tensor:
    """
    Prediction using the learned session embeddings.

    Args:
        user_seq (Tensor): Padded sequences of item IDs for users to predict for.
        seq_len (Tensor): Actual lengths of these sequences, before padding.
        *args (Any): List of arguments.
        item_indices (Optional[Tensor]): The batch of item indices. If None,
            full prediction will be produced.
        **kwargs (Any): The dictionary of keyword arguments.

    Returns:
        Tensor: The score matrix {user x item}.
    """
    # Get sequence output embeddings
    seq_output = self.forward(user_seq, seq_len)  # [batch_size, embedding_size]

    if item_indices is None:
        # Case 'full': prediction on all items
        item_embeddings = self.item_embedding.weight[
            :-1, :
        ]  # [n_items, embedding_size]
        einsum_string = "be,ie->bi"  # b: batch, e: embedding, i: item
    else:
        # Case 'sampled': prediction on a sampled set of items
        item_embeddings = self.item_embedding(
            item_indices
        )  # [batch_size, pad_seq, embedding_size]
        einsum_string = "be,bse->bs"  # b: batch, e: embedding, s: sample

    predictions = torch.einsum(
        einsum_string, seq_output, item_embeddings
    )  # [batch_size, n_items] or [batch_size, pad_seq]
    return predictions

warprec.recommenders.sequential_recommender.narm.NARM

Bases: IterativeRecommender, SequentialRecommenderUtils

Implementation of NARM algorithm from "Neural Attentive Session-based Recommendation." in CIKM 2017.

NARM explores a hybrid encoder with an attention mechanism to model the user’s sequential behavior (Global Encoder) and capture the user’s main purpose in the current session (Local Encoder).

Parameters:

Name Type Description Default
params dict

Model parameters.

required
info dict

The dictionary containing dataset information.

required
*args Any

Variable length argument list.

()
seed int

The seed to use for reproducibility.

42
**kwargs Any

Arbitrary keyword arguments.

{}

Attributes:

Name Type Description
DATALOADER_TYPE

The type of dataloader used.

embedding_size int

The dimension of the item embeddings.

hidden_size int

The number of features in the hidden state of the GRU.

n_layers int

The number of recurrent layers in the GRU.

hidden_dropout_prob float

Dropout probability for the item embeddings.

attn_dropout_prob float

Dropout probability for the hybrid session representation.

reg_weight float

The L2 regularization weight.

weight_decay float

The value of weight decay used in the optimizer.

batch_size int

The batch size used for training.

epochs int

The number of training epochs.

learning_rate float

The learning rate value.

neg_samples int

The number of negative samples.

max_seq_len int

The maximum length of sequences.

Source code in warprec/recommenders/sequential_recommender/narm.py
@model_registry.register(name="NARM")
class NARM(IterativeRecommender, SequentialRecommenderUtils):
    """Implementation of NARM algorithm from
    "Neural Attentive Session-based Recommendation." in CIKM 2017.

    NARM explores a hybrid encoder with an attention mechanism to model the
    user’s sequential behavior (Global Encoder) and capture the user’s
    main purpose in the current session (Local Encoder).

    Args:
        params (dict): Model parameters.
        info (dict): The dictionary containing dataset information.
        *args (Any): Variable length argument list.
        seed (int): The seed to use for reproducibility.
        **kwargs (Any): Arbitrary keyword arguments.

    Attributes:
        DATALOADER_TYPE: The type of dataloader used.
        embedding_size (int): The dimension of the item embeddings.
        hidden_size (int): The number of features in the hidden state of the GRU.
        n_layers (int): The number of recurrent layers in the GRU.
        hidden_dropout_prob (float): Dropout probability for the item embeddings.
        attn_dropout_prob (float): Dropout probability for the hybrid session representation.
        reg_weight (float): The L2 regularization weight.
        weight_decay (float): The value of weight decay used in the optimizer.
        batch_size (int): The batch size used for training.
        epochs (int): The number of training epochs.
        learning_rate (float): The learning rate value.
        neg_samples (int): The number of negative samples.
        max_seq_len (int): The maximum length of sequences.
    """

    # Dataloader definition
    DATALOADER_TYPE = DataLoaderType.SEQUENTIAL_LOADER

    # Model hyperparameters
    embedding_size: int
    hidden_size: int
    n_layers: int
    hidden_dropout_prob: float
    attn_dropout_prob: float
    reg_weight: float
    weight_decay: float
    batch_size: int
    epochs: int
    learning_rate: float
    neg_samples: int
    max_seq_len: int

    def __init__(
        self,
        params: dict,
        info: dict,
        *args: Any,
        seed: int = 42,
        **kwargs: Any,
    ):
        super().__init__(params, info, *args, seed=seed, **kwargs)

        # Item embedding
        self.item_embedding = nn.Embedding(
            self.n_items + 1,
            self.embedding_size,
            padding_idx=self.n_items,
        )
        self.emb_dropout = nn.Dropout(self.hidden_dropout_prob)

        # Sequential Encoder (GRU)
        self.gru = nn.GRU(
            input_size=self.embedding_size,
            hidden_size=self.hidden_size,
            num_layers=self.n_layers,
            bias=False,
            batch_first=True,
        )

        # Attention layers for Local Encoder
        self.a_1 = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
        self.a_2 = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
        self.v_t = nn.Linear(self.hidden_size, 1, bias=False)
        self.ct_dropout = nn.Dropout(self.attn_dropout_prob)

        # Final projection to align hybrid representation with item embedding space
        self.b = nn.Linear(2 * self.hidden_size, self.embedding_size, bias=False)

        # Initialize weights using Xavier Normal (recommended in NARM paper)
        self.apply(self._init_weights)

        # Loss function setup
        self.main_loss: nn.Module
        if self.neg_samples > 0:
            self.main_loss = BPRLoss()
        else:
            self.main_loss = nn.CrossEntropyLoss()
        self.reg_loss = EmbLoss()

    def get_dataloader(
        self,
        interactions: Interactions,
        sessions: Sessions,
        **kwargs: Any,
    ):
        return sessions.get_sequential_dataloader(
            max_seq_len=self.max_seq_len,
            neg_samples=self.neg_samples,
            batch_size=self.batch_size,
            **kwargs,
        )

    def training_step(self, batch: Any, batch_idx: int):
        if self.neg_samples > 0:
            item_seq, item_seq_len, pos_item, neg_item = batch
        else:
            item_seq, item_seq_len, pos_item = batch
            neg_item = None

        seq_output = self.forward(item_seq, item_seq_len)
        pos_items_emb = self.item_embedding(pos_item)

        if self.neg_samples > 0:
            # Pairwise BPR Loss
            neg_items_emb = self.item_embedding(neg_item)
            pos_score = torch.sum(seq_output * pos_items_emb, dim=-1)
            neg_score = torch.sum(seq_output.unsqueeze(1) * neg_items_emb, dim=-1)
            main_loss = self.main_loss(pos_score, neg_score)

            # L2 Regularization
            reg_loss = self.reg_weight * self.reg_loss(
                self.item_embedding(item_seq),
                pos_items_emb,
                neg_items_emb,
            )
        else:
            # Pointwise Cross Entropy Loss
            test_item_emb = self.item_embedding.weight[:-1, :]
            logits = torch.matmul(seq_output, test_item_emb.transpose(0, 1))
            main_loss = self.main_loss(logits, pos_item)

            # L2 Regularization
            reg_loss = self.reg_weight * self.reg_loss(
                self.item_embedding(item_seq),
                pos_items_emb,
            )

        # Loss logging
        loss = main_loss + reg_loss
        self.log("loss", loss, prog_bar=True, on_step=False, on_epoch=True)
        return loss

    def forward(self, item_seq: Tensor, item_seq_len: Tensor) -> Tensor:
        """Forward pass of the NARM model.

        Args:
            item_seq (Tensor): Padded sequences of item IDs [batch_size, max_seq_len].
            item_seq_len (Tensor): Actual lengths of sequences.

        Returns:
            Tensor: The hybrid session representation [batch_size, embedding_size].
        """
        item_seq_emb = self.item_embedding(item_seq)
        item_seq_emb_dropout = self.emb_dropout(item_seq_emb)

        # GRU encoding
        gru_out, _ = self.gru(item_seq_emb_dropout)

        # Global Encoder (c_global): The last hidden state of the GRU
        c_global = self._gather_indexes(gru_out, item_seq_len - 1)

        # Local Encoder (c_local): Attention mechanism over all hidden states
        # Avoid influence of padding tokens in attention
        mask = (item_seq != self.n_items).unsqueeze(2).expand_as(gru_out)
        q1 = self.a_1(gru_out)
        q2 = self.a_2(c_global).unsqueeze(1)

        # Calculate attention weights alpha

        # Compute non-normalized logits
        # q1: [B, S, H]
        # q2: [B, 1, H]
        # v_t output: [B, S, 1]
        alpha_logits = self.v_t(torch.sigmoid(q1 + q2))

        # Apply masking: [B, S, 1]
        mask = (item_seq != self.n_items).unsqueeze(2)
        alpha_logits = alpha_logits.masked_fill(mask == 0, -1e9)

        # Apply softmax over temporal dimension
        alpha = F.softmax(alpha_logits, dim=1)
        c_local = torch.sum(alpha * gru_out, dim=1)  # Weighted sum

        # Hybrid Representation: Concatenate Local and Global vectors
        c_t = torch.cat([c_local, c_global], dim=1)
        c_t = self.ct_dropout(c_t)

        # Final projection to the item embedding space
        seq_output = self.b(c_t)

        return seq_output

    def predict(
        self,
        user_seq: Tensor,
        seq_len: Tensor,
        *args: Any,
        item_indices: Optional[Tensor] = None,
        **kwargs: Any,
    ) -> Tensor:
        """
        Prediction using the learned session embeddings.

        Args:
            user_seq (Tensor): Padded sequences of item IDs for users to predict for.
            seq_len (Tensor): Actual lengths of these sequences, before padding.
            *args (Any): List of arguments.
            item_indices (Optional[Tensor]): The batch of item indices. If None,
                full prediction will be produced.
            **kwargs (Any): The dictionary of keyword arguments.

        Returns:
            Tensor: The score matrix {user x item}.
        """
        # Get sequence output embeddings
        seq_output = self.forward(user_seq, seq_len)  # [batch_size, embedding_size]

        if item_indices is None:
            # Case 'full': prediction on all items
            item_embeddings = self.item_embedding.weight[
                :-1, :
            ]  # [n_items, embedding_size]
            einsum_string = "be,ie->bi"  # b: batch, e: embedding, i: item
        else:
            # Case 'sampled': prediction on a sampled set of items
            item_embeddings = self.item_embedding(
                item_indices
            )  # [batch_size, pad_seq, embedding_size]
            einsum_string = "be,bse->bs"  # b: batch, e: embedding, s: sample

        predictions = torch.einsum(
            einsum_string, seq_output, item_embeddings
        )  # [batch_size, n_items] or [batch_size, pad_seq]
        return predictions

forward(item_seq, item_seq_len)

Forward pass of the NARM model.

Parameters:

Name Type Description Default
item_seq Tensor

Padded sequences of item IDs [batch_size, max_seq_len].

required
item_seq_len Tensor

Actual lengths of sequences.

required

Returns:

Name Type Description
Tensor Tensor

The hybrid session representation [batch_size, embedding_size].

Source code in warprec/recommenders/sequential_recommender/narm.py
def forward(self, item_seq: Tensor, item_seq_len: Tensor) -> Tensor:
    """Forward pass of the NARM model.

    Args:
        item_seq (Tensor): Padded sequences of item IDs [batch_size, max_seq_len].
        item_seq_len (Tensor): Actual lengths of sequences.

    Returns:
        Tensor: The hybrid session representation [batch_size, embedding_size].
    """
    item_seq_emb = self.item_embedding(item_seq)
    item_seq_emb_dropout = self.emb_dropout(item_seq_emb)

    # GRU encoding
    gru_out, _ = self.gru(item_seq_emb_dropout)

    # Global Encoder (c_global): The last hidden state of the GRU
    c_global = self._gather_indexes(gru_out, item_seq_len - 1)

    # Local Encoder (c_local): Attention mechanism over all hidden states
    # Avoid influence of padding tokens in attention
    mask = (item_seq != self.n_items).unsqueeze(2).expand_as(gru_out)
    q1 = self.a_1(gru_out)
    q2 = self.a_2(c_global).unsqueeze(1)

    # Calculate attention weights alpha

    # Compute non-normalized logits
    # q1: [B, S, H]
    # q2: [B, 1, H]
    # v_t output: [B, S, 1]
    alpha_logits = self.v_t(torch.sigmoid(q1 + q2))

    # Apply masking: [B, S, 1]
    mask = (item_seq != self.n_items).unsqueeze(2)
    alpha_logits = alpha_logits.masked_fill(mask == 0, -1e9)

    # Apply softmax over temporal dimension
    alpha = F.softmax(alpha_logits, dim=1)
    c_local = torch.sum(alpha * gru_out, dim=1)  # Weighted sum

    # Hybrid Representation: Concatenate Local and Global vectors
    c_t = torch.cat([c_local, c_global], dim=1)
    c_t = self.ct_dropout(c_t)

    # Final projection to the item embedding space
    seq_output = self.b(c_t)

    return seq_output

predict(user_seq, seq_len, *args, item_indices=None, **kwargs)

Prediction using the learned session embeddings.

Parameters:

Name Type Description Default
user_seq Tensor

Padded sequences of item IDs for users to predict for.

required
seq_len Tensor

Actual lengths of these sequences, before padding.

required
*args Any

List of arguments.

()
item_indices Optional[Tensor]

The batch of item indices. If None, full prediction will be produced.

None
**kwargs Any

The dictionary of keyword arguments.

{}

Returns:

Name Type Description
Tensor Tensor

The score matrix {user x item}.

Source code in warprec/recommenders/sequential_recommender/narm.py
def predict(
    self,
    user_seq: Tensor,
    seq_len: Tensor,
    *args: Any,
    item_indices: Optional[Tensor] = None,
    **kwargs: Any,
) -> Tensor:
    """
    Prediction using the learned session embeddings.

    Args:
        user_seq (Tensor): Padded sequences of item IDs for users to predict for.
        seq_len (Tensor): Actual lengths of these sequences, before padding.
        *args (Any): List of arguments.
        item_indices (Optional[Tensor]): The batch of item indices. If None,
            full prediction will be produced.
        **kwargs (Any): The dictionary of keyword arguments.

    Returns:
        Tensor: The score matrix {user x item}.
    """
    # Get sequence output embeddings
    seq_output = self.forward(user_seq, seq_len)  # [batch_size, embedding_size]

    if item_indices is None:
        # Case 'full': prediction on all items
        item_embeddings = self.item_embedding.weight[
            :-1, :
        ]  # [n_items, embedding_size]
        einsum_string = "be,ie->bi"  # b: batch, e: embedding, i: item
    else:
        # Case 'sampled': prediction on a sampled set of items
        item_embeddings = self.item_embedding(
            item_indices
        )  # [batch_size, pad_seq, embedding_size]
        einsum_string = "be,bse->bs"  # b: batch, e: embedding, s: sample

    predictions = torch.einsum(
        einsum_string, seq_output, item_embeddings
    )  # [batch_size, n_items] or [batch_size, pad_seq]
    return predictions

warprec.recommenders.sequential_recommender.bert4rec.BERT4Rec

Bases: IterativeRecommender, SequentialRecommenderUtils

Implementation of BERT4Rec algorithm from "BERT4Rec: Sequential Recommendation with Bidirectional Encoder Representations from Transformer." in CIKM 2019.

This model uses a bidirectional Transformer to learn item representations based on a masked item prediction task (cloze task). For next-item prediction, a special [MASK] token is appended to the sequence.

Parameters:

Name Type Description Default
params dict

Model parameters.

required
info dict

The dictionary containing dataset information.

required
*args Any

Variable length argument list.

()
seed int

The seed to use for reproducibility.

42
**kwargs Any

Arbitrary keyword arguments.

{}

Attributes:

Name Type Description
DATALOADER_TYPE

The type of dataloader used.

embedding_size int

The dimension of the item embeddings (hidden_size).

n_layers int

The number of transformer encoder layers.

n_heads int

The number of attention heads in the transformer.

inner_size int

The dimensionality of the feed-forward layer in the transformer.

dropout_prob float

The probability of dropout for embeddings and other layers.

attn_dropout_prob float

The probability of dropout for the attention weights.

mask_prob float

The probability of an item being masked during training.

reg_weight float

The L2 regularization weight.

weight_decay float

The value of weight decay used in the optimizer.

batch_size int

The batch size used during training.

epochs int

The number of training epochs.

learning_rate float

The learning rate value.

neg_samples int

The number of negative samples for BPR loss.

max_seq_len int

The maximum length of sequences.

Source code in warprec/recommenders/sequential_recommender/bert4rec.py
@model_registry.register(name="BERT4Rec")
class BERT4Rec(IterativeRecommender, SequentialRecommenderUtils):
    """Implementation of BERT4Rec algorithm from
    "BERT4Rec: Sequential Recommendation with Bidirectional Encoder Representations from Transformer."
    in CIKM 2019.

    This model uses a bidirectional Transformer to learn item representations based on a
    masked item prediction task (cloze task). For next-item prediction, a special [MASK]
    token is appended to the sequence.

    Args:
        params (dict): Model parameters.
        info (dict): The dictionary containing dataset information.
        *args (Any): Variable length argument list.
        seed (int): The seed to use for reproducibility.
        **kwargs (Any): Arbitrary keyword arguments.

    Attributes:
        DATALOADER_TYPE: The type of dataloader used.
        embedding_size (int): The dimension of the item embeddings (hidden_size).
        n_layers (int): The number of transformer encoder layers.
        n_heads (int): The number of attention heads in the transformer.
        inner_size (int): The dimensionality of the feed-forward layer in the transformer.
        dropout_prob (float): The probability of dropout for embeddings and other layers.
        attn_dropout_prob (float): The probability of dropout for the attention weights.
        mask_prob (float): The probability of an item being masked during training.
        reg_weight (float): The L2 regularization weight.
        weight_decay (float): The value of weight decay used in the optimizer.
        batch_size (int): The batch size used during training.
        epochs (int): The number of training epochs.
        learning_rate (float): The learning rate value.
        neg_samples (int): The number of negative samples for BPR loss.
        max_seq_len (int): The maximum length of sequences.
    """

    # Dataloader definition
    DATALOADER_TYPE = DataLoaderType.CLOZE_MASK_LOADER

    # Model hyperparameters
    embedding_size: int
    n_layers: int
    n_heads: int
    inner_size: int
    dropout_prob: float
    attn_dropout_prob: float
    mask_prob: float
    reg_weight: float
    weight_decay: float
    batch_size: int
    epochs: int
    learning_rate: float
    neg_samples: int
    max_seq_len: int

    def __init__(
        self,
        params: dict,
        info: dict,
        *args: Any,
        seed: int = 42,
        **kwargs: Any,
    ):
        super().__init__(params, info, *args, seed=seed, **kwargs)

        # Define special token IDs
        self.padding_token_id = self.n_items
        self.mask_token_id = self.n_items + 1

        # Item embedding needs to accommodate items, padding token, and mask token
        self.item_embedding = nn.Embedding(
            self.n_items + 2, self.embedding_size, padding_idx=self.padding_token_id
        )

        # Take into account the extra [MASK] token in position embeddings
        self.position_embedding = nn.Embedding(
            self.max_seq_len + 1, self.embedding_size
        )
        self.layernorm = nn.LayerNorm(self.embedding_size, eps=1e-8)
        self.dropout = nn.Dropout(self.dropout_prob)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=self.embedding_size,
            nhead=self.n_heads,
            dim_feedforward=self.inner_size,
            dropout=self.attn_dropout_prob,
            activation="gelu",
            batch_first=True,
            norm_first=False,
        )
        self.transformer_encoder = nn.TransformerEncoder(
            encoder_layer, num_layers=self.n_layers
        )

        # Final projection layer, as in the original implementation
        self.out_bias = nn.Parameter(torch.zeros(self.n_items + 1))

        self.apply(self._init_weights)
        self.bpr_loss = BPRLoss()
        self.reg_loss = EmbLoss()

    def get_dataloader(
        self,
        interactions: Interactions,
        sessions: Sessions,
        **kwargs,
    ):
        return sessions.get_cloze_mask_dataloader(
            max_seq_len=self.max_seq_len,
            mask_prob=self.mask_prob,
            neg_samples=self.neg_samples,
            batch_size=self.batch_size,
            mask_token_id=self.mask_token_id,
            **kwargs,
        )

    def training_step(self, batch: Any, batch_idx: int):
        masked_seq, pos_items, neg_items, masked_indices = batch

        # Get the output of the bidirectional transformer
        transformer_output = self.forward(masked_seq)

        # Gather the hidden states at the masked positions
        seq_output = self._multi_hot_gather(transformer_output, masked_indices)

        # Get embeddings for positive and negative items
        pos_items_emb = self.item_embedding(pos_items)
        neg_items_emb = self.item_embedding(neg_items)

        # Get the output bias for positive and negative items
        pos_bias = self.out_bias[pos_items]
        neg_bias = self.out_bias[neg_items]

        # Calculate BPR loss
        pos_score = torch.sum(seq_output * pos_items_emb, dim=-1) + pos_bias
        neg_score = (
            torch.sum(seq_output.unsqueeze(2) * neg_items_emb, dim=-1) + neg_bias
        )
        loss_mask = masked_indices > 0
        bpr_loss = self.bpr_loss(pos_score[loss_mask], neg_score[loss_mask])

        # Calculate L2 regularization
        reg_loss = self.reg_weight * self.reg_loss(
            self.item_embedding(masked_seq),
            pos_items_emb,
            neg_items_emb,
        )

        # Loss logging
        loss = bpr_loss + reg_loss
        self.log("loss", loss, prog_bar=True, on_step=False, on_epoch=True)
        return loss

    def forward(self, item_seq: Tensor) -> Tensor:
        """
        Forward pass of BERT4Rec. Uses bidirectional attention.

        Args:
            item_seq (Tensor): Sequence of items, potentially with [MASK] tokens.

        Returns:
            Tensor: Output of the Transformer for each token [batch_size, seq_len, embedding_size].
        """
        # Padding mask to ignore padding tokens
        padding_mask = item_seq == self.padding_token_id

        position_ids = torch.arange(
            item_seq.size(1), dtype=torch.long, device=item_seq.device
        )
        position_ids = position_ids.unsqueeze(0).expand_as(item_seq)

        item_emb = self.item_embedding(item_seq)
        pos_emb = self.position_embedding(position_ids)

        input_emb = self.layernorm(item_emb + pos_emb)
        input_emb = self.dropout(input_emb)

        # For bidirectional attention, the causal mask is None
        transformer_output = self.transformer_encoder(
            src=input_emb, mask=None, src_key_padding_mask=padding_mask
        )
        return transformer_output

    def _multi_hot_gather(self, source: Tensor, indices: Tensor) -> Tensor:
        """Gathers specific vectors from a source tensor based on indices.
        This is an efficient way to select the transformer outputs at masked positions.

        Args:
            source (Tensor): The source tensor [batch_size, seq_len, embedding_size].
            indices (Tensor): The indices to gather [batch_size, num_masked].

        Returns:
            Tensor: The gathered vectors [batch_size, num_masked, embedding_size].
        """
        # Add a dimension for the embedding size
        indices_expanded = indices.unsqueeze(-1).expand(-1, -1, source.size(-1))
        return torch.gather(source, 1, indices_expanded)

    def _prepare_for_prediction(self, user_seq: Tensor, seq_len: Tensor) -> Tensor:
        """Appends a [MASK] token at the end of each sequence for next-item prediction."""
        # Create a new sequence with one extra spot for the mask token
        pred_seq = torch.full(
            (user_seq.size(0), user_seq.size(1) + 1),
            self.padding_token_id,
            dtype=torch.long,
            device=user_seq.device,
        )
        pred_seq[:, : user_seq.size(1)] = user_seq

        # Place the mask token at the end of the actual sequence length
        batch_indices = torch.arange(user_seq.size(0), device=user_seq.device)
        pred_seq[batch_indices, seq_len] = self.mask_token_id

        return pred_seq

    def predict(
        self,
        user_seq: Tensor,
        seq_len: Tensor,
        *args: Any,
        item_indices: Optional[Tensor] = None,
        **kwargs: Any,
    ) -> Tensor:
        """
        Prediction using the learned bidirectional embeddings.

        Args:
            user_seq (Tensor): Padded sequences of item IDs for users to predict for.
            seq_len (Tensor): Actual lengths of these sequences, before padding.
            *args (Any): List of arguments.
            item_indices (Optional[Tensor]): The batch of item indices. If None,
                full prediction will be produced.
            **kwargs (Any): The dictionary of keyword arguments.

        Returns:
            Tensor: The score matrix {user x item}.
        """
        # Prepare the sequence by appending a [MASK] token
        pred_seq = self._prepare_for_prediction(user_seq, seq_len)

        # Get the output of the bidirectional transformer
        transformer_output = self.forward(pred_seq)

        # Gather the output embedding at the position of the [MASK] token
        seq_output = self._gather_indexes(
            transformer_output, seq_len
        )  # [batch_size, embedding_size]

        if item_indices is None:
            # Case 'full': use all item embeddings (excluding padding and mask)
            item_embeddings = self.item_embedding.weight[
                : self.n_items, :
            ]  # [n_items, embedding_size]
            einsum_string = "be,ie->bi"  # b: batch, e: embedding, i: item
            bias = self.out_bias[: self.n_items]
        else:
            # Case 'sampled': use only the provided item indices
            item_embeddings = self.item_embedding(
                item_indices
            )  # [batch_size, pad_seq, embedding_size]
            einsum_string = "be,bse->bs"  # b: batch, e: embedding, s: sample
            bias = self.out_bias[item_indices]

        predictions = (
            torch.einsum(einsum_string, seq_output, item_embeddings) + bias
        )  # [batch_size, n_items] or [batch_size, pad_seq]
        return predictions

forward(item_seq)

Forward pass of BERT4Rec. Uses bidirectional attention.

Parameters:

Name Type Description Default
item_seq Tensor

Sequence of items, potentially with [MASK] tokens.

required

Returns:

Name Type Description
Tensor Tensor

Output of the Transformer for each token [batch_size, seq_len, embedding_size].

Source code in warprec/recommenders/sequential_recommender/bert4rec.py
def forward(self, item_seq: Tensor) -> Tensor:
    """
    Forward pass of BERT4Rec. Uses bidirectional attention.

    Args:
        item_seq (Tensor): Sequence of items, potentially with [MASK] tokens.

    Returns:
        Tensor: Output of the Transformer for each token [batch_size, seq_len, embedding_size].
    """
    # Padding mask to ignore padding tokens
    padding_mask = item_seq == self.padding_token_id

    position_ids = torch.arange(
        item_seq.size(1), dtype=torch.long, device=item_seq.device
    )
    position_ids = position_ids.unsqueeze(0).expand_as(item_seq)

    item_emb = self.item_embedding(item_seq)
    pos_emb = self.position_embedding(position_ids)

    input_emb = self.layernorm(item_emb + pos_emb)
    input_emb = self.dropout(input_emb)

    # For bidirectional attention, the causal mask is None
    transformer_output = self.transformer_encoder(
        src=input_emb, mask=None, src_key_padding_mask=padding_mask
    )
    return transformer_output

predict(user_seq, seq_len, *args, item_indices=None, **kwargs)

Prediction using the learned bidirectional embeddings.

Parameters:

Name Type Description Default
user_seq Tensor

Padded sequences of item IDs for users to predict for.

required
seq_len Tensor

Actual lengths of these sequences, before padding.

required
*args Any

List of arguments.

()
item_indices Optional[Tensor]

The batch of item indices. If None, full prediction will be produced.

None
**kwargs Any

The dictionary of keyword arguments.

{}

Returns:

Name Type Description
Tensor Tensor

The score matrix {user x item}.

Source code in warprec/recommenders/sequential_recommender/bert4rec.py
def predict(
    self,
    user_seq: Tensor,
    seq_len: Tensor,
    *args: Any,
    item_indices: Optional[Tensor] = None,
    **kwargs: Any,
) -> Tensor:
    """
    Prediction using the learned bidirectional embeddings.

    Args:
        user_seq (Tensor): Padded sequences of item IDs for users to predict for.
        seq_len (Tensor): Actual lengths of these sequences, before padding.
        *args (Any): List of arguments.
        item_indices (Optional[Tensor]): The batch of item indices. If None,
            full prediction will be produced.
        **kwargs (Any): The dictionary of keyword arguments.

    Returns:
        Tensor: The score matrix {user x item}.
    """
    # Prepare the sequence by appending a [MASK] token
    pred_seq = self._prepare_for_prediction(user_seq, seq_len)

    # Get the output of the bidirectional transformer
    transformer_output = self.forward(pred_seq)

    # Gather the output embedding at the position of the [MASK] token
    seq_output = self._gather_indexes(
        transformer_output, seq_len
    )  # [batch_size, embedding_size]

    if item_indices is None:
        # Case 'full': use all item embeddings (excluding padding and mask)
        item_embeddings = self.item_embedding.weight[
            : self.n_items, :
        ]  # [n_items, embedding_size]
        einsum_string = "be,ie->bi"  # b: batch, e: embedding, i: item
        bias = self.out_bias[: self.n_items]
    else:
        # Case 'sampled': use only the provided item indices
        item_embeddings = self.item_embedding(
            item_indices
        )  # [batch_size, pad_seq, embedding_size]
        einsum_string = "be,bse->bs"  # b: batch, e: embedding, s: sample
        bias = self.out_bias[item_indices]

    predictions = (
        torch.einsum(einsum_string, seq_output, item_embeddings) + bias
    )  # [batch_size, n_items] or [batch_size, pad_seq]
    return predictions

warprec.recommenders.sequential_recommender.bsarec.BSARec

Bases: IterativeRecommender, SequentialRecommenderUtils

Implementation of BASRec model from "BSARec: Bandlimited Self-Attention for Sequential Recommendation." in AAAi 2024.

This model combines frequency-based filtering with self-attention to capture both periodic patterns and sequential dependencies in user behavior.

Architecture: 1. Domain-Specific Patterns (DSP): FFT-based low/high-pass filtering 2. Graph-Space Patterns (GSP): Multi-head self-attention 3. Adaptive Combination: Learnable weighted sum (alpha parameter)

The frequency filtering helps capture cyclical patterns (e.g., weekly habits), while attention captures complex sequential dependencies.

Parameters:

Name Type Description Default
params dict

Model parameters.

required
info dict

The dictionary containing dataset information.

required
*args Any

Variable length argument list.

()
seed int

The seed to use for reproducibility.

42
**kwargs Any

Arbitrary keyword arguments.

{}

Attributes:

Name Type Description
DATALOADER_TYPE

The type of dataloader used.

embedding_size int

The dimension of the item embeddings (hidden_size).

n_layers int

The number of transformer encoder layers.

n_heads int

The number of attention heads in the transformer.

inner_size int

The dimensionality of the feed-forward layer.

dropout_prob float

The probability of dropout for embeddings.

attn_dropout_prob float

The probability of dropout for attention weights.

alpha float

Balance parameter between DSP and GSP (0.0-1.0).

c int

Cutoff frequency for low-pass filtering.

reg_weight float

The L2 regularization weight.

weight_decay float

The value of weight decay used in optimizer.

batch_size int

The batch size used during training.

epochs int

The number of training epochs.

learning_rate float

The learning rate value.

neg_samples int

The number of negative samples.

max_seq_len int

The maximum length of sequences.

Source code in warprec/recommenders/sequential_recommender/bsarec.py
@model_registry.register(name="BSARec")
class BSARec(IterativeRecommender, SequentialRecommenderUtils):
    """Implementation of BASRec model from
    "BSARec: Bandlimited Self-Attention for Sequential Recommendation." in AAAi 2024.

    This model combines frequency-based filtering with self-attention to capture
    both periodic patterns and sequential dependencies in user behavior.

    Architecture:
    1. Domain-Specific Patterns (DSP): FFT-based low/high-pass filtering
    2. Graph-Space Patterns (GSP): Multi-head self-attention
    3. Adaptive Combination: Learnable weighted sum (alpha parameter)

    The frequency filtering helps capture cyclical patterns (e.g., weekly habits),
    while attention captures complex sequential dependencies.

    Args:
        params (dict): Model parameters.
        info (dict): The dictionary containing dataset information.
        *args (Any): Variable length argument list.
        seed (int): The seed to use for reproducibility.
        **kwargs (Any): Arbitrary keyword arguments.

    Attributes:
        DATALOADER_TYPE: The type of dataloader used.
        embedding_size (int): The dimension of the item embeddings (hidden_size).
        n_layers (int): The number of transformer encoder layers.
        n_heads (int): The number of attention heads in the transformer.
        inner_size (int): The dimensionality of the feed-forward layer.
        dropout_prob (float): The probability of dropout for embeddings.
        attn_dropout_prob (float): The probability of dropout for attention weights.
        alpha (float): Balance parameter between DSP and GSP (0.0-1.0).
        c (int): Cutoff frequency for low-pass filtering.
        reg_weight (float): The L2 regularization weight.
        weight_decay (float): The value of weight decay used in optimizer.
        batch_size (int): The batch size used during training.
        epochs (int): The number of training epochs.
        learning_rate (float): The learning rate value.
        neg_samples (int): The number of negative samples.
        max_seq_len (int): The maximum length of sequences.
    """

    # Dataloader definition
    DATALOADER_TYPE = DataLoaderType.SEQUENTIAL_LOADER

    # Model hyperparameters
    embedding_size: int
    n_layers: int
    n_heads: int
    inner_size: int
    dropout_prob: float
    attn_dropout_prob: float
    alpha: float
    c: int
    reg_weight: float
    weight_decay: float
    batch_size: int
    epochs: int
    learning_rate: float
    neg_samples: int
    max_seq_len: int

    def __init__(
        self,
        params: dict,
        info: dict,
        *args: Any,
        seed: int = 42,
        **kwargs: Any,
    ):
        super().__init__(params, info, *args, seed=seed, **kwargs)

        # Item and position embeddings
        self.item_embedding = nn.Embedding(
            self.n_items + 1, self.embedding_size, padding_idx=self.n_items
        )
        self.position_embedding = nn.Embedding(self.max_seq_len, self.embedding_size)

        self.emb_dropout = nn.Dropout(self.dropout_prob)
        self.layernorm = nn.LayerNorm(self.embedding_size, eps=1e-8)

        # BSARec layers: frequency filtering + attention
        self.bsarec_encoder = BSARecEncoder(
            embedding_size=self.embedding_size,
            n_layers=self.n_layers,
            n_heads=self.n_heads,
            inner_size=self.inner_size,
            attn_dropout_prob=self.attn_dropout_prob,
            dropout_prob=self.dropout_prob,
            max_seq_len=self.max_seq_len,
            alpha=self.alpha,
            c=self.c,
        )

        # Precompute causal mask
        causal_mask = self._generate_square_subsequent_mask(self.max_seq_len)
        self.register_buffer("causal_mask", causal_mask)

        # The paper optimizes next-item prediction with full softmax CE loss.
        self.main_loss = nn.CrossEntropyLoss()
        self.reg_loss = EmbLoss()

    def get_dataloader(
        self,
        interactions: Interactions,
        sessions: Sessions,
        **kwargs,
    ):
        return sessions.get_sequential_dataloader(
            max_seq_len=self.max_seq_len,
            neg_samples=0,
            batch_size=self.batch_size,
            **kwargs,
        )

    def training_step(self, batch: Any, batch_idx: int):
        if len(batch) == 4:
            item_seq, item_seq_len, pos_item, neg_item = batch
        else:
            item_seq, item_seq_len, pos_item = batch
            neg_item = None

        seq_output = self.forward(item_seq, item_seq_len)

        logits = torch.matmul(
            seq_output, self.item_embedding.weight[:-1].transpose(0, 1)
        )
        main_loss = self.main_loss(logits, pos_item)

        reg_terms = [
            self.item_embedding(item_seq),
            self.item_embedding(pos_item),
        ]
        if neg_item is not None:
            reg_terms.append(self.item_embedding(neg_item))
        reg_loss = self.reg_weight * self.reg_loss(*reg_terms)

        return main_loss + reg_loss

    def forward(self, item_seq: Tensor, item_seq_len: Tensor) -> Tensor:
        """Forward pass of the BSARec model.

        Combines frequency-based patterns (DSP) with attention patterns (GSP)
        through an adaptive weighted combination.

        Args:
            item_seq (Tensor): Padded sequences of item IDs [batch_size, max_seq_len].
            item_seq_len (Tensor): Actual lengths of sequences [batch_size,].

        Returns:
            Tensor: The embedding of the predicted item (last session state)
                    [batch_size, embedding_size].
        """
        seq_len = item_seq.size(1)

        # Padding mask to ignore padding tokens
        padding_mask = item_seq == self.n_items

        # Create position IDs
        position_ids = torch.arange(seq_len, dtype=torch.long).to(item_seq.device)
        position_ids = position_ids.unsqueeze(0).expand_as(item_seq)

        # Get embeddings
        item_emb = self.item_embedding(item_seq)
        pos_emb = self.position_embedding(position_ids)

        # Combine embeddings and apply LayerNorm + Dropout
        seq_emb = self.layernorm(item_emb + pos_emb)
        seq_emb = self.emb_dropout(seq_emb)

        # Pass through BSARec encoder (frequency + attention layers)
        transformer_output = self.bsarec_encoder(
            seq_emb,
            padding_mask,
            self.causal_mask[:seq_len, :seq_len],  # type: ignore[index]
        )  # [batch_size, max_seq_len, embedding_size]

        # Gather the output of the last relevant item in each sequence
        seq_output = self._gather_indexes(
            transformer_output, item_seq_len - 1
        )  # [batch_size, embedding_size]

        return seq_output

    @torch.no_grad()
    def predict(
        self,
        user_indices: Tensor,
        *args: Any,
        item_indices: Optional[Tensor] = None,
        user_seq: Optional[Tensor] = None,
        seq_len: Optional[Tensor] = None,
        **kwargs: Any,
    ) -> Tensor:
        """
        Prediction using the learned session embeddings.

        Args:
            user_indices (Tensor): The batch of user indices.
            *args (Any): List of arguments.
            item_indices (Optional[Tensor]): The batch of item indices. If None,
                full prediction will be produced.
            user_seq (Optional[Tensor]): Padded sequences of item IDs for users to predict for.
            seq_len (Optional[Tensor]): Actual lengths of these sequences, before padding.
            **kwargs (Any): The dictionary of keyword arguments.

        Returns:
            Tensor: The score matrix {user x item}.
        """
        # Get sequence output embeddings
        seq_output = self.forward(user_seq, seq_len)

        if item_indices is None:
            # Case 'full': prediction on all items
            item_embeddings = self.item_embedding.weight[:-1, :]
            einsum_string = "be,ie->bi"
        else:
            # Case 'sampled': prediction on a sampled set of items
            item_embeddings = self.item_embedding(item_indices)
            einsum_string = "be,bse->bs"

        predictions = torch.einsum(einsum_string, seq_output, item_embeddings)
        return predictions

forward(item_seq, item_seq_len)

Forward pass of the BSARec model.

Combines frequency-based patterns (DSP) with attention patterns (GSP) through an adaptive weighted combination.

Parameters:

Name Type Description Default
item_seq Tensor

Padded sequences of item IDs [batch_size, max_seq_len].

required
item_seq_len Tensor

Actual lengths of sequences [batch_size,].

required

Returns:

Name Type Description
Tensor Tensor

The embedding of the predicted item (last session state) [batch_size, embedding_size].

Source code in warprec/recommenders/sequential_recommender/bsarec.py
def forward(self, item_seq: Tensor, item_seq_len: Tensor) -> Tensor:
    """Forward pass of the BSARec model.

    Combines frequency-based patterns (DSP) with attention patterns (GSP)
    through an adaptive weighted combination.

    Args:
        item_seq (Tensor): Padded sequences of item IDs [batch_size, max_seq_len].
        item_seq_len (Tensor): Actual lengths of sequences [batch_size,].

    Returns:
        Tensor: The embedding of the predicted item (last session state)
                [batch_size, embedding_size].
    """
    seq_len = item_seq.size(1)

    # Padding mask to ignore padding tokens
    padding_mask = item_seq == self.n_items

    # Create position IDs
    position_ids = torch.arange(seq_len, dtype=torch.long).to(item_seq.device)
    position_ids = position_ids.unsqueeze(0).expand_as(item_seq)

    # Get embeddings
    item_emb = self.item_embedding(item_seq)
    pos_emb = self.position_embedding(position_ids)

    # Combine embeddings and apply LayerNorm + Dropout
    seq_emb = self.layernorm(item_emb + pos_emb)
    seq_emb = self.emb_dropout(seq_emb)

    # Pass through BSARec encoder (frequency + attention layers)
    transformer_output = self.bsarec_encoder(
        seq_emb,
        padding_mask,
        self.causal_mask[:seq_len, :seq_len],  # type: ignore[index]
    )  # [batch_size, max_seq_len, embedding_size]

    # Gather the output of the last relevant item in each sequence
    seq_output = self._gather_indexes(
        transformer_output, item_seq_len - 1
    )  # [batch_size, embedding_size]

    return seq_output

predict(user_indices, *args, item_indices=None, user_seq=None, seq_len=None, **kwargs)

Prediction using the learned session embeddings.

Parameters:

Name Type Description Default
user_indices Tensor

The batch of user indices.

required
*args Any

List of arguments.

()
item_indices Optional[Tensor]

The batch of item indices. If None, full prediction will be produced.

None
user_seq Optional[Tensor]

Padded sequences of item IDs for users to predict for.

None
seq_len Optional[Tensor]

Actual lengths of these sequences, before padding.

None
**kwargs Any

The dictionary of keyword arguments.

{}

Returns:

Name Type Description
Tensor Tensor

The score matrix {user x item}.

Source code in warprec/recommenders/sequential_recommender/bsarec.py
@torch.no_grad()
def predict(
    self,
    user_indices: Tensor,
    *args: Any,
    item_indices: Optional[Tensor] = None,
    user_seq: Optional[Tensor] = None,
    seq_len: Optional[Tensor] = None,
    **kwargs: Any,
) -> Tensor:
    """
    Prediction using the learned session embeddings.

    Args:
        user_indices (Tensor): The batch of user indices.
        *args (Any): List of arguments.
        item_indices (Optional[Tensor]): The batch of item indices. If None,
            full prediction will be produced.
        user_seq (Optional[Tensor]): Padded sequences of item IDs for users to predict for.
        seq_len (Optional[Tensor]): Actual lengths of these sequences, before padding.
        **kwargs (Any): The dictionary of keyword arguments.

    Returns:
        Tensor: The score matrix {user x item}.
    """
    # Get sequence output embeddings
    seq_output = self.forward(user_seq, seq_len)

    if item_indices is None:
        # Case 'full': prediction on all items
        item_embeddings = self.item_embedding.weight[:-1, :]
        einsum_string = "be,ie->bi"
    else:
        # Case 'sampled': prediction on a sampled set of items
        item_embeddings = self.item_embedding(item_indices)
        einsum_string = "be,bse->bs"

    predictions = torch.einsum(einsum_string, seq_output, item_embeddings)
    return predictions

warprec.recommenders.sequential_recommender.cl4srec.CL4SRec

Bases: IterativeRecommender, SequentialRecommenderUtils

Implementation of CL4SRec model from "Contrastive learning for sequential recommendation" in SIGIR 2021.

This implementation follows the original paper: 1. A SASRec-style unidirectional Transformer encoder. 2. Two random augmentations sampled from crop/mask/reorder. 3. A multi-task objective with sampled-softmax next-item prediction and InfoNCE contrastive learning.

Parameters:

Name Type Description Default
params dict

Model parameters.

required
info dict

The dictionary containing dataset information.

required
*args Any

Variable length argument list.

()
seed int

The seed to use for reproducibility.

42
**kwargs Any

Arbitrary keyword arguments.

{}

Attributes:

Name Type Description
DATALOADER_TYPE

The type of dataloader used.

embedding_size int

The dimension of the item embeddings (hidden_size).

n_layers int

The number of transformer encoder layers.

n_heads int

The number of attention heads in the transformer.

inner_size int

The dimensionality of the feed-forward layer.

dropout_prob float

The probability of dropout for embeddings.

attn_dropout_prob float

The probability of dropout for attention weights.

ssl_lambda float

The weight for the unsupervised CL loss.

tau float

The temperature parameter for contrastive loss.

sim_type str

The similarity metric for contrastive loss ("dot" or "cos").

crop_eta float

The probability of cropping items in the augmentation.

mask_gamma float

The probability of masking items in the augmentation.

reorder_beta float

The probability of reordering items in the augmentation.

reg_weight float

The L2 regularization weight.

weight_decay float

The value of weight decay used in optimizer.

batch_size int

The batch size used during training.

epochs int

The number of training epochs.

learning_rate float

The learning rate value.

neg_samples int

The number of negative samples.

max_seq_len int

The maximum length of sequences.

Source code in warprec/recommenders/sequential_recommender/cl4srec.py
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
@model_registry.register(name="CL4SRec")
class CL4SRec(IterativeRecommender, SequentialRecommenderUtils):
    """Implementation of CL4SRec model from
    "Contrastive learning for sequential recommendation" in SIGIR 2021.

    This implementation follows the original paper:
    1. A SASRec-style unidirectional Transformer encoder.
    2. Two random augmentations sampled from crop/mask/reorder.
    3. A multi-task objective with sampled-softmax next-item prediction
       and InfoNCE contrastive learning.

    Args:
        params (dict): Model parameters.
        info (dict): The dictionary containing dataset information.
        *args (Any): Variable length argument list.
        seed (int): The seed to use for reproducibility.
        **kwargs (Any): Arbitrary keyword arguments.

    Attributes:
        DATALOADER_TYPE: The type of dataloader used.
        embedding_size (int): The dimension of the item embeddings (hidden_size).
        n_layers (int): The number of transformer encoder layers.
        n_heads (int): The number of attention heads in the transformer.
        inner_size (int): The dimensionality of the feed-forward layer.
        dropout_prob (float): The probability of dropout for embeddings.
        attn_dropout_prob (float): The probability of dropout for attention weights.
        ssl_lambda (float): The weight for the unsupervised CL loss.
        tau (float): The temperature parameter for contrastive loss.
        sim_type (str): The similarity metric for contrastive loss ("dot" or "cos").
        crop_eta (float): The probability of cropping items in the augmentation.
        mask_gamma (float): The probability of masking items in the augmentation.
        reorder_beta (float): The probability of reordering items in the augmentation.
        reg_weight (float): The L2 regularization weight.
        weight_decay (float): The value of weight decay used in optimizer.
        batch_size (int): The batch size used during training.
        epochs (int): The number of training epochs.
        learning_rate (float): The learning rate value.
        neg_samples (int): The number of negative samples.
        max_seq_len (int): The maximum length of sequences.
    """

    DATALOADER_TYPE = DataLoaderType.SEQUENTIAL_LOADER

    embedding_size: int
    n_layers: int
    n_heads: int
    inner_size: int
    dropout_prob: float
    attn_dropout_prob: float
    ssl_lambda: float
    tau: float
    sim_type: str
    crop_eta: float
    mask_gamma: float
    reorder_beta: float
    reg_weight: float
    weight_decay: float
    batch_size: int
    epochs: int
    learning_rate: float
    neg_samples: int
    max_seq_len: int

    def __init__(
        self,
        params: dict,
        info: dict,
        *args: Any,
        seed: int = 42,
        **kwargs: Any,
    ):
        super().__init__(params, info, *args, seed=seed, **kwargs)

        self.padding_token_id = self.n_items
        self.mask_token_id = self.n_items + 1

        self.item_embedding = nn.Embedding(
            self.n_items + 2,
            self.embedding_size,
            padding_idx=self.padding_token_id,
        )
        self.position_embedding = nn.Embedding(self.max_seq_len, self.embedding_size)
        self.emb_dropout = nn.Dropout(self.dropout_prob)
        self.layernorm = nn.LayerNorm(self.embedding_size, eps=1e-8)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=self.embedding_size,
            nhead=self.n_heads,
            dim_feedforward=self.inner_size,
            dropout=self.attn_dropout_prob,
            activation="relu",
            batch_first=True,
            norm_first=False,
        )
        self.transformer_encoder = nn.TransformerEncoder(
            encoder_layer,
            num_layers=self.n_layers,
        )

        causal_mask = self._generate_square_subsequent_mask(self.max_seq_len)
        self.register_buffer("causal_mask", causal_mask)

        self.apply(self._init_weights)

        self.main_loss = nn.CrossEntropyLoss()
        self.reg_loss = EmbLoss()
        self.contrastive_loss = nn.CrossEntropyLoss()

    def get_dataloader(
        self,
        interactions: Interactions,
        sessions: Sessions,
        **kwargs,
    ):
        return sessions.get_sequential_dataloader(
            max_seq_len=self.max_seq_len,
            neg_samples=self.neg_samples,
            batch_size=self.batch_size,
            **kwargs,
        )

    def _item_crop_batch(
        self, item_seq: Tensor, item_seq_len: Tensor
    ) -> Tuple[Tensor, Tensor]:
        """Crop a valid continuous subsequence of length floor(eta * |s|)."""
        batch_size, max_seq = item_seq.shape
        num_left = torch.floor(item_seq_len * self.crop_eta).long().clamp(min=1)

        max_start = (item_seq_len - num_left).clamp(min=0)
        random_offsets = torch.rand(batch_size, device=item_seq.device)
        crop_begin = torch.floor(random_offsets * (max_start + 1).float()).long()

        seq_indices = torch.arange(max_seq, device=item_seq.device).unsqueeze(0)
        gather_idx = crop_begin.unsqueeze(1) + seq_indices
        gathered = torch.gather(item_seq, 1, gather_idx.clamp(max=max_seq - 1))

        keep_mask = seq_indices < num_left.unsqueeze(1)
        cropped = torch.where(
            keep_mask,
            gathered,
            torch.full_like(gathered, self.padding_token_id),
        )
        return cropped, num_left

    def _item_mask_batch(
        self, item_seq: Tensor, item_seq_len: Tensor
    ) -> Tuple[Tensor, Tensor]:
        """Mask valid items with a dedicated [mask] token."""
        batch_size, max_seq = item_seq.shape
        num_mask = torch.floor(item_seq_len * self.mask_gamma).long().clamp(min=1)

        # Assign random float values to all items and assign 1.0 to padding items so they are sorted last
        rand_vals = torch.rand(batch_size, max_seq, device=item_seq.device)
        pos_indices = torch.arange(max_seq, device=item_seq.device).unsqueeze(0)
        valid_mask = pos_indices < item_seq_len.unsqueeze(1)
        rand_vals = torch.where(valid_mask, rand_vals, torch.ones_like(rand_vals))

        # Sort the random values to get candidate positions for masking
        _, sorted_indices = torch.sort(rand_vals, dim=1, descending=False)

        # Determine the rank of each position in the original sequence (vectorized double argsort)
        ranks = torch.argsort(sorted_indices, dim=1)

        # Identify mask positions where rank is less than the calculated masking budget
        mask_positions = ranks < num_mask.unsqueeze(1)
        # Prevent masking of padding elements
        mask_positions = mask_positions & valid_mask

        # Replace selected elements with the mask token
        masked = torch.where(
            mask_positions,
            torch.full_like(item_seq, self.mask_token_id),
            item_seq,
        )
        return masked, item_seq_len

    def _item_reorder_batch(
        self, item_seq: Tensor, item_seq_len: Tensor
    ) -> Tuple[Tensor, Tensor]:
        """Shuffle a valid continuous subsequence of length floor(beta * |s|)."""
        batch_size, max_seq = item_seq.shape
        num_reorder = torch.floor(item_seq_len * self.reorder_beta).long().clamp(min=1)

        # Constrain the reorder length to avoid shuffling beyond sequence boundaries
        max_reorder = (item_seq_len - 1).clamp(min=1)
        num_reorder = torch.minimum(num_reorder, max_reorder)

        # Calculate randomized start indices for the reorder window in a vectorized manner
        max_start = (item_seq_len - num_reorder).clamp(min=0)
        random_offsets = torch.rand(batch_size, device=item_seq.device)
        reorder_start = torch.floor(random_offsets * (max_start + 1).float()).long()

        # Create a boolean mask indicating the active reorder window for each sequence
        seq_indices = torch.arange(max_seq, device=item_seq.device).unsqueeze(0)
        reorder_mask = (seq_indices >= reorder_start.unsqueeze(1)) & (
            seq_indices < (reorder_start + num_reorder).unsqueeze(1)
        )
        # Ensure we do not shuffle padding elements by restricting to actual sequence lengths
        reorder_mask = reorder_mask & (seq_indices < item_seq_len.unsqueeze(1))

        # Initialize sorting keys with original indices to keep unselected items in place
        seq_indices_float = seq_indices.float()

        # Generate random sorting keys bounded strictly within the active reorder window
        rand_vals = torch.rand(batch_size, max_seq, device=item_seq.device)
        shuffled_keys = (
            reorder_start.unsqueeze(1).float()
            + rand_vals * num_reorder.unsqueeze(1).float()
        )

        # Merge the keys: random values for the shuffle window, original indices for the rest
        sort_keys = torch.where(reorder_mask, shuffled_keys, seq_indices_float)

        # Perform a single batched sort to find the new shuffled index mapping
        _, shuffle_idx = torch.sort(sort_keys, dim=1)
        reordered = torch.gather(item_seq, 1, shuffle_idx)

        return reordered, item_seq_len

    def augment(
        self, item_seq: Tensor, item_seq_len: Tensor
    ) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
        """Generate two random augmented views for each sequence."""
        batch_size = item_seq_len.shape[0]
        augment_choices1 = torch.randint(0, 3, (batch_size,), device=item_seq.device)
        augment_choices2 = torch.randint(0, 3, (batch_size,), device=item_seq.device)

        aug_seq1 = item_seq.clone()
        aug_len1 = item_seq_len.clone()
        aug_seq2 = item_seq.clone()
        aug_len2 = item_seq_len.clone()

        crop_mask1 = augment_choices1 == 0
        if crop_mask1.any():
            cropped, crop_len = self._item_crop_batch(
                item_seq[crop_mask1], item_seq_len[crop_mask1]
            )
            aug_seq1[crop_mask1] = cropped
            aug_len1[crop_mask1] = crop_len

        crop_mask2 = augment_choices2 == 0
        if crop_mask2.any():
            cropped, crop_len = self._item_crop_batch(
                item_seq[crop_mask2], item_seq_len[crop_mask2]
            )
            aug_seq2[crop_mask2] = cropped
            aug_len2[crop_mask2] = crop_len

        mask_mask1 = augment_choices1 == 1
        if mask_mask1.any():
            masked, mask_len = self._item_mask_batch(
                item_seq[mask_mask1], item_seq_len[mask_mask1]
            )
            aug_seq1[mask_mask1] = masked
            aug_len1[mask_mask1] = mask_len

        mask_mask2 = augment_choices2 == 1
        if mask_mask2.any():
            masked, mask_len = self._item_mask_batch(
                item_seq[mask_mask2], item_seq_len[mask_mask2]
            )
            aug_seq2[mask_mask2] = masked
            aug_len2[mask_mask2] = mask_len

        reorder_mask1 = augment_choices1 == 2
        if reorder_mask1.any():
            reordered, reorder_len = self._item_reorder_batch(
                item_seq[reorder_mask1], item_seq_len[reorder_mask1]
            )
            aug_seq1[reorder_mask1] = reordered
            aug_len1[reorder_mask1] = reorder_len

        reorder_mask2 = augment_choices2 == 2
        if reorder_mask2.any():
            reordered, reorder_len = self._item_reorder_batch(
                item_seq[reorder_mask2], item_seq_len[reorder_mask2]
            )
            aug_seq2[reorder_mask2] = reordered
            aug_len2[reorder_mask2] = reorder_len

        return aug_seq1, aug_len1, aug_seq2, aug_len2

    def _sampled_softmax_loss(
        self, seq_output: Tensor, pos_item: Tensor, neg_item: Optional[Tensor]
    ) -> Tensor:
        """Main next-item objective used by CL4SRec."""
        if neg_item is None:
            logits = torch.matmul(
                seq_output, self.item_embedding.weight[: self.n_items].T
            )
            return self.main_loss(logits, pos_item)

        pos_logits = torch.sum(
            seq_output * self.item_embedding(pos_item),
            dim=-1,
            keepdim=True,
        )
        neg_logits = torch.sum(
            seq_output.unsqueeze(1) * self.item_embedding(neg_item),
            dim=-1,
        )
        sampled_logits = torch.cat([pos_logits, neg_logits], dim=1)
        sampled_labels = torch.zeros(
            sampled_logits.size(0),
            dtype=torch.long,
            device=seq_output.device,
        )
        return self.main_loss(sampled_logits, sampled_labels)

    def training_step(self, batch: Any, batch_idx: int):
        if len(batch) == 4:
            item_seq, item_seq_len, pos_item, neg_item = batch
        else:
            item_seq, item_seq_len, pos_item = batch
            neg_item = None

        seq_output = self.forward(item_seq, item_seq_len)
        main_loss = self._sampled_softmax_loss(seq_output, pos_item, neg_item)

        reg_terms = [self.item_embedding(item_seq), self.item_embedding(pos_item)]
        if neg_item is not None:
            reg_terms.append(self.item_embedding(neg_item))
        reg_loss = self.reg_weight * self.reg_loss(*reg_terms)
        total_loss = main_loss + reg_loss

        if self.ssl_lambda > 0 and item_seq.size(0) > 1:
            aug_item_seq1, aug_len1, aug_item_seq2, aug_len2 = self.augment(
                item_seq, item_seq_len
            )
            seq_output1 = self.forward(aug_item_seq1, aug_len1)
            seq_output2 = self.forward(aug_item_seq2, aug_len2)

            nce_logits, nce_labels = self._info_nce(
                seq_output1,
                seq_output2,
                temp=self.tau,
                batch_size=item_seq.size(0),
                sim=self.sim_type,
            )
            nce_loss = self.contrastive_loss(nce_logits, nce_labels)
            total_loss += self.ssl_lambda * nce_loss

        return total_loss

    def forward(self, item_seq: Tensor, item_seq_len: Tensor) -> Tensor:
        """Forward pass of the SASRec-style encoder used by CL4SRec."""
        seq_len = item_seq.size(1)
        padding_mask = item_seq == self.padding_token_id

        position_ids = torch.arange(seq_len, dtype=torch.long, device=item_seq.device)
        position_ids = position_ids.unsqueeze(0).expand_as(item_seq)

        item_emb = self.item_embedding(item_seq)
        pos_emb = self.position_embedding(position_ids)

        seq_emb = self.layernorm(item_emb + pos_emb)
        seq_emb = self.emb_dropout(seq_emb)

        transformer_output = self.transformer_encoder(
            src=seq_emb,
            mask=self.causal_mask[:seq_len, :seq_len],  # type: ignore[index]
            src_key_padding_mask=padding_mask,
        )

        return self._gather_indexes(transformer_output, item_seq_len - 1)

    def _prepare_contrastive_representations(
        self, z_i: Tensor, z_j: Tensor, sim: str
    ) -> Tensor:
        if sim == "cos":
            z_i = torch.nn.functional.normalize(z_i, dim=1)
            z_j = torch.nn.functional.normalize(z_j, dim=1)
        elif sim != "dot":
            raise ValueError(f"Unknown similarity metric: {sim}")
        return torch.cat([z_i, z_j], dim=0)

    def _info_nce(
        self, z_i: Tensor, z_j: Tensor, temp: float, batch_size: int, sim: str = "dot"
    ) -> Tuple[Tensor, Tensor]:
        """Compute InfoNCE with 2N views and 2(N-1) negatives per sample."""
        representations = self._prepare_contrastive_representations(z_i, z_j, sim)
        total_views = 2 * batch_size

        similarity_matrix = torch.matmul(representations, representations.T) / temp
        eye_mask = torch.eye(
            total_views, dtype=torch.bool, device=representations.device
        )
        similarity_matrix = similarity_matrix.masked_fill(eye_mask, float("-inf"))

        positive_indices = torch.arange(total_views, device=representations.device)
        positive_indices = (positive_indices + batch_size) % total_views

        positive_samples = similarity_matrix[
            torch.arange(total_views, device=representations.device),
            positive_indices,
        ].unsqueeze(1)

        negative_mask = ~eye_mask
        negative_mask[
            torch.arange(total_views, device=representations.device),
            positive_indices,
        ] = False
        negative_samples = similarity_matrix[negative_mask].reshape(total_views, -1)

        logits = torch.cat([positive_samples, negative_samples], dim=1)
        labels = torch.zeros(
            total_views, dtype=torch.long, device=representations.device
        )
        return logits, labels

    def _decompose(
        self, z_i: Tensor, z_j: Tensor, origin_z: Tensor, batch_size: int
    ) -> Tuple[Tensor, Tensor]:
        """Decompose contrastive behavior into alignment and uniformity metrics."""
        total_views = 2 * batch_size
        z = torch.cat((z_i, z_j), dim=0)

        sim = torch.cdist(z, z, p=2)
        sim_i_j = torch.diag(sim, batch_size)
        sim_j_i = torch.diag(sim, -batch_size)
        positive_samples = torch.cat((sim_i_j, sim_j_i), dim=0).reshape(total_views, 1)
        alignment = positive_samples.mean()

        sim = torch.cdist(origin_z, origin_z, p=2)
        mask = torch.ones(
            (batch_size, batch_size), dtype=torch.bool, device=origin_z.device
        )
        mask = mask.fill_diagonal_(0)
        negative_samples = sim[mask].reshape(batch_size, -1)
        uniformity = torch.log(torch.exp(-2 * negative_samples).mean())

        return alignment, uniformity

    @torch.no_grad()
    def predict(
        self,
        user_indices: Tensor,
        *args: Any,
        item_indices: Optional[Tensor] = None,
        user_seq: Optional[Tensor] = None,
        seq_len: Optional[Tensor] = None,
        **kwargs: Any,
    ) -> Tensor:
        """Prediction using the learned sequence embeddings."""
        seq_output = self.forward(user_seq, seq_len)

        if item_indices is None:
            item_embeddings = self.item_embedding.weight[: self.n_items, :]
            einsum_string = "be,ie->bi"
        else:
            item_embeddings = self.item_embedding(item_indices)
            einsum_string = "be,bse->bs"

        return torch.einsum(einsum_string, seq_output, item_embeddings)

augment(item_seq, item_seq_len)

Generate two random augmented views for each sequence.

Source code in warprec/recommenders/sequential_recommender/cl4srec.py
def augment(
    self, item_seq: Tensor, item_seq_len: Tensor
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
    """Generate two random augmented views for each sequence."""
    batch_size = item_seq_len.shape[0]
    augment_choices1 = torch.randint(0, 3, (batch_size,), device=item_seq.device)
    augment_choices2 = torch.randint(0, 3, (batch_size,), device=item_seq.device)

    aug_seq1 = item_seq.clone()
    aug_len1 = item_seq_len.clone()
    aug_seq2 = item_seq.clone()
    aug_len2 = item_seq_len.clone()

    crop_mask1 = augment_choices1 == 0
    if crop_mask1.any():
        cropped, crop_len = self._item_crop_batch(
            item_seq[crop_mask1], item_seq_len[crop_mask1]
        )
        aug_seq1[crop_mask1] = cropped
        aug_len1[crop_mask1] = crop_len

    crop_mask2 = augment_choices2 == 0
    if crop_mask2.any():
        cropped, crop_len = self._item_crop_batch(
            item_seq[crop_mask2], item_seq_len[crop_mask2]
        )
        aug_seq2[crop_mask2] = cropped
        aug_len2[crop_mask2] = crop_len

    mask_mask1 = augment_choices1 == 1
    if mask_mask1.any():
        masked, mask_len = self._item_mask_batch(
            item_seq[mask_mask1], item_seq_len[mask_mask1]
        )
        aug_seq1[mask_mask1] = masked
        aug_len1[mask_mask1] = mask_len

    mask_mask2 = augment_choices2 == 1
    if mask_mask2.any():
        masked, mask_len = self._item_mask_batch(
            item_seq[mask_mask2], item_seq_len[mask_mask2]
        )
        aug_seq2[mask_mask2] = masked
        aug_len2[mask_mask2] = mask_len

    reorder_mask1 = augment_choices1 == 2
    if reorder_mask1.any():
        reordered, reorder_len = self._item_reorder_batch(
            item_seq[reorder_mask1], item_seq_len[reorder_mask1]
        )
        aug_seq1[reorder_mask1] = reordered
        aug_len1[reorder_mask1] = reorder_len

    reorder_mask2 = augment_choices2 == 2
    if reorder_mask2.any():
        reordered, reorder_len = self._item_reorder_batch(
            item_seq[reorder_mask2], item_seq_len[reorder_mask2]
        )
        aug_seq2[reorder_mask2] = reordered
        aug_len2[reorder_mask2] = reorder_len

    return aug_seq1, aug_len1, aug_seq2, aug_len2

forward(item_seq, item_seq_len)

Forward pass of the SASRec-style encoder used by CL4SRec.

Source code in warprec/recommenders/sequential_recommender/cl4srec.py
def forward(self, item_seq: Tensor, item_seq_len: Tensor) -> Tensor:
    """Forward pass of the SASRec-style encoder used by CL4SRec."""
    seq_len = item_seq.size(1)
    padding_mask = item_seq == self.padding_token_id

    position_ids = torch.arange(seq_len, dtype=torch.long, device=item_seq.device)
    position_ids = position_ids.unsqueeze(0).expand_as(item_seq)

    item_emb = self.item_embedding(item_seq)
    pos_emb = self.position_embedding(position_ids)

    seq_emb = self.layernorm(item_emb + pos_emb)
    seq_emb = self.emb_dropout(seq_emb)

    transformer_output = self.transformer_encoder(
        src=seq_emb,
        mask=self.causal_mask[:seq_len, :seq_len],  # type: ignore[index]
        src_key_padding_mask=padding_mask,
    )

    return self._gather_indexes(transformer_output, item_seq_len - 1)

predict(user_indices, *args, item_indices=None, user_seq=None, seq_len=None, **kwargs)

Prediction using the learned sequence embeddings.

Source code in warprec/recommenders/sequential_recommender/cl4srec.py
@torch.no_grad()
def predict(
    self,
    user_indices: Tensor,
    *args: Any,
    item_indices: Optional[Tensor] = None,
    user_seq: Optional[Tensor] = None,
    seq_len: Optional[Tensor] = None,
    **kwargs: Any,
) -> Tensor:
    """Prediction using the learned sequence embeddings."""
    seq_output = self.forward(user_seq, seq_len)

    if item_indices is None:
        item_embeddings = self.item_embedding.weight[: self.n_items, :]
        einsum_string = "be,ie->bi"
    else:
        item_embeddings = self.item_embedding(item_indices)
        einsum_string = "be,bse->bs"

    return torch.einsum(einsum_string, seq_output, item_embeddings)

warprec.recommenders.sequential_recommender.core.CORE

Bases: IterativeRecommender, SequentialRecommenderUtils

Implementation of CORE algorithm from "CORE: Simple and Effective Session-based Recommendation within Consistent Representation Space." in SIGIR 2022.

CORE unifies the representation space for both encoding and decoding processes, using a Representation-Consistent Encoder (RCE) and Robust Distance Measuring (RDM).

Parameters:

Name Type Description Default
params dict

Model parameters.

required
info dict

The dictionary containing dataset information.

required
*args Any

Variable length argument list.

()
seed int

The seed to use for reproducibility.

42
**kwargs Any

Arbitrary keyword arguments.

{}

Attributes:

Name Type Description
DATALOADER_TYPE

The type of dataloader used.

embedding_size int

The dimension of the item embeddings.

dnn_type str

Type of encoder ('trm' for Transformer or 'ave' for Average).

n_layers int

Number of transformer layers.

n_heads int

Number of attention heads.

inner_size int

Inner size of the transformer feed-forward layer.

hidden_dropout_prob float

Dropout probability for hidden layers.

attn_dropout_prob float

Dropout probability for attention weights.

layer_norm_eps float

Epsilon for layer normalization.

session_dropout float

Dropout for the session embeddings.

item_dropout float

Dropout for item embeddings during training.

temperature float

Temperature scaling factor for RDM.

reg_weight float

The L2 regularization weight.

weight_decay float

The value of weight decay used in the optimizer.

batch_size int

The batch size used during training.

epochs int

The number of training epochs.

learning_rate float

The learning rate value.

neg_samples int

The number of negative samples.

max_seq_len int

Maximum sequence length.

Source code in warprec/recommenders/sequential_recommender/core.py
@model_registry.register(name="CORE")
class CORE(IterativeRecommender, SequentialRecommenderUtils):
    """Implementation of CORE algorithm from
    "CORE: Simple and Effective Session-based Recommendation within Consistent Representation Space." in SIGIR 2022.

    CORE unifies the representation space for both encoding and decoding processes,
    using a Representation-Consistent Encoder (RCE) and Robust Distance Measuring (RDM).

    Args:
        params (dict): Model parameters.
        info (dict): The dictionary containing dataset information.
        *args (Any): Variable length argument list.
        seed (int): The seed to use for reproducibility.
        **kwargs (Any): Arbitrary keyword arguments.

    Attributes:
        DATALOADER_TYPE: The type of dataloader used.
        embedding_size (int): The dimension of the item embeddings.
        dnn_type (str): Type of encoder ('trm' for Transformer or 'ave' for Average).
        n_layers (int): Number of transformer layers.
        n_heads (int): Number of attention heads.
        inner_size (int): Inner size of the transformer feed-forward layer.
        hidden_dropout_prob (float): Dropout probability for hidden layers.
        attn_dropout_prob (float): Dropout probability for attention weights.
        layer_norm_eps (float): Epsilon for layer normalization.
        session_dropout (float): Dropout for the session embeddings.
        item_dropout (float): Dropout for item embeddings during training.
        temperature (float): Temperature scaling factor for RDM.
        reg_weight (float): The L2 regularization weight.
        weight_decay (float): The value of weight decay used in the optimizer.
        batch_size (int): The batch size used during training.
        epochs (int): The number of training epochs.
        learning_rate (float): The learning rate value.
        neg_samples (int): The number of negative samples.
        max_seq_len (int): Maximum sequence length.
    """

    DATALOADER_TYPE = DataLoaderType.SEQUENTIAL_LOADER

    # Model hyperparameters
    embedding_size: int
    dnn_type: str
    n_layers: int
    n_heads: int
    inner_size: int
    hidden_dropout_prob: float
    attn_dropout_prob: float
    layer_norm_eps: float
    session_dropout: float
    item_dropout: float
    temperature: float
    reg_weight: float
    weight_decay: float
    batch_size: int
    epochs: int
    learning_rate: float
    neg_samples: int
    max_seq_len: int

    def __init__(
        self,
        params: dict,
        info: dict,
        *args: Any,
        seed: int = 42,
        **kwargs: Any,
    ):
        super().__init__(params, info, *args, seed=seed, **kwargs)

        self.session_dropout_layer = nn.Dropout(self.session_dropout)
        self.item_dropout_layer = nn.Dropout(self.item_dropout)

        self.item_embedding = nn.Embedding(
            self.n_items + 1, self.embedding_size, padding_idx=self.n_items
        )

        # Initialize the chosen DNN encoder
        if self.dnn_type == "trm":
            self.net = TransNet(params, self.max_seq_len, self.n_items)
        else:
            self.net = self.ave_net  # type: ignore[assignment]

        self.loss_fct = nn.CrossEntropyLoss()
        self.reg_loss = EmbLoss()
        self.apply(self._init_weights)

    def ave_net(self, item_seq: Tensor, item_emb: Tensor) -> Tensor:
        """Simple average pooling encoder."""
        mask = (item_seq != self.n_items).to(torch.float)
        alpha = mask / mask.sum(dim=-1, keepdim=True)
        return alpha.unsqueeze(-1)

    def forward(self, item_seq: Tensor) -> Tensor:
        """Forward pass of the CORE model.
        Args:
            item_seq (Tensor): Padded sequences of item IDs [batch_size, max_seq_len].
        Returns:
            Tensor: The session representation [batch_size, embedding_size].
        """
        # Get item embeddings
        x = self.item_embedding(item_seq)
        x = self.session_dropout_layer(x)

        # Representation-Consistent Encoder (RCE):
        # Calculate weights alpha and perform weighted sum of item embeddings
        alpha = self.net(item_seq, x)
        seq_output = torch.sum(alpha * x, dim=1)

        # Normalize output for Robust Distance Measuring (RDM)
        return F.normalize(seq_output, dim=-1)

    def get_dataloader(
        self,
        interactions: Interactions,
        sessions: Sessions,
        **kwargs: Any,
    ):
        return sessions.get_sequential_dataloader(
            max_seq_len=self.max_seq_len,
            neg_samples=self.neg_samples,
            batch_size=self.batch_size,
            **kwargs,
        )

    def training_step(self, batch: Any, batch_idx: int):
        item_seq, _, pos_item = batch[:3]

        # Generate session representation
        seq_output = self.forward(item_seq)

        # Robust Distance Measuring (RDM):
        # Calculate cosine similarity between session and all items
        all_item_emb = self.item_embedding.weight[: self.n_items]
        all_item_emb = self.item_dropout_layer(all_item_emb)
        all_item_emb = F.normalize(all_item_emb, dim=-1)

        # Logits are scaled by temperature
        logits = (
            torch.matmul(seq_output, all_item_emb.transpose(0, 1)) / self.temperature
        )
        main_loss = self.loss_fct(logits, pos_item)

        # L2 regularization
        reg_loss = self.reg_weight * self.reg_loss(
            self.item_embedding(item_seq),
            self.item_embedding(pos_item),
        )

        # Loss logging
        loss = main_loss + reg_loss
        self.log("loss", loss, prog_bar=True, on_step=False, on_epoch=True)
        return loss

    def predict(
        self,
        user_seq: Tensor,
        *args: Any,
        item_indices: Optional[Tensor] = None,
        **kwargs: Any,
    ) -> Tensor:
        """
        Prediction using the learned session embeddings.

        Args:
            user_seq (Tensor): Padded sequences of item IDs for users to predict for.
            *args (Any): List of arguments.
            item_indices (Optional[Tensor]): The batch of item indices. If None,
                full prediction will be produced.
            **kwargs (Any): The dictionary of keyword arguments.

        Returns:
            Tensor: The score matrix {user x item}.
        """
        # Get session representation
        seq_output = self.forward(user_seq)

        if item_indices is None:
            # Predict scores for all items
            item_embeddings = self.item_embedding.weight[:-1, :]
            item_embeddings = F.normalize(item_embeddings, dim=-1)
            einsum_string = "be,ie->bi"
        else:
            # Predict scores for a specific subset of items
            item_embeddings = self.item_embedding(item_indices)
            item_embeddings = F.normalize(item_embeddings, dim=-1)
            einsum_string = "be,bse->bs"

        # Calculate similarity and apply temperature scaling
        predictions = torch.einsum(einsum_string, seq_output, item_embeddings)
        return predictions / self.temperature

ave_net(item_seq, item_emb)

Simple average pooling encoder.

Source code in warprec/recommenders/sequential_recommender/core.py
def ave_net(self, item_seq: Tensor, item_emb: Tensor) -> Tensor:
    """Simple average pooling encoder."""
    mask = (item_seq != self.n_items).to(torch.float)
    alpha = mask / mask.sum(dim=-1, keepdim=True)
    return alpha.unsqueeze(-1)

forward(item_seq)

Forward pass of the CORE model. Args: item_seq (Tensor): Padded sequences of item IDs [batch_size, max_seq_len]. Returns: Tensor: The session representation [batch_size, embedding_size].

Source code in warprec/recommenders/sequential_recommender/core.py
def forward(self, item_seq: Tensor) -> Tensor:
    """Forward pass of the CORE model.
    Args:
        item_seq (Tensor): Padded sequences of item IDs [batch_size, max_seq_len].
    Returns:
        Tensor: The session representation [batch_size, embedding_size].
    """
    # Get item embeddings
    x = self.item_embedding(item_seq)
    x = self.session_dropout_layer(x)

    # Representation-Consistent Encoder (RCE):
    # Calculate weights alpha and perform weighted sum of item embeddings
    alpha = self.net(item_seq, x)
    seq_output = torch.sum(alpha * x, dim=1)

    # Normalize output for Robust Distance Measuring (RDM)
    return F.normalize(seq_output, dim=-1)

predict(user_seq, *args, item_indices=None, **kwargs)

Prediction using the learned session embeddings.

Parameters:

Name Type Description Default
user_seq Tensor

Padded sequences of item IDs for users to predict for.

required
*args Any

List of arguments.

()
item_indices Optional[Tensor]

The batch of item indices. If None, full prediction will be produced.

None
**kwargs Any

The dictionary of keyword arguments.

{}

Returns:

Name Type Description
Tensor Tensor

The score matrix {user x item}.

Source code in warprec/recommenders/sequential_recommender/core.py
def predict(
    self,
    user_seq: Tensor,
    *args: Any,
    item_indices: Optional[Tensor] = None,
    **kwargs: Any,
) -> Tensor:
    """
    Prediction using the learned session embeddings.

    Args:
        user_seq (Tensor): Padded sequences of item IDs for users to predict for.
        *args (Any): List of arguments.
        item_indices (Optional[Tensor]): The batch of item indices. If None,
            full prediction will be produced.
        **kwargs (Any): The dictionary of keyword arguments.

    Returns:
        Tensor: The score matrix {user x item}.
    """
    # Get session representation
    seq_output = self.forward(user_seq)

    if item_indices is None:
        # Predict scores for all items
        item_embeddings = self.item_embedding.weight[:-1, :]
        item_embeddings = F.normalize(item_embeddings, dim=-1)
        einsum_string = "be,ie->bi"
    else:
        # Predict scores for a specific subset of items
        item_embeddings = self.item_embedding(item_indices)
        item_embeddings = F.normalize(item_embeddings, dim=-1)
        einsum_string = "be,bse->bs"

    # Calculate similarity and apply temperature scaling
    predictions = torch.einsum(einsum_string, seq_output, item_embeddings)
    return predictions / self.temperature

warprec.recommenders.sequential_recommender.duorec.DuoRec

Bases: IterativeRecommender, SequentialRecommenderUtils

Implementation of DuoRec model "Contrastive Learning for Representation Degeneration Problem in Sequential Recommendation" in WSDM 2022.

DuoRec extends a SASRec-style backbone with two contrastive regularizers: 1. Unsupervised CL from two stochastic forward passes of the same sequence. 2. Supervised CL from another sequence sharing the same next-item target.

Parameters:

Name Type Description Default
params dict

Model parameters.

required
info dict

The dictionary containing dataset information.

required
*args Any

Variable length argument list.

()
seed int

The seed to use for reproducibility.

42
**kwargs Any

Arbitrary keyword arguments.

{}

Attributes:

Name Type Description
DATALOADER_TYPE

The type of dataloader used.

embedding_size int

Dimension of item and position embeddings.

n_layers int

Number of transformer encoder layers.

n_heads int

Number of attention heads in the transformer.

inner_size int

Dimension of the feedforward network in the transformer.

dropout_prob float

Dropout probability for embeddings.

attn_dropout_prob float

Dropout probability for attention weights.

ssl_type str

Type of self-supervised learning ("us", "su", "un", "us_x").

ssl_lambda float

Weight for the unsupervised CL loss.

ssl_lambda_sem float

Weight for the supervised CL loss.

tau float

Temperature parameter for contrastive loss.

sim_type str

Similarity metric for contrastive loss ("dot" or "cos").

reg_weight float

Weight for the embedding regularization loss.

weight_decay float

L2 regularization weight for optimizer.

batch_size int

Training batch size.

epochs int

Number of training epochs.

learning_rate float

Learning rate for optimizer.

neg_samples int

Number of negative samples for training.

max_seq_len int

Maximum length of input sequences.

Source code in warprec/recommenders/sequential_recommender/duorec.py
@model_registry.register(name="DuoRec")
class DuoRec(IterativeRecommender, SequentialRecommenderUtils):
    """Implementation of DuoRec model
    "Contrastive Learning for Representation Degeneration Problem in Sequential Recommendation" in WSDM 2022.

    DuoRec extends a SASRec-style backbone with two contrastive regularizers:
    1. Unsupervised CL from two stochastic forward passes of the same sequence.
    2. Supervised CL from another sequence sharing the same next-item target.

    Args:
        params (dict): Model parameters.
        info (dict): The dictionary containing dataset information.
        *args (Any): Variable length argument list.
        seed (int): The seed to use for reproducibility.
        **kwargs (Any): Arbitrary keyword arguments.

    Attributes:
        DATALOADER_TYPE: The type of dataloader used.
        embedding_size (int): Dimension of item and position embeddings.
        n_layers (int): Number of transformer encoder layers.
        n_heads (int): Number of attention heads in the transformer.
        inner_size (int): Dimension of the feedforward network in the transformer.
        dropout_prob (float): Dropout probability for embeddings.
        attn_dropout_prob (float): Dropout probability for attention weights.
        ssl_type (str): Type of self-supervised learning ("us", "su", "un", "us_x").
        ssl_lambda (float): Weight for the unsupervised CL loss.
        ssl_lambda_sem (float): Weight for the supervised CL loss.
        tau (float): Temperature parameter for contrastive loss.
        sim_type (str): Similarity metric for contrastive loss ("dot" or "cos").
        reg_weight (float): Weight for the embedding regularization loss.
        weight_decay (float): L2 regularization weight for optimizer.
        batch_size (int): Training batch size.
        epochs (int): Number of training epochs.
        learning_rate (float): Learning rate for optimizer.
        neg_samples (int): Number of negative samples for training.
        max_seq_len (int): Maximum length of input sequences.
    """

    DATALOADER_TYPE = DataLoaderType.SEQUENTIAL_LOADER

    # Model hyperparameters
    embedding_size: int
    n_layers: int
    n_heads: int
    inner_size: int
    dropout_prob: float
    attn_dropout_prob: float
    ssl_type: str
    ssl_lambda: float
    ssl_lambda_sem: float
    tau: float
    sim_type: str
    reg_weight: float
    weight_decay: float
    batch_size: int
    epochs: int
    learning_rate: float
    neg_samples: int
    max_seq_len: int

    def __init__(
        self,
        params: dict,
        info: dict,
        *args: Any,
        seed: int = 42,
        **kwargs: Any,
    ):
        super().__init__(params, info, *args, seed=seed, **kwargs)

        self.item_embedding = nn.Embedding(
            self.n_items + 1, self.embedding_size, padding_idx=self.n_items
        )
        self.position_embedding = nn.Embedding(self.max_seq_len, self.embedding_size)
        self.emb_dropout = nn.Dropout(self.dropout_prob)
        self.layernorm = nn.LayerNorm(self.embedding_size, eps=1e-8)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=self.embedding_size,
            nhead=self.n_heads,
            dim_feedforward=self.inner_size,
            dropout=self.attn_dropout_prob,
            activation="relu",
            batch_first=True,
            norm_first=False,
        )
        self.transformer_encoder = nn.TransformerEncoder(
            encoder_layer,
            num_layers=self.n_layers,
        )

        causal_mask = self._generate_square_subsequent_mask(self.max_seq_len)
        self.register_buffer("causal_mask", causal_mask)
        self.register_buffer(
            "position_ids", torch.arange(self.max_seq_len, dtype=torch.long)
        )

        self.apply(self._init_weights)

        self.main_loss = nn.CrossEntropyLoss()
        self.reg_loss = EmbLoss()

    def get_dataloader(
        self,
        interactions: Interactions,
        sessions: Sessions,
        **kwargs,
    ):
        return sessions.get_same_target_sequential_dataloader(
            max_seq_len=self.max_seq_len,
            batch_size=self.batch_size,
            **kwargs,
        )

    def training_step(self, batch: Any, batch_idx: int):
        item_seq, item_seq_len, pos_item, sem_seq, sem_seq_len, has_sem_pos = batch

        seq_output = self.forward(item_seq, item_seq_len)
        logits = torch.matmul(
            seq_output, self.item_embedding.weight[:-1].transpose(0, 1)
        )
        main_loss = self.main_loss(logits, pos_item)
        reg_loss = self.reg_weight * self.reg_loss(
            self.item_embedding(item_seq),
            self.item_embedding(pos_item),
        )
        total_loss = main_loss + reg_loss

        ssl_mode = self.ssl_type.lower()

        if ssl_mode in {"us", "un", "us_x"} and self.ssl_lambda > 0:
            aug_seq_output = self.forward(item_seq, item_seq_len)
            total_loss += self.ssl_lambda * self._contrastive_loss(
                seq_output, aug_seq_output, pos_item
            )

        if ssl_mode in {"su", "us_x"} and self.ssl_lambda_sem > 0:
            valid_sem_mask = has_sem_pos.bool()
            if valid_sem_mask.any():
                sem_seq_output = self.forward(
                    sem_seq[valid_sem_mask], sem_seq_len[valid_sem_mask]
                )
                total_loss += self.ssl_lambda_sem * self._contrastive_loss(
                    seq_output[valid_sem_mask],
                    sem_seq_output,
                    pos_item[valid_sem_mask],
                )

        return total_loss

    def forward(self, item_seq: Tensor, item_seq_len: Tensor) -> Tensor:
        """Encode the sequence and return the final valid hidden state."""
        seq_len = item_seq.size(1)
        position_ids = self.position_ids[:seq_len].unsqueeze(0).expand_as(item_seq)  # type: ignore[index]

        item_emb = self.item_embedding(item_seq)
        position_emb = self.position_embedding(position_ids)
        seq_emb = self.layernorm(item_emb + position_emb)
        seq_emb = self.emb_dropout(seq_emb)

        padding_mask = item_seq == self.n_items
        attention_mask = self.causal_mask[:seq_len, :seq_len]  # type: ignore[index]
        transformer_output = self.transformer_encoder(
            src=seq_emb,
            mask=attention_mask,
            src_key_padding_mask=padding_mask,
            is_causal=True,
        )

        return self._gather_indexes(transformer_output, item_seq_len - 1)

    def _contrastive_loss(
        self, z_i: Tensor, z_j: Tensor, target_items: Tensor
    ) -> Tensor:
        """InfoNCE with same-target samples removed from the negative pool."""
        batch_size = z_i.size(0)
        representations = torch.cat((z_i, z_j), dim=0)

        if self.sim_type == "cos":
            representations = F.normalize(representations, dim=1)
        elif self.sim_type != "dot":
            raise ValueError(f"Unknown similarity metric: {self.sim_type}")

        sim_matrix = torch.matmul(representations, representations.transpose(0, 1))
        sim_matrix = sim_matrix / self.tau

        targets = torch.cat((target_items, target_items), dim=0)
        total = 2 * batch_size
        row_indices = torch.arange(total, device=sim_matrix.device)
        positive_indices = (row_indices + batch_size) % total

        negative_mask = torch.ones(
            (total, total), dtype=torch.bool, device=sim_matrix.device
        )
        negative_mask.fill_diagonal_(False)
        negative_mask[row_indices, positive_indices] = False

        same_target_mask = targets.unsqueeze(0).eq(targets.unsqueeze(1))
        negative_mask &= ~same_target_mask

        logits = sim_matrix.masked_fill(~negative_mask, -1e9)
        logits[row_indices, positive_indices] = sim_matrix[
            row_indices, positive_indices
        ]

        return F.cross_entropy(logits, positive_indices)

    @torch.no_grad()
    def predict(
        self,
        user_indices: Tensor,
        *args: Any,
        item_indices: Optional[Tensor] = None,
        user_seq: Optional[Tensor] = None,
        seq_len: Optional[Tensor] = None,
        **kwargs: Any,
    ) -> Tensor:
        """Prediction using the learned sequence embeddings."""
        seq_output = self.forward(user_seq, seq_len)

        if item_indices is None:
            item_embeddings = self.item_embedding.weight[:-1, :]
            einsum_string = "be,ie->bi"
        else:
            item_embeddings = self.item_embedding(item_indices)
            einsum_string = "be,bse->bs"

        return torch.einsum(einsum_string, seq_output, item_embeddings)

forward(item_seq, item_seq_len)

Encode the sequence and return the final valid hidden state.

Source code in warprec/recommenders/sequential_recommender/duorec.py
def forward(self, item_seq: Tensor, item_seq_len: Tensor) -> Tensor:
    """Encode the sequence and return the final valid hidden state."""
    seq_len = item_seq.size(1)
    position_ids = self.position_ids[:seq_len].unsqueeze(0).expand_as(item_seq)  # type: ignore[index]

    item_emb = self.item_embedding(item_seq)
    position_emb = self.position_embedding(position_ids)
    seq_emb = self.layernorm(item_emb + position_emb)
    seq_emb = self.emb_dropout(seq_emb)

    padding_mask = item_seq == self.n_items
    attention_mask = self.causal_mask[:seq_len, :seq_len]  # type: ignore[index]
    transformer_output = self.transformer_encoder(
        src=seq_emb,
        mask=attention_mask,
        src_key_padding_mask=padding_mask,
        is_causal=True,
    )

    return self._gather_indexes(transformer_output, item_seq_len - 1)

predict(user_indices, *args, item_indices=None, user_seq=None, seq_len=None, **kwargs)

Prediction using the learned sequence embeddings.

Source code in warprec/recommenders/sequential_recommender/duorec.py
@torch.no_grad()
def predict(
    self,
    user_indices: Tensor,
    *args: Any,
    item_indices: Optional[Tensor] = None,
    user_seq: Optional[Tensor] = None,
    seq_len: Optional[Tensor] = None,
    **kwargs: Any,
) -> Tensor:
    """Prediction using the learned sequence embeddings."""
    seq_output = self.forward(user_seq, seq_len)

    if item_indices is None:
        item_embeddings = self.item_embedding.weight[:-1, :]
        einsum_string = "be,ie->bi"
    else:
        item_embeddings = self.item_embedding(item_indices)
        einsum_string = "be,bse->bs"

    return torch.einsum(einsum_string, seq_output, item_embeddings)

warprec.recommenders.sequential_recommender.esasrec.eSASRec

Bases: IterativeRecommender, SequentialRecommenderUtils

Implementation of eSASRec from "eSASRec: Enhancing Transformer-based Recommendations in a Modular Fashion."

The model is built around the winning combination described in the paper: shifted-sequence objective, LiGR Transformer blocks, and sampled softmax, with optional mixed negative sampling.

Parameters:

Name Type Description Default
params dict

Model parameters.

required
info dict

The dictionary containing dataset information.

required
*args Any

Variable length argument list.

()
seed int

The seed to use for reproducibility.

42
**kwargs Any

Arbitrary keyword arguments.

{}

Attributes:

Name Type Description
DATALOADER_TYPE

The type of dataloader used.

embedding_size int

Dimension of item and position embeddings.

n_layers int

Number of transformer encoder layers.

n_heads int

Number of attention heads in the transformer.

inner_size int

Dimension of the feedforward network in the transformer.

dropout_prob float

Dropout probability for embeddings.

attn_dropout_prob float

Dropout probability for attention weights.

use_relative_pos bool

Whether to use relative positional embeddings.

use_sampled_softmax bool

Whether to use sampled softmax loss.

use_ligr bool

Whether to use LiGR blocks instead of standard transformer layers.

mn_ratio float

Ratio of in-batch negatives to uniform negatives when using mixed negative sampling.

reg_weight float

Weight for the embedding regularization loss.

weight_decay float

L2 regularization weight for optimizer.

batch_size int

Training batch size.

epochs int

Number of training epochs.

learning_rate float

Learning rate for optimizer.

neg_samples int

Number of negative samples for training.

max_seq_len int

Maximum length of input sequences.

Source code in warprec/recommenders/sequential_recommender/esasrec.py
@model_registry.register(name="eSASRec")
class eSASRec(IterativeRecommender, SequentialRecommenderUtils):
    """Implementation of eSASRec from
    "eSASRec: Enhancing Transformer-based Recommendations in a Modular Fashion."

    The model is built around the winning combination described in the paper:
    shifted-sequence objective, LiGR Transformer blocks, and sampled softmax,
    with optional mixed negative sampling.

    Args:
        params (dict): Model parameters.
        info (dict): The dictionary containing dataset information.
        *args (Any): Variable length argument list.
        seed (int): The seed to use for reproducibility.
        **kwargs (Any): Arbitrary keyword arguments.

    Attributes:
        DATALOADER_TYPE: The type of dataloader used.
        embedding_size (int): Dimension of item and position embeddings.
        n_layers (int): Number of transformer encoder layers.
        n_heads (int): Number of attention heads in the transformer.
        inner_size (int): Dimension of the feedforward network in the transformer.
        dropout_prob (float): Dropout probability for embeddings.
        attn_dropout_prob (float): Dropout probability for attention weights.
        use_relative_pos (bool): Whether to use relative positional embeddings.
        use_sampled_softmax (bool): Whether to use sampled softmax loss.
        use_ligr (bool): Whether to use LiGR blocks instead of standard transformer layers.
        mn_ratio (float): Ratio of in-batch negatives to uniform negatives when using mixed negative sampling.
        reg_weight (float): Weight for the embedding regularization loss.
        weight_decay (float): L2 regularization weight for optimizer.
        batch_size (int): Training batch size.
        epochs (int): Number of training epochs.
        learning_rate (float): Learning rate for optimizer.
        neg_samples (int): Number of negative samples for training.
        max_seq_len (int): Maximum length of input sequences.
    """

    DATALOADER_TYPE = DataLoaderType.SEQUENTIAL_LOADER

    embedding_size: int
    n_layers: int
    n_heads: int
    inner_size: int
    dropout_prob: float
    attn_dropout_prob: float
    use_relative_pos: bool
    use_sampled_softmax: bool
    use_ligr: bool
    mn_ratio: float
    reg_weight: float
    weight_decay: float
    batch_size: int
    epochs: int
    learning_rate: float
    neg_samples: int
    max_seq_len: int

    def __init__(
        self,
        params: dict,
        info: dict,
        *args: Any,
        seed: int = 42,
        **kwargs: Any,
    ):
        super().__init__(params, info, *args, seed=seed, **kwargs)

        self.item_embedding = nn.Embedding(
            self.n_items + 1, self.embedding_size, padding_idx=self.n_items
        )
        # eSASRec keeps SASRec's absolute positions for the shifted-sequence objective.
        self.position_embedding = nn.Embedding(self.max_seq_len, self.embedding_size)

        self.emb_dropout = nn.Dropout(self.dropout_prob)
        self.layernorm = nn.LayerNorm(self.embedding_size, eps=1e-8)

        if self.use_ligr:
            self.encoder = nn.ModuleList(
                [
                    LiGRBlock(
                        d_model=self.embedding_size,
                        n_heads=self.n_heads,
                        d_ff=self.inner_size,
                        attn_dropout=self.attn_dropout_prob,
                        dropout=self.dropout_prob,
                    )
                    for _ in range(self.n_layers)
                ]
            )
        else:
            encoder_layer = nn.TransformerEncoderLayer(
                d_model=self.embedding_size,
                nhead=self.n_heads,
                dim_feedforward=self.inner_size,
                dropout=self.attn_dropout_prob,
                activation="gelu",
                batch_first=True,
                norm_first=False,
            )
            self.transformer_encoder = nn.TransformerEncoder(
                encoder_layer,
                num_layers=self.n_layers,
            )

        causal_mask = self._generate_square_subsequent_mask(self.max_seq_len)
        self.register_buffer("causal_mask", causal_mask)

        self.apply(self._init_weights)

        self.full_softmax_loss = nn.CrossEntropyLoss()
        self.sampled_softmax_loss = SampledSoftmaxLoss(temperature=1.0)
        self.reg_loss = EmbLoss()

    def get_dataloader(
        self,
        interactions: Interactions,
        sessions: Sessions,
        **kwargs,
    ):
        return sessions.get_sequential_dataloader(
            max_seq_len=self.max_seq_len,
            neg_samples=self.neg_samples,
            batch_size=self.batch_size,
            **kwargs,
        )

    def _mix_negative_items(
        self, pos_item: Tensor, neg_item: Optional[Tensor]
    ) -> Optional[Tensor]:
        """Mix uniform negatives with in-batch negatives when requested."""
        if (
            neg_item is None
            or self.neg_samples <= 0
            or self.mn_ratio <= 0
            or pos_item.size(0) <= 1
        ):
            return neg_item

        num_in_batch = int(round(self.neg_samples * self.mn_ratio))
        num_in_batch = max(0, min(self.neg_samples, num_in_batch))
        if num_in_batch == 0:
            return neg_item

        num_uniform = self.neg_samples - num_in_batch
        uniform_part = (
            neg_item[:, :num_uniform]
            if num_uniform > 0
            else torch.empty(
                pos_item.size(0),
                0,
                dtype=pos_item.dtype,
                device=pos_item.device,
            )
        )

        batch_size = pos_item.size(0)

        # Create a boolean candidate mask of shape [B, B]
        # candidate_mask[i, j] is True if pos_item[j] != pos_item[i]
        # This naturally avoids self-negation (j == i) and duplicate items in the batch
        candidate_mask = pos_item.unsqueeze(1) != pos_item.unsqueeze(0)

        # Assign random weights to all batch elements to perform uniform sampling
        rand_weights = torch.rand(batch_size, batch_size, device=pos_item.device)

        # Zero out weights for invalid candidates
        rand_weights = torch.where(
            candidate_mask, rand_weights, torch.zeros_like(rand_weights)
        )

        # Retrieve the top-k indices with the largest weights
        sampled_vals, sampled_idx = torch.topk(rand_weights, num_in_batch, dim=1)

        # Handle potential edge cases where a row has fewer valid candidates than num_in_batch
        # Create a deterministic fallback index matrix (e.g., shifting indices safely)
        row_indices = (
            torch.arange(batch_size, device=pos_item.device)
            .unsqueeze(1)
            .expand(-1, num_in_batch)
        )
        shift_offsets = torch.arange(
            1, num_in_batch + 1, device=pos_item.device
        ).unsqueeze(0)
        fallback_idx = (row_indices + shift_offsets) % batch_size

        # Check if the sampled top-k items are actually valid (weight > 0)
        is_valid = sampled_vals > 0.0
        final_idx = torch.where(is_valid, sampled_idx, fallback_idx)

        # Gather the final mixed in-batch negatives
        in_batch_part = pos_item[final_idx]

        return torch.cat([uniform_part, in_batch_part], dim=1)

    def _compute_main_loss(
        self, seq_output: Tensor, pos_item: Tensor, neg_item: Optional[Tensor]
    ) -> Tensor:
        if self.use_sampled_softmax and neg_item is not None and neg_item.size(1) > 0:
            pos_items_emb = self.item_embedding(pos_item)
            neg_items_emb = self.item_embedding(neg_item)
            return self.sampled_softmax_loss(seq_output, pos_items_emb, neg_items_emb)

        item_embeddings = self.item_embedding.weight[:-1, :]
        logits = torch.matmul(seq_output, item_embeddings.transpose(0, 1))
        return self.full_softmax_loss(logits, pos_item)

    def training_step(self, batch: Any, batch_idx: int):
        if len(batch) == 4:
            item_seq, item_seq_len, pos_item, neg_item = batch
        else:
            item_seq, item_seq_len, pos_item = batch
            neg_item = None

        neg_item = self._mix_negative_items(pos_item, neg_item)
        seq_output = self.forward(item_seq, item_seq_len)
        main_loss = self._compute_main_loss(seq_output, pos_item, neg_item)

        reg_terms = [self.item_embedding(item_seq), self.item_embedding(pos_item)]
        if neg_item is not None:
            reg_terms.append(self.item_embedding(neg_item))
        reg_loss = self.reg_weight * self.reg_loss(*reg_terms)

        return main_loss + reg_loss

    def forward(self, item_seq: Tensor, item_seq_len: Tensor) -> Tensor:
        """Forward pass with shifted-sequence causal masking."""
        seq_len = item_seq.size(1)
        padding_mask = item_seq == self.n_items

        position_ids = torch.arange(seq_len, dtype=torch.long, device=item_seq.device)
        position_ids = position_ids.unsqueeze(0).expand_as(item_seq)

        item_emb = self.item_embedding(item_seq)
        pos_emb = self.position_embedding(position_ids)
        seq_emb = self.layernorm(item_emb + pos_emb)
        seq_emb = self.emb_dropout(seq_emb)

        attn_mask = self.causal_mask[:seq_len, :seq_len]  # type: ignore[index]

        if self.use_ligr:
            output = seq_emb
            for layer in self.encoder:
                output = layer(output, attn_mask, padding_mask)
            transformer_output = output
        else:
            transformer_output = self.transformer_encoder(
                src=seq_emb,
                mask=attn_mask,
                src_key_padding_mask=padding_mask,
            )

        return self._gather_indexes(transformer_output, item_seq_len - 1)

    @torch.no_grad()
    def predict(
        self,
        user_indices: Tensor,
        *args: Any,
        item_indices: Optional[Tensor] = None,
        user_seq: Optional[Tensor] = None,
        seq_len: Optional[Tensor] = None,
        **kwargs: Any,
    ) -> Tensor:
        """Prediction using the learned session embeddings."""
        seq_output = self.forward(user_seq, seq_len)

        if item_indices is None:
            item_embeddings = self.item_embedding.weight[:-1, :]
            einsum_string = "be,ie->bi"
        else:
            item_embeddings = self.item_embedding(item_indices)
            einsum_string = "be,bse->bs"

        return torch.einsum(einsum_string, seq_output, item_embeddings)

forward(item_seq, item_seq_len)

Forward pass with shifted-sequence causal masking.

Source code in warprec/recommenders/sequential_recommender/esasrec.py
def forward(self, item_seq: Tensor, item_seq_len: Tensor) -> Tensor:
    """Forward pass with shifted-sequence causal masking."""
    seq_len = item_seq.size(1)
    padding_mask = item_seq == self.n_items

    position_ids = torch.arange(seq_len, dtype=torch.long, device=item_seq.device)
    position_ids = position_ids.unsqueeze(0).expand_as(item_seq)

    item_emb = self.item_embedding(item_seq)
    pos_emb = self.position_embedding(position_ids)
    seq_emb = self.layernorm(item_emb + pos_emb)
    seq_emb = self.emb_dropout(seq_emb)

    attn_mask = self.causal_mask[:seq_len, :seq_len]  # type: ignore[index]

    if self.use_ligr:
        output = seq_emb
        for layer in self.encoder:
            output = layer(output, attn_mask, padding_mask)
        transformer_output = output
    else:
        transformer_output = self.transformer_encoder(
            src=seq_emb,
            mask=attn_mask,
            src_key_padding_mask=padding_mask,
        )

    return self._gather_indexes(transformer_output, item_seq_len - 1)

predict(user_indices, *args, item_indices=None, user_seq=None, seq_len=None, **kwargs)

Prediction using the learned session embeddings.

Source code in warprec/recommenders/sequential_recommender/esasrec.py
@torch.no_grad()
def predict(
    self,
    user_indices: Tensor,
    *args: Any,
    item_indices: Optional[Tensor] = None,
    user_seq: Optional[Tensor] = None,
    seq_len: Optional[Tensor] = None,
    **kwargs: Any,
) -> Tensor:
    """Prediction using the learned session embeddings."""
    seq_output = self.forward(user_seq, seq_len)

    if item_indices is None:
        item_embeddings = self.item_embedding.weight[:-1, :]
        einsum_string = "be,ie->bi"
    else:
        item_embeddings = self.item_embedding(item_indices)
        einsum_string = "be,bse->bs"

    return torch.einsum(einsum_string, seq_output, item_embeddings)

warprec.recommenders.sequential_recommender.gsasrec.gSASRec

Bases: IterativeRecommender, SequentialRecommenderUtils

Implementation of gSASRec algorithm from "gSASRec: Reducing Overconfidence in Sequential Recommendation Trained with Negative Sampling." in RecSys 2023.

This model adapts the SASRec architecture to predict the next item at every step of the sequence, using a Group-wise Binary Cross-Entropy (GBCE) loss function.

Parameters:

Name Type Description Default
params dict

Model parameters.

required
info dict

The dictionary containing dataset information.

required
*args Any

Variable length argument list.

()
seed int

The seed to use for reproducibility.

42
**kwargs Any

Arbitrary keyword arguments.

{}

Attributes:

Name Type Description
DATALOADER_TYPE

The type of dataloader used.

embedding_size int

The dimension of the item embeddings (hidden_size).

n_layers int

The number of transformer encoder layers.

n_heads int

The number of attention heads in the transformer.

inner_size int

The dimensionality of the feed-forward layer in the transformer.

dropout_prob float

The probability of dropout for embeddings and other layers.

attn_dropout_prob float

The probability of dropout for the attention weights.

reg_weight float

The L2 regularization weight.

weight_decay float

The value of weight decay used in the optimizer.

batch_size int

The batch size used during training.

epochs int

The number of training epochs.

learning_rate float

The learning rate value.

gbce_t float

The temperature parameter for the Group-wise Binary Cross-Entropy loss.

neg_samples int

The number of negative samples.

max_seq_len int

The maximum length of sequences.

reuse_item_embeddings bool

Whether to reuse item embeddings for output or not.

Source code in warprec/recommenders/sequential_recommender/gsasrec.py
@model_registry.register(name="gSASRec")
class gSASRec(IterativeRecommender, SequentialRecommenderUtils):
    """Implementation of gSASRec algorithm from
    "gSASRec: Reducing Overconfidence in Sequential Recommendation Trained with Negative Sampling." in RecSys 2023.

    This model adapts the SASRec architecture to predict the next item at every
    step of the sequence, using a Group-wise Binary Cross-Entropy (GBCE) loss function.

    Args:
        params (dict): Model parameters.
        info (dict): The dictionary containing dataset information.
        *args (Any): Variable length argument list.
        seed (int): The seed to use for reproducibility.
        **kwargs (Any): Arbitrary keyword arguments.

    Attributes:
        DATALOADER_TYPE: The type of dataloader used.
        embedding_size (int): The dimension of the item embeddings (hidden_size).
        n_layers (int): The number of transformer encoder layers.
        n_heads (int): The number of attention heads in the transformer.
        inner_size (int): The dimensionality of the feed-forward layer in the transformer.
        dropout_prob (float): The probability of dropout for embeddings and other layers.
        attn_dropout_prob (float): The probability of dropout for the attention weights.
        reg_weight (float): The L2 regularization weight.
        weight_decay (float): The value of weight decay used in the optimizer.
        batch_size (int): The batch size used during training.
        epochs (int): The number of training epochs.
        learning_rate (float): The learning rate value.
        gbce_t (float): The temperature parameter for the Group-wise Binary Cross-Entropy loss.
        neg_samples (int): The number of negative samples.
        max_seq_len (int): The maximum length of sequences.
        reuse_item_embeddings (bool): Whether to reuse item embeddings for output or not.
    """

    # Dataloader definition
    DATALOADER_TYPE = DataLoaderType.USER_HISTORY_LOADER

    # Model hyperparameters
    embedding_size: int
    n_layers: int
    n_heads: int
    inner_size: int
    dropout_prob: float
    attn_dropout_prob: float
    reg_weight: float
    weight_decay: float
    batch_size: int
    epochs: int
    learning_rate: float
    gbce_t: float
    neg_samples: int
    max_seq_len: int
    reuse_item_embeddings: bool

    def __init__(
        self,
        params: dict,
        info: dict,
        *args: Any,
        seed: int = 42,
        **kwargs: Any,
    ):
        super().__init__(params, info, *args, seed=seed, **kwargs)

        self.item_embedding = nn.Embedding(
            self.n_items + 1, self.embedding_size, padding_idx=self.n_items
        )
        self.position_embedding = nn.Embedding(self.max_seq_len, self.embedding_size)

        if not self.reuse_item_embeddings:
            self.output_embedding = nn.Embedding(
                self.n_items + 1, self.embedding_size, padding_idx=self.n_items
            )

        self.emb_dropout = nn.Dropout(self.dropout_prob)
        self.layernorm = nn.LayerNorm(self.embedding_size, eps=1e-8)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=self.embedding_size,
            nhead=self.n_heads,
            dim_feedforward=self.inner_size,
            dropout=self.attn_dropout_prob,
            activation="gelu",
            batch_first=True,
            norm_first=False,
        )
        self.transformer_encoder = nn.TransformerEncoder(
            encoder_layer,
            num_layers=self.n_layers,
        )

        # Precompute causal mask
        causal_mask = self._generate_square_subsequent_mask(self.max_seq_len)
        self.register_buffer("causal_mask", causal_mask)

        # Initialize weights
        self.apply(self._init_weights)
        self.gbce_loss = self._gbce_loss_function()
        self.reg_loss = EmbLoss()

    def _get_output_embeddings(self) -> nn.Embedding:
        """Return embeddings based on the flag value reuse_item_embeddings.

        Returns:
            nn.Embedding: The item embedding if reuse_item_embeddings is True,
                else the output embedding.
        """
        if self.reuse_item_embeddings:
            return self.item_embedding
        return self.output_embedding

    def get_dataloader(
        self,
        interactions: Interactions,
        sessions: Sessions,
        **kwargs: Any,
    ):
        return sessions.get_sliding_window_dataloader(
            max_seq_len=self.max_seq_len,
            neg_samples=self.neg_samples,
            batch_size=self.batch_size,
            **kwargs,
        )

    def forward(self, item_seq: Tensor) -> Tensor:
        """Forward pass of gSASRec. Returns the output of the Transformer
        for each token in the input sequence.

        Args:
            item_seq (Tensor): Sequence of items [batch_size, seq_len].

        Returns:
            Tensor: Output of the Transformer encoder [batch_size, seq_len, embedding_size].
        """
        seq_len = item_seq.size(1)
        padding_mask = item_seq == self.n_items

        position_ids = torch.arange(seq_len, dtype=torch.long).to(item_seq.device)
        position_ids = position_ids.unsqueeze(0).expand_as(item_seq)

        item_emb = self.item_embedding(item_seq)
        pos_emb = self.position_embedding(position_ids)

        seq_emb = self.layernorm(item_emb + pos_emb)
        seq_emb = self.emb_dropout(seq_emb)

        transformer_output = self.transformer_encoder(
            src=seq_emb,
            mask=self.causal_mask[:seq_len, :seq_len],  # type:ignore [index]
            src_key_padding_mask=padding_mask,
        )
        return transformer_output

    def _gbce_loss_function(self) -> Callable:
        """Return the General Binary Cross-Entropy (GBCE) loss.

        Returns:
            Callable: The GBCE loss.
        """

        def gbce_loss_fn(
            sequence_hidden_states: Tensor,
            labels: Tensor,
            negatives: Tensor,
            model_input: Tensor,
        ):
            pos_neg_concat = torch.cat([labels.unsqueeze(-1), negatives], dim=-1)
            pos_neg_embeddings = self._get_output_embeddings()(pos_neg_concat)

            logits = torch.einsum(
                "bse, bsne -> bsn", sequence_hidden_states, pos_neg_embeddings
            )

            gt = torch.zeros_like(logits).to(logits.device)
            gt[:, :, 0] = 1.0

            alpha = self.neg_samples / (self.n_items - 1)
            t = self.gbce_t
            beta = alpha * ((1 - 1 / alpha) * t + 1 / alpha)

            positive_logits = logits[:, :, 0:1].to(torch.float64)
            negative_logits = logits[:, :, 1:].to(torch.float64)
            eps = 1e-10

            positive_probs = torch.clamp(torch.sigmoid(positive_logits), eps, 1 - eps)
            positive_probs_pow = torch.clamp(
                positive_probs.pow(-beta),
                min=1.0 + eps,
                max=torch.finfo(torch.float64).max,
            )
            to_log = torch.clamp(
                torch.div(1.0, (positive_probs_pow - 1)),
                eps,
                torch.finfo(torch.float64).max,
            )
            positive_logits_transformed = to_log.log()

            final_logits = torch.cat(
                [positive_logits_transformed, negative_logits], -1
            ).to(torch.float32)

            mask = (labels != self.n_items).float()
            loss_per_element = nn.functional.binary_cross_entropy_with_logits(
                final_logits, gt, reduction="none"
            )

            loss_per_element = loss_per_element.mean(-1) * mask
            total_loss = loss_per_element.sum() / mask.sum().clamp(min=1)
            return total_loss

        return gbce_loss_fn

    def training_step(self, batch: Any, batch_idx: int):
        positives, negatives = batch

        if positives.shape[0] == 0 or positives.shape[1] < 2:
            return torch.tensor(0.0, requires_grad=True).to(positives.device)

        model_input = positives[:, :-1]
        labels = positives[:, 1:]
        negatives = negatives[:, 1:, :]

        if model_input.shape[1] == 0:
            return torch.tensor(0.0, requires_grad=True).to(positives.device)

        # Calculate GBCE loss
        sequence_hidden_states = self.forward(model_input)
        gbce_loss = self.gbce_loss(
            sequence_hidden_states, labels, negatives, model_input
        )

        # Calculate L2 regularization
        reg_loss = self.reg_weight * self.reg_loss(
            self.item_embedding(model_input),
            self._get_output_embeddings()(labels),
            self._get_output_embeddings()(negatives),
        )

        # Loss logging
        loss = gbce_loss + reg_loss
        self.log("loss", loss, prog_bar=True, on_step=False, on_epoch=True)
        return loss

    def predict(
        self,
        user_seq: Tensor,
        seq_len: Tensor,
        *args: Any,
        item_indices: Optional[Tensor] = None,
        **kwargs: Any,
    ) -> Tensor:
        """
        Prediction using the learned session embeddings.

        Args:
            user_seq (Tensor): Padded sequences of item IDs for users to predict for.
            seq_len (Tensor): Actual lengths of these sequences, before padding.
            *args (Any): List of arguments.
            item_indices (Optional[Tensor]): The batch of item indices. If None,
                full prediction will be produced.
            **kwargs (Any): The dictionary of keyword arguments.

        Returns:
            Tensor: The score matrix {user x item}.
        """
        # Get the transformer output for all tokens in the sequence
        transformer_output = self.forward(user_seq)

        # Get the embedding of the LAST relevant item for prediction
        seq_output = self._gather_indexes(
            transformer_output, seq_len - 1
        )  # [batch_size, embedding_size]

        target_item_embeddings = self._get_output_embeddings()
        if item_indices is None:
            # Case 'full': prediction on all items
            item_embeddings = target_item_embeddings.weight[
                :-1, :
            ]  # [n_items, embedding_size]
            einsum_string = "be,ie->bi"  # b: batch, e: embedding, i: item
        else:
            # Case 'sampled': prediction on a sampled set of items
            item_embeddings = target_item_embeddings(
                item_indices
            )  # [batch_size, pad_seq, embedding_size]
            einsum_string = "be,bse->bs"  # b: batch, e: embedding, s: sample

        predictions = torch.einsum(
            einsum_string, seq_output, item_embeddings
        )  # [batch_size, n_items] or [batch_size, pad_seq]
        return predictions

forward(item_seq)

Forward pass of gSASRec. Returns the output of the Transformer for each token in the input sequence.

Parameters:

Name Type Description Default
item_seq Tensor

Sequence of items [batch_size, seq_len].

required

Returns:

Name Type Description
Tensor Tensor

Output of the Transformer encoder [batch_size, seq_len, embedding_size].

Source code in warprec/recommenders/sequential_recommender/gsasrec.py
def forward(self, item_seq: Tensor) -> Tensor:
    """Forward pass of gSASRec. Returns the output of the Transformer
    for each token in the input sequence.

    Args:
        item_seq (Tensor): Sequence of items [batch_size, seq_len].

    Returns:
        Tensor: Output of the Transformer encoder [batch_size, seq_len, embedding_size].
    """
    seq_len = item_seq.size(1)
    padding_mask = item_seq == self.n_items

    position_ids = torch.arange(seq_len, dtype=torch.long).to(item_seq.device)
    position_ids = position_ids.unsqueeze(0).expand_as(item_seq)

    item_emb = self.item_embedding(item_seq)
    pos_emb = self.position_embedding(position_ids)

    seq_emb = self.layernorm(item_emb + pos_emb)
    seq_emb = self.emb_dropout(seq_emb)

    transformer_output = self.transformer_encoder(
        src=seq_emb,
        mask=self.causal_mask[:seq_len, :seq_len],  # type:ignore [index]
        src_key_padding_mask=padding_mask,
    )
    return transformer_output

predict(user_seq, seq_len, *args, item_indices=None, **kwargs)

Prediction using the learned session embeddings.

Parameters:

Name Type Description Default
user_seq Tensor

Padded sequences of item IDs for users to predict for.

required
seq_len Tensor

Actual lengths of these sequences, before padding.

required
*args Any

List of arguments.

()
item_indices Optional[Tensor]

The batch of item indices. If None, full prediction will be produced.

None
**kwargs Any

The dictionary of keyword arguments.

{}

Returns:

Name Type Description
Tensor Tensor

The score matrix {user x item}.

Source code in warprec/recommenders/sequential_recommender/gsasrec.py
def predict(
    self,
    user_seq: Tensor,
    seq_len: Tensor,
    *args: Any,
    item_indices: Optional[Tensor] = None,
    **kwargs: Any,
) -> Tensor:
    """
    Prediction using the learned session embeddings.

    Args:
        user_seq (Tensor): Padded sequences of item IDs for users to predict for.
        seq_len (Tensor): Actual lengths of these sequences, before padding.
        *args (Any): List of arguments.
        item_indices (Optional[Tensor]): The batch of item indices. If None,
            full prediction will be produced.
        **kwargs (Any): The dictionary of keyword arguments.

    Returns:
        Tensor: The score matrix {user x item}.
    """
    # Get the transformer output for all tokens in the sequence
    transformer_output = self.forward(user_seq)

    # Get the embedding of the LAST relevant item for prediction
    seq_output = self._gather_indexes(
        transformer_output, seq_len - 1
    )  # [batch_size, embedding_size]

    target_item_embeddings = self._get_output_embeddings()
    if item_indices is None:
        # Case 'full': prediction on all items
        item_embeddings = target_item_embeddings.weight[
            :-1, :
        ]  # [n_items, embedding_size]
        einsum_string = "be,ie->bi"  # b: batch, e: embedding, i: item
    else:
        # Case 'sampled': prediction on a sampled set of items
        item_embeddings = target_item_embeddings(
            item_indices
        )  # [batch_size, pad_seq, embedding_size]
        einsum_string = "be,bse->bs"  # b: batch, e: embedding, s: sample

    predictions = torch.einsum(
        einsum_string, seq_output, item_embeddings
    )  # [batch_size, n_items] or [batch_size, pad_seq]
    return predictions

warprec.recommenders.sequential_recommender.lightsans.LightSANs

Bases: IterativeRecommender, SequentialRecommenderUtils

Implementation of LightSANs algorithm from "Lighter and Better: Low-Rank Decomposed Self-Attention Networks for Next-Item Recommendation" (SIGIR 2021).

Parameters:

Name Type Description Default
params dict

Model parameters.

required
info dict

The dictionary containing dataset information.

required
*args Any

Variable length argument list.

()
seed int

The seed to use for reproducibility.

42
**kwargs Any

Arbitrary keyword arguments.

{}

Attributes:

Name Type Description
DATALOADER_TYPE

The type of dataloader used.

embedding_size int

The dimension of the item embeddings.

n_layers int

The number of attention layers.

n_heads int

The number of attention heads.

k_interests int

The number of latent interests (k).

inner_size int

The dimensionality of the feed-forward layer.

dropout_prob float

The probability of dropout.

attn_dropout_prob float

The probability of dropout for attention.

reg_weight float

The L2 regularization weight.

weight_decay float

The value of weight decay used in the optimizer.

batch_size int

The batch size used during training.

epochs int

The number of training epochs.

learning_rate float

The learning rate value.

neg_samples int

The number of negative samples.

max_seq_len int

The maximum length of sequences.

Source code in warprec/recommenders/sequential_recommender/lightsans.py
@model_registry.register(name="LightSANs")
class LightSANs(IterativeRecommender, SequentialRecommenderUtils):
    """Implementation of LightSANs algorithm from
    "Lighter and Better: Low-Rank Decomposed Self-Attention Networks for Next-Item Recommendation" (SIGIR 2021).

    Args:
        params (dict): Model parameters.
        info (dict): The dictionary containing dataset information.
        *args (Any): Variable length argument list.
        seed (int): The seed to use for reproducibility.
        **kwargs (Any): Arbitrary keyword arguments.

    Attributes:
        DATALOADER_TYPE: The type of dataloader used.
        embedding_size (int): The dimension of the item embeddings.
        n_layers (int): The number of attention layers.
        n_heads (int): The number of attention heads.
        k_interests (int): The number of latent interests (k).
        inner_size (int): The dimensionality of the feed-forward layer.
        dropout_prob (float): The probability of dropout.
        attn_dropout_prob (float): The probability of dropout for attention.
        reg_weight (float): The L2 regularization weight.
        weight_decay (float): The value of weight decay used in the optimizer.
        batch_size (int): The batch size used during training.
        epochs (int): The number of training epochs.
        learning_rate (float): The learning rate value.
        neg_samples (int): The number of negative samples.
        max_seq_len (int): The maximum length of sequences.
    """

    # Dataloader definition
    DATALOADER_TYPE = DataLoaderType.SEQUENTIAL_LOADER

    # Model hyperparameters
    embedding_size: int
    n_layers: int
    n_heads: int
    k_interests: int
    inner_size: int
    dropout_prob: float
    attn_dropout_prob: float
    reg_weight: float
    weight_decay: float
    batch_size: int
    epochs: int
    learning_rate: float
    neg_samples: int
    max_seq_len: int

    def __init__(
        self,
        params: dict,
        info: dict,
        *args: Any,
        seed: int = 42,
        **kwargs: Any,
    ):
        super().__init__(params, info, *args, seed=seed, **kwargs)

        self.item_embedding = nn.Embedding(
            self.n_items + 1, self.embedding_size, padding_idx=self.n_items
        )
        self.position_embedding = nn.Embedding(self.max_seq_len, self.embedding_size)

        self.emb_dropout = nn.Dropout(self.dropout_prob)
        self.layernorm = nn.LayerNorm(self.embedding_size, eps=1e-8)

        # LightSANs Layers
        self.layers = nn.ModuleList(
            [
                LightSANsLayer(
                    n_heads=self.n_heads,
                    k_interests=self.k_interests,
                    hidden_size=self.embedding_size,
                    inner_size=self.inner_size,
                    dropout_prob=self.dropout_prob,
                    attn_dropout_prob=self.attn_dropout_prob,
                    layer_norm_eps=1e-8,
                )
                for _ in range(self.n_layers)
            ]
        )

        # Precompute causal mask
        causal_mask = self._generate_square_subsequent_mask(self.max_seq_len)
        self.register_buffer("causal_mask", causal_mask)

        # Initialize weights
        self.apply(self._init_weights)

        # Loss function
        self.main_loss: nn.Module
        if self.neg_samples > 0:
            self.main_loss = BPRLoss()
        else:
            self.main_loss = nn.CrossEntropyLoss()
        self.reg_loss = EmbLoss()

    def get_dataloader(
        self,
        interactions: Interactions,
        sessions: Sessions,
        **kwargs: Any,
    ):
        return sessions.get_sequential_dataloader(
            max_seq_len=self.max_seq_len,
            neg_samples=self.neg_samples,
            batch_size=self.batch_size,
            **kwargs,
        )

    def training_step(self, batch: Any, batch_idx: int):
        if self.neg_samples > 0:
            item_seq, item_seq_len, pos_item, neg_item = batch
        else:
            item_seq, item_seq_len, pos_item = batch
            neg_item = None

        seq_output = self.forward(item_seq, item_seq_len)

        # Calculate main loss and L2 regularization
        if self.neg_samples > 0:
            pos_items_emb = self.item_embedding(pos_item)
            neg_items_emb = self.item_embedding(neg_item)

            pos_score = torch.sum(seq_output * pos_items_emb, dim=-1)
            neg_score = torch.sum(seq_output.unsqueeze(1) * neg_items_emb, dim=-1)
            main_loss = self.main_loss(pos_score, neg_score)

            # L2 regularization
            reg_loss = self.reg_weight * self.reg_loss(
                self.item_embedding(item_seq),
                self.item_embedding(pos_item),
                self.item_embedding(neg_item),
            )
        else:
            test_item_emb = self.item_embedding.weight
            logits = torch.matmul(seq_output, test_item_emb.transpose(0, 1))
            main_loss = self.main_loss(logits, pos_item)

            # L2 regularization
            reg_loss = self.reg_weight * self.reg_loss(
                self.item_embedding(item_seq),
                self.item_embedding(pos_item),
            )

        # Loss logging
        loss = main_loss + reg_loss
        self.log("loss", loss, prog_bar=True, on_step=False, on_epoch=True)
        return loss

    def forward(self, item_seq: Tensor, item_seq_len: Tensor) -> Tensor:
        """Forward pass of the LightSANs model.

        Args:
            item_seq (Tensor): Padded sequences of item IDs [batch_size, max_seq_len].
            item_seq_len (Tensor): Actual lengths of sequences [batch_size,].

        Returns:
            Tensor: The embedding of the predicted item (last session state).
        """
        seq_len = item_seq.size(1)

        # Padding mask (True where padding exists)
        padding_mask = item_seq == self.n_items

        # Get embeddings
        x = self.item_embedding(item_seq)
        x = self.layernorm(x)
        x = self.emb_dropout(x)

        # Get Position Embeddings for the current sequence length
        # These are shared across the batch
        position_ids = torch.arange(seq_len, dtype=torch.long, device=item_seq.device)
        pos_emb = self.position_embedding(position_ids)

        # Causal mask for the current sequence length
        # We slice the precomputed mask
        curr_causal_mask = self.causal_mask[:seq_len, :seq_len]  # type: ignore[index]

        # Pass through LightSANs Layers
        for layer in self.layers:
            x = layer(
                x,
                pos_emb=pos_emb,
                padding_mask=padding_mask,
                causal_mask=curr_causal_mask,
            )

        # Gather the output of the last relevant item in each sequence
        seq_output = self._gather_indexes(x, item_seq_len - 1)

        return seq_output

    def predict(
        self,
        user_seq: Tensor,
        seq_len: Tensor,
        *args: Any,
        item_indices: Optional[Tensor] = None,
        **kwargs: Any,
    ) -> Tensor:
        """
        Prediction using the learned session embeddings.

        Args:
            user_seq (Tensor): Padded sequences of item IDs for users to predict for.
            seq_len (Tensor): Actual lengths of these sequences, before padding.
            *args (Any): List of arguments.
            item_indices (Optional[Tensor]): The batch of item indices. If None,
                full prediction will be produced.
            **kwargs (Any): The dictionary of keyword arguments.

        Returns:
            Tensor: The score matrix {user x item}.
        """
        # Get sequence output embeddings
        seq_output = self.forward(user_seq, seq_len)  # [batch_size, embedding_size]

        if item_indices is None:
            # Case 'full': prediction on all items
            item_embeddings = self.item_embedding.weight[
                :-1, :
            ]  # [n_items, embedding_size]
            einsum_string = "be,ie->bi"  # b: batch, e: embedding, i: item
        else:
            # Case 'sampled': prediction on a sampled set of items
            item_embeddings = self.item_embedding(
                item_indices
            )  # [batch_size, pad_seq, embedding_size]
            einsum_string = "be,bse->bs"  # b: batch, e: embedding, s: sample

        predictions = torch.einsum(
            einsum_string, seq_output, item_embeddings
        )  # [batch_size, n_items] or [batch_size, pad_seq]
        return predictions

forward(item_seq, item_seq_len)

Forward pass of the LightSANs model.

Parameters:

Name Type Description Default
item_seq Tensor

Padded sequences of item IDs [batch_size, max_seq_len].

required
item_seq_len Tensor

Actual lengths of sequences [batch_size,].

required

Returns:

Name Type Description
Tensor Tensor

The embedding of the predicted item (last session state).

Source code in warprec/recommenders/sequential_recommender/lightsans.py
def forward(self, item_seq: Tensor, item_seq_len: Tensor) -> Tensor:
    """Forward pass of the LightSANs model.

    Args:
        item_seq (Tensor): Padded sequences of item IDs [batch_size, max_seq_len].
        item_seq_len (Tensor): Actual lengths of sequences [batch_size,].

    Returns:
        Tensor: The embedding of the predicted item (last session state).
    """
    seq_len = item_seq.size(1)

    # Padding mask (True where padding exists)
    padding_mask = item_seq == self.n_items

    # Get embeddings
    x = self.item_embedding(item_seq)
    x = self.layernorm(x)
    x = self.emb_dropout(x)

    # Get Position Embeddings for the current sequence length
    # These are shared across the batch
    position_ids = torch.arange(seq_len, dtype=torch.long, device=item_seq.device)
    pos_emb = self.position_embedding(position_ids)

    # Causal mask for the current sequence length
    # We slice the precomputed mask
    curr_causal_mask = self.causal_mask[:seq_len, :seq_len]  # type: ignore[index]

    # Pass through LightSANs Layers
    for layer in self.layers:
        x = layer(
            x,
            pos_emb=pos_emb,
            padding_mask=padding_mask,
            causal_mask=curr_causal_mask,
        )

    # Gather the output of the last relevant item in each sequence
    seq_output = self._gather_indexes(x, item_seq_len - 1)

    return seq_output

predict(user_seq, seq_len, *args, item_indices=None, **kwargs)

Prediction using the learned session embeddings.

Parameters:

Name Type Description Default
user_seq Tensor

Padded sequences of item IDs for users to predict for.

required
seq_len Tensor

Actual lengths of these sequences, before padding.

required
*args Any

List of arguments.

()
item_indices Optional[Tensor]

The batch of item indices. If None, full prediction will be produced.

None
**kwargs Any

The dictionary of keyword arguments.

{}

Returns:

Name Type Description
Tensor Tensor

The score matrix {user x item}.

Source code in warprec/recommenders/sequential_recommender/lightsans.py
def predict(
    self,
    user_seq: Tensor,
    seq_len: Tensor,
    *args: Any,
    item_indices: Optional[Tensor] = None,
    **kwargs: Any,
) -> Tensor:
    """
    Prediction using the learned session embeddings.

    Args:
        user_seq (Tensor): Padded sequences of item IDs for users to predict for.
        seq_len (Tensor): Actual lengths of these sequences, before padding.
        *args (Any): List of arguments.
        item_indices (Optional[Tensor]): The batch of item indices. If None,
            full prediction will be produced.
        **kwargs (Any): The dictionary of keyword arguments.

    Returns:
        Tensor: The score matrix {user x item}.
    """
    # Get sequence output embeddings
    seq_output = self.forward(user_seq, seq_len)  # [batch_size, embedding_size]

    if item_indices is None:
        # Case 'full': prediction on all items
        item_embeddings = self.item_embedding.weight[
            :-1, :
        ]  # [n_items, embedding_size]
        einsum_string = "be,ie->bi"  # b: batch, e: embedding, i: item
    else:
        # Case 'sampled': prediction on a sampled set of items
        item_embeddings = self.item_embedding(
            item_indices
        )  # [batch_size, pad_seq, embedding_size]
        einsum_string = "be,bse->bs"  # b: batch, e: embedding, s: sample

    predictions = torch.einsum(
        einsum_string, seq_output, item_embeddings
    )  # [batch_size, n_items] or [batch_size, pad_seq]
    return predictions

warprec.recommenders.sequential_recommender.linrec.LinRec

Bases: IterativeRecommender, SequentialRecommenderUtils

Implementation of LinRec algorithm from "LinRec: Linear Attention Mechanism for Long-term Sequential Recommender Systems" in SIGIR 2023.

LinRec replaces the quadratic Dot-Product Attention with an O(N) Linear Attention mechanism based on L2 Normalization and ELU activation.

Parameters:

Name Type Description Default
params dict

Model parameters.

required
info dict

The dictionary containing dataset information.

required
*args Any

Variable length argument list.

()
seed int

The seed to use for reproducibility.

42
**kwargs Any

Arbitrary keyword arguments.

{}

Attributes:

Name Type Description
DATALOADER_TYPE

The type of dataloader used.

embedding_size int

Item embedding dimensions.

n_layers int

The number of transformer encoder layers.

n_heads int

The number of attention heads in the transformer.

inner_size int

The dimensionality of the feed-forward layer in the transformer.

dropout_prob float

Dropout probability.

reg_weight float

The L2 regularization weight.

weight_decay float

The value of weight decay used in the optimizer.

batch_size int

The batch size used during training.

epochs int

The number of training epochs.

learning_rate float

The learning rate value.

neg_samples int

The number of negative samples.

max_seq_len int

The maximum length of sequences.

Source code in warprec/recommenders/sequential_recommender/linrec.py
@model_registry.register(name="LinRec")
class LinRec(IterativeRecommender, SequentialRecommenderUtils):
    """Implementation of LinRec algorithm from
        "LinRec: Linear Attention Mechanism for Long-term Sequential Recommender Systems" in SIGIR 2023.

    LinRec replaces the quadratic Dot-Product Attention with an O(N) Linear Attention
    mechanism based on L2 Normalization and ELU activation.

    Args:
        params (dict): Model parameters.
        info (dict): The dictionary containing dataset information.
        *args (Any): Variable length argument list.
        seed (int): The seed to use for reproducibility.
        **kwargs (Any): Arbitrary keyword arguments.

    Attributes:
        DATALOADER_TYPE: The type of dataloader used.
        embedding_size (int): Item embedding dimensions.
        n_layers (int): The number of transformer encoder layers.
        n_heads (int): The number of attention heads in the transformer.
        inner_size (int): The dimensionality of the feed-forward layer in the transformer.
        dropout_prob (float): Dropout probability.
        reg_weight (float): The L2 regularization weight.
        weight_decay (float): The value of weight decay used in the optimizer.
        batch_size (int): The batch size used during training.
        epochs (int): The number of training epochs.
        learning_rate (float): The learning rate value.
        neg_samples (int): The number of negative samples.
        max_seq_len (int): The maximum length of sequences.
    """

    DATALOADER_TYPE = DataLoaderType.SEQUENTIAL_LOADER

    embedding_size: int
    n_layers: int
    n_heads: int
    inner_size: int
    dropout_prob: float
    reg_weight: float
    weight_decay: float
    batch_size: int
    epochs: int
    learning_rate: float
    neg_samples: int
    max_seq_len: int

    def __init__(
        self,
        params: dict,
        info: dict,
        *args: Any,
        seed: int = 42,
        **kwargs: Any,
    ):
        super().__init__(params, info, *args, seed=seed, **kwargs)

        # Item Embeddings
        self.item_embedding = nn.Embedding(
            self.n_items + 1,
            self.embedding_size,
            padding_idx=self.n_items,
        )

        # Positional Embeddings
        self.position_embedding = nn.Embedding(self.max_seq_len, self.embedding_size)

        self.emb_dropout = nn.Dropout(self.dropout_prob)

        # Stack of LinRec Blocks (Transformer Layers)
        self.layers = nn.ModuleList(
            [
                LinRecBlock(
                    self.embedding_size,
                    self.n_heads,
                    self.inner_size,
                    self.dropout_prob,
                )
                for _ in range(self.n_layers)
            ]
        )

        self.layer_norm = nn.LayerNorm(self.embedding_size)

        self.apply(self._init_weights)

        # Loss function will be based on number of
        # negative samples
        self.main_loss: nn.Module
        if self.neg_samples > 0:
            self.main_loss = BPRLoss()
        else:
            self.main_loss = nn.CrossEntropyLoss()
        self.reg_loss = EmbLoss()

    def get_dataloader(
        self, interactions: Interactions, sessions: Sessions, **kwargs: Any
    ):
        return sessions.get_sequential_dataloader(
            max_seq_len=self.max_seq_len,
            neg_samples=self.neg_samples,
            batch_size=self.batch_size,
            **kwargs,
        )

    def forward(self, item_seq: Tensor, item_seq_len: Tensor) -> Tensor:
        # item_seq: [B, L]
        batch_size, seq_len = item_seq.shape

        # Create Position IDs: [0, 1, 2, ..., L-1]
        position_ids = (
            torch.arange(seq_len, dtype=torch.long, device=item_seq.device)
            .unsqueeze(0)
            .expand(batch_size, -1)
        )

        # Embedding Lookup: Item + Position
        items_emb = self.item_embedding(item_seq)
        pos_emb = self.position_embedding(position_ids)

        x = self.emb_dropout(items_emb + pos_emb)

        # Pass through stacked Transformer Blocks
        for layer in self.layers:
            x = layer(x)

        # Final Layer Norm
        seq_output = self.layer_norm(x)

        # Gather the representation at the last non-padded index
        # This represents the user's state after the full sequence
        seq_output = self._gather_indexes(seq_output, item_seq_len - 1)

        return seq_output

    def training_step(self, batch: Any, batch_idx: int):
        if self.neg_samples > 0:
            item_seq, item_seq_len, pos_item, neg_item = batch
        else:
            item_seq, item_seq_len, pos_item = batch
            neg_item = None

        # Standard Sequential slicing: Input is items [0...n-1], Target is [n]
        input_seq = item_seq[:, :-1]
        input_len = torch.clamp(item_seq_len - 1, min=1)

        seq_output = self.forward(input_seq, input_len)
        pos_items_emb = self.item_embedding(pos_item)

        if self.neg_samples > 0:
            neg_items_emb = self.item_embedding(neg_item)
            pos_score = torch.sum(seq_output * pos_items_emb, dim=-1)
            neg_score = torch.sum(seq_output.unsqueeze(1) * neg_items_emb, dim=-1)
            main_loss = self.main_loss(pos_score, neg_score)

            reg_loss = self.reg_weight * self.reg_loss(
                self.item_embedding(input_seq), pos_items_emb, neg_items_emb
            )
        else:
            # Full Softmax over all items
            test_item_emb = self.item_embedding.weight[:-1, :]
            logits = torch.matmul(seq_output, test_item_emb.transpose(0, 1))
            main_loss = self.main_loss(logits, pos_item)

            reg_loss = self.reg_weight * self.reg_loss(
                self.item_embedding(input_seq), pos_items_emb
            )

        # Loss logging
        loss = main_loss + reg_loss
        self.log("loss", loss, prog_bar=True, on_step=False, on_epoch=True)
        return loss

    def predict(
        self,
        user_seq: Tensor,
        seq_len: Tensor,
        *args: Any,
        item_indices: Optional[Tensor] = None,
        **kwargs: Any,
    ) -> Tensor:
        """
        Prediction using the learned session embeddings.

        Args:
            user_seq (Tensor): Padded sequences of item IDs for users to predict for.
            seq_len (Tensor): Actual lengths of these sequences, before padding.
            *args (Any): List of arguments.
            item_indices (Optional[Tensor]): The batch of item indices. If None,
                full prediction will be produced.
            **kwargs (Any): The dictionary of keyword arguments.

        Returns:
            Tensor: The score matrix {user x item}.
        """
        # Get sequence output embeddings
        seq_output = self.forward(user_seq, seq_len)  # [batch_size, embedding_size]

        if item_indices is None:
            # Case 'full': prediction on all items
            item_embeddings = self.item_embedding.weight[
                :-1, :
            ]  # [n_items, embedding_size]
            einsum_string = "be,ie->bi"  # b: batch, e: embedding, i: item
        else:
            # Case 'sampled': prediction on a sampled set of items
            item_embeddings = self.item_embedding(
                item_indices
            )  # [batch_size, pad_seq, embedding_size]
            einsum_string = "be,bse->bs"  # b: batch, e: embedding, s: sample

        predictions = torch.einsum(
            einsum_string, seq_output, item_embeddings
        )  # [batch_size, n_items] or [batch_size, pad_seq]
        return predictions

predict(user_seq, seq_len, *args, item_indices=None, **kwargs)

Prediction using the learned session embeddings.

Parameters:

Name Type Description Default
user_seq Tensor

Padded sequences of item IDs for users to predict for.

required
seq_len Tensor

Actual lengths of these sequences, before padding.

required
*args Any

List of arguments.

()
item_indices Optional[Tensor]

The batch of item indices. If None, full prediction will be produced.

None
**kwargs Any

The dictionary of keyword arguments.

{}

Returns:

Name Type Description
Tensor Tensor

The score matrix {user x item}.

Source code in warprec/recommenders/sequential_recommender/linrec.py
def predict(
    self,
    user_seq: Tensor,
    seq_len: Tensor,
    *args: Any,
    item_indices: Optional[Tensor] = None,
    **kwargs: Any,
) -> Tensor:
    """
    Prediction using the learned session embeddings.

    Args:
        user_seq (Tensor): Padded sequences of item IDs for users to predict for.
        seq_len (Tensor): Actual lengths of these sequences, before padding.
        *args (Any): List of arguments.
        item_indices (Optional[Tensor]): The batch of item indices. If None,
            full prediction will be produced.
        **kwargs (Any): The dictionary of keyword arguments.

    Returns:
        Tensor: The score matrix {user x item}.
    """
    # Get sequence output embeddings
    seq_output = self.forward(user_seq, seq_len)  # [batch_size, embedding_size]

    if item_indices is None:
        # Case 'full': prediction on all items
        item_embeddings = self.item_embedding.weight[
            :-1, :
        ]  # [n_items, embedding_size]
        einsum_string = "be,ie->bi"  # b: batch, e: embedding, i: item
    else:
        # Case 'sampled': prediction on a sampled set of items
        item_embeddings = self.item_embedding(
            item_indices
        )  # [batch_size, pad_seq, embedding_size]
        einsum_string = "be,bse->bs"  # b: batch, e: embedding, s: sample

    predictions = torch.einsum(
        einsum_string, seq_output, item_embeddings
    )  # [batch_size, n_items] or [batch_size, pad_seq]
    return predictions

warprec.recommenders.sequential_recommender.sasrec.SASRec

Bases: IterativeRecommender, SequentialRecommenderUtils

Implementation of SASRec algorithm from "Self-Attentive Sequential Recommendation." in ICDM 2018.

This implementation is adapted to the WarpRec framework, using PyTorch's native nn.TransformerEncoder for the self-attention mechanism.

Parameters:

Name Type Description Default
params dict

Model parameters.

required
info dict

The dictionary containing dataset information.

required
*args Any

Variable length argument list.

()
seed int

The seed to use for reproducibility.

42
**kwargs Any

Arbitrary keyword arguments.

{}

Attributes:

Name Type Description
DATALOADER_TYPE

The type of dataloader used.

embedding_size int

The dimension of the item embeddings (hidden_size).

n_layers int

The number of transformer encoder layers.

n_heads int

The number of attention heads in the transformer.

inner_size int

The dimensionality of the feed-forward layer in the transformer.

dropout_prob float

The probability of dropout for embeddings and other layers.

attn_dropout_prob float

The probability of dropout for the attention weights.

reg_weight float

The L2 regularization weight.

weight_decay float

The value of weight decay used in the optimizer.

batch_size int

The batch size used during training.

epochs int

The number of training epochs.

learning_rate float

The learning rate value.

neg_samples int

The number of negative samples.

max_seq_len int

The maximum length of sequences.

Source code in warprec/recommenders/sequential_recommender/sasrec.py
@model_registry.register(name="SASRec")
class SASRec(IterativeRecommender, SequentialRecommenderUtils):
    """Implementation of SASRec algorithm from
    "Self-Attentive Sequential Recommendation." in ICDM 2018.

    This implementation is adapted to the WarpRec framework, using PyTorch's
    native nn.TransformerEncoder for the self-attention mechanism.

    Args:
        params (dict): Model parameters.
        info (dict): The dictionary containing dataset information.
        *args (Any): Variable length argument list.
        seed (int): The seed to use for reproducibility.
        **kwargs (Any): Arbitrary keyword arguments.

    Attributes:
        DATALOADER_TYPE: The type of dataloader used.
        embedding_size (int): The dimension of the item embeddings (hidden_size).
        n_layers (int): The number of transformer encoder layers.
        n_heads (int): The number of attention heads in the transformer.
        inner_size (int): The dimensionality of the feed-forward layer in the transformer.
        dropout_prob (float): The probability of dropout for embeddings and other layers.
        attn_dropout_prob (float): The probability of dropout for the attention weights.
        reg_weight (float): The L2 regularization weight.
        weight_decay (float): The value of weight decay used in the optimizer.
        batch_size (int): The batch size used during training.
        epochs (int): The number of training epochs.
        learning_rate (float): The learning rate value.
        neg_samples (int): The number of negative samples.
        max_seq_len (int): The maximum length of sequences.
    """

    # Dataloader definition
    DATALOADER_TYPE = DataLoaderType.SEQUENTIAL_LOADER

    # Model hyperparameters
    embedding_size: int
    n_layers: int
    n_heads: int
    inner_size: int
    dropout_prob: float
    attn_dropout_prob: float
    reg_weight: float
    weight_decay: float
    batch_size: int
    epochs: int
    learning_rate: float
    neg_samples: int
    max_seq_len: int

    def __init__(
        self,
        params: dict,
        info: dict,
        *args: Any,
        seed: int = 42,
        **kwargs: Any,
    ):
        super().__init__(params, info, *args, seed=seed, **kwargs)

        self.item_embedding = nn.Embedding(
            self.n_items + 1, self.embedding_size, padding_idx=self.n_items
        )
        self.position_embedding = nn.Embedding(self.max_seq_len, self.embedding_size)
        self.emb_dropout = nn.Dropout(self.dropout_prob)
        self.layernorm = nn.LayerNorm(self.embedding_size, eps=1e-8)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=self.embedding_size,
            nhead=self.n_heads,
            dim_feedforward=self.inner_size,
            dropout=self.attn_dropout_prob,
            activation="gelu",  # GELU is a common choice in Transformers
            batch_first=True,  # Input tensors are (batch, seq_len, features)
            norm_first=False,  # Following the original Transformer paper (post-LN)
        )
        self.transformer_encoder = nn.TransformerEncoder(
            encoder_layer,
            num_layers=self.n_layers,
        )

        # Precompute causal mask
        causal_mask = self._generate_square_subsequent_mask(self.max_seq_len)
        self.register_buffer("causal_mask", causal_mask)

        # Initialize weights
        self.apply(self._init_weights)

        # Loss function will be based on number of
        # negative samples
        self.main_loss: nn.Module
        if self.neg_samples > 0:
            self.main_loss = BPRLoss()
        else:
            self.main_loss = nn.CrossEntropyLoss()
        self.reg_loss = EmbLoss()

    def get_dataloader(
        self,
        interactions: Interactions,
        sessions: Sessions,
        **kwargs: Any,
    ):
        return sessions.get_sequential_dataloader(
            max_seq_len=self.max_seq_len,
            neg_samples=self.neg_samples,
            batch_size=self.batch_size,
            **kwargs,
        )

    def training_step(self, batch: Any, batch_idx: int):
        if self.neg_samples > 0:
            item_seq, item_seq_len, pos_item, neg_item = batch
        else:
            item_seq, item_seq_len, pos_item = batch
            neg_item = None

        seq_output = self.forward(item_seq, item_seq_len)

        # Calculate main loss and L2 regularization
        if self.neg_samples > 0:
            pos_items_emb = self.item_embedding(
                pos_item
            )  # [batch_size, embedding_size]
            neg_items_emb = self.item_embedding(
                neg_item
            )  # [batch_size, embedding_size]

            pos_score = torch.sum(seq_output * pos_items_emb, dim=-1)  # [batch_size]
            neg_score = torch.sum(
                seq_output.unsqueeze(1) * neg_items_emb, dim=-1
            )  # [batch_size]
            main_loss = self.main_loss(pos_score, neg_score)

            # L2 regularization
            reg_loss = self.reg_weight * self.reg_loss(
                self.item_embedding(item_seq),
                self.item_embedding(pos_item),
                self.item_embedding(neg_item),
            )
        else:
            test_item_emb = self.item_embedding.weight  # [n_items, embedding_size]
            logits = torch.matmul(
                seq_output, test_item_emb.transpose(0, 1)
            )  # [batch_size, n_items]
            main_loss = self.main_loss(logits, pos_item)

            # L2 regularization
            reg_loss = self.reg_weight * self.reg_loss(
                self.item_embedding(item_seq),
                self.item_embedding(pos_item),
            )

        # Loss logging
        loss = main_loss + reg_loss
        self.log("loss", loss, prog_bar=True, on_step=False, on_epoch=True)
        return loss

    def forward(self, item_seq: Tensor, item_seq_len: Tensor) -> Tensor:
        """Forward pass of the SASRec model.

        Args:
            item_seq (Tensor): Padded sequences of item IDs [batch_size, max_seq_len].
            item_seq_len (Tensor): Actual lengths of sequences [batch_size,].

        Returns:
            Tensor: The embedding of the predicted item (last session state)
                    [batch_size, embedding_size].
        """
        seq_len = item_seq.size(1)

        # Padding mask to ignore padding tokens
        padding_mask = item_seq == self.n_items  # [batch_size, seq_len]

        # Create position IDs
        position_ids = torch.arange(seq_len, dtype=torch.long).to(item_seq.device)
        position_ids = position_ids.unsqueeze(0).expand_as(item_seq)

        # Get embeddings
        item_emb = self.item_embedding(item_seq)
        pos_emb = self.position_embedding(position_ids)

        # Combine embeddings and apply LayerNorm + Dropout
        seq_emb = self.layernorm(item_emb + pos_emb)
        seq_emb = self.emb_dropout(seq_emb)

        # Pass through Transformer Encoder
        transformer_output = self.transformer_encoder(
            src=seq_emb,
            mask=self.causal_mask,
            src_key_padding_mask=padding_mask,
        )  # [batch_size, max_seq_len, embedding_size]

        # Gather the output of the last relevant item in each sequence
        seq_output = self._gather_indexes(
            transformer_output, item_seq_len - 1
        )  # [batch_size, embedding_size]

        return seq_output

    def predict(
        self,
        user_seq: Tensor,
        seq_len: Tensor,
        *args: Any,
        item_indices: Optional[Tensor] = None,
        **kwargs: Any,
    ) -> Tensor:
        """
        Prediction using the learned session embeddings.

        Args:
            user_seq (Tensor): Padded sequences of item IDs for users to predict for.
            seq_len (Tensor): Actual lengths of these sequences, before padding.
            *args (Any): List of arguments.
            item_indices (Optional[Tensor]): The batch of item indices. If None,
                full prediction will be produced.
            **kwargs (Any): The dictionary of keyword arguments.

        Returns:
            Tensor: The score matrix {user x item}.
        """
        # Get sequence output embeddings
        seq_output = self.forward(user_seq, seq_len)  # [batch_size, embedding_size]

        if item_indices is None:
            # Case 'full': prediction on all items
            item_embeddings = self.item_embedding.weight[
                :-1, :
            ]  # [n_items, embedding_size]
            einsum_string = "be,ie->bi"  # b: batch, e: embedding, i: item
        else:
            # Case 'sampled': prediction on a sampled set of items
            item_embeddings = self.item_embedding(
                item_indices
            )  # [batch_size, pad_seq, embedding_size]
            einsum_string = "be,bse->bs"  # b: batch, e: embedding, s: sample

        predictions = torch.einsum(
            einsum_string, seq_output, item_embeddings
        )  # [batch_size, n_items] or [batch_size, pad_seq]
        return predictions

forward(item_seq, item_seq_len)

Forward pass of the SASRec model.

Parameters:

Name Type Description Default
item_seq Tensor

Padded sequences of item IDs [batch_size, max_seq_len].

required
item_seq_len Tensor

Actual lengths of sequences [batch_size,].

required

Returns:

Name Type Description
Tensor Tensor

The embedding of the predicted item (last session state) [batch_size, embedding_size].

Source code in warprec/recommenders/sequential_recommender/sasrec.py
def forward(self, item_seq: Tensor, item_seq_len: Tensor) -> Tensor:
    """Forward pass of the SASRec model.

    Args:
        item_seq (Tensor): Padded sequences of item IDs [batch_size, max_seq_len].
        item_seq_len (Tensor): Actual lengths of sequences [batch_size,].

    Returns:
        Tensor: The embedding of the predicted item (last session state)
                [batch_size, embedding_size].
    """
    seq_len = item_seq.size(1)

    # Padding mask to ignore padding tokens
    padding_mask = item_seq == self.n_items  # [batch_size, seq_len]

    # Create position IDs
    position_ids = torch.arange(seq_len, dtype=torch.long).to(item_seq.device)
    position_ids = position_ids.unsqueeze(0).expand_as(item_seq)

    # Get embeddings
    item_emb = self.item_embedding(item_seq)
    pos_emb = self.position_embedding(position_ids)

    # Combine embeddings and apply LayerNorm + Dropout
    seq_emb = self.layernorm(item_emb + pos_emb)
    seq_emb = self.emb_dropout(seq_emb)

    # Pass through Transformer Encoder
    transformer_output = self.transformer_encoder(
        src=seq_emb,
        mask=self.causal_mask,
        src_key_padding_mask=padding_mask,
    )  # [batch_size, max_seq_len, embedding_size]

    # Gather the output of the last relevant item in each sequence
    seq_output = self._gather_indexes(
        transformer_output, item_seq_len - 1
    )  # [batch_size, embedding_size]

    return seq_output

predict(user_seq, seq_len, *args, item_indices=None, **kwargs)

Prediction using the learned session embeddings.

Parameters:

Name Type Description Default
user_seq Tensor

Padded sequences of item IDs for users to predict for.

required
seq_len Tensor

Actual lengths of these sequences, before padding.

required
*args Any

List of arguments.

()
item_indices Optional[Tensor]

The batch of item indices. If None, full prediction will be produced.

None
**kwargs Any

The dictionary of keyword arguments.

{}

Returns:

Name Type Description
Tensor Tensor

The score matrix {user x item}.

Source code in warprec/recommenders/sequential_recommender/sasrec.py
def predict(
    self,
    user_seq: Tensor,
    seq_len: Tensor,
    *args: Any,
    item_indices: Optional[Tensor] = None,
    **kwargs: Any,
) -> Tensor:
    """
    Prediction using the learned session embeddings.

    Args:
        user_seq (Tensor): Padded sequences of item IDs for users to predict for.
        seq_len (Tensor): Actual lengths of these sequences, before padding.
        *args (Any): List of arguments.
        item_indices (Optional[Tensor]): The batch of item indices. If None,
            full prediction will be produced.
        **kwargs (Any): The dictionary of keyword arguments.

    Returns:
        Tensor: The score matrix {user x item}.
    """
    # Get sequence output embeddings
    seq_output = self.forward(user_seq, seq_len)  # [batch_size, embedding_size]

    if item_indices is None:
        # Case 'full': prediction on all items
        item_embeddings = self.item_embedding.weight[
            :-1, :
        ]  # [n_items, embedding_size]
        einsum_string = "be,ie->bi"  # b: batch, e: embedding, i: item
    else:
        # Case 'sampled': prediction on a sampled set of items
        item_embeddings = self.item_embedding(
            item_indices
        )  # [batch_size, pad_seq, embedding_size]
        einsum_string = "be,bse->bs"  # b: batch, e: embedding, s: sample

    predictions = torch.einsum(
        einsum_string, seq_output, item_embeddings
    )  # [batch_size, n_items] or [batch_size, pad_seq]
    return predictions

warprec.recommenders.sequential_recommender.stan.STAN

Bases: Recommender, SequentialRecommenderUtils

Implementation of STAN model from Sequence and Time Aware Neighborhood for Session-basedRecommendations (SIGIR'19).

Parameters:

Name Type Description Default
params dict

Model parameters.

required
info dict

The dictionary containing dataset information.

required
sessions Sessions

Training sessions — the primary data source; we pull (flat_items, flat_users, user_offsets, timestamps) from it.

required
*args Any

Variable length argument list.

()
seed int

The seed to use for reproducibility.

42
**kwargs Any

Arbitrary keyword arguments.

{}

Attributes:

Name Type Description
k int

Neighborhood size N (Section 4).

lambda_1 float

Eq. 3 decay.

lambda_2 float

Eq. 5 decay (seconds).

lambda_3 float

Eq. 7 decay.

max_seq_len int

Upper bound on current-session length (controls how much history the evaluator feeds to predict).

Source code in warprec/recommenders/sequential_recommender/stan.py
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
@model_registry.register(name="STAN")
class STAN(Recommender, SequentialRecommenderUtils):
    """Implementation of STAN model from
        Sequence and Time Aware Neighborhood for Session-basedRecommendations (SIGIR'19).

    Args:
        params (dict): Model parameters.
        info (dict): The dictionary containing dataset information.
        sessions (Sessions): Training sessions — the primary data source; we
            pull (flat_items, flat_users, user_offsets, timestamps) from it.
        *args (Any): Variable length argument list.
        seed (int): The seed to use for reproducibility.
        **kwargs (Any): Arbitrary keyword arguments.

    Attributes:
        k (int): Neighborhood size N (Section 4).
        lambda_1 (float): Eq. 3 decay.
        lambda_2 (float): Eq. 5 decay (seconds).
        lambda_3 (float): Eq. 7 decay.
        max_seq_len (int): Upper bound on current-session length (controls how much
            history the evaluator feeds to predict).
    """

    # Model hyperparameters
    k: int
    lambda_1: float
    lambda_2: float
    lambda_3: float
    max_seq_len: int

    @classmethod
    def estimate_space(
        cls,
        params: dict,
        info: dict,
        interactions: Optional[Interactions] = None,
        **kwargs: Any,
    ) -> dict:
        sessions = kwargs.get("sessions")
        if sessions is None:
            raise ValueError("STAN requires sessions to estimate space.")

        n_users = info["n_users"]
        n_items = info["n_items"]
        n_events = len(sessions._flat_items)
        nnz = interactions.get_sparse().nnz if interactions is not None else n_events

        flat_items_mb = cls._bytes_to_mb(sessions._flat_items.nbytes)
        flat_users_mb = cls._bytes_to_mb(sessions._flat_users.nbytes)
        user_offsets_mb = cls._bytes_to_mb(sessions._user_offsets.nbytes)
        session_timestamps_mb = (
            cls._dense_size_mb((n_users,), np.float64)
            if sessions.timestamp_label in sessions._inter_df.columns
            else 0.0
        )
        session_item_coo_mb = cls._bytes_to_mb(nnz * (8 + 4 + 4))
        session_item_csr_mb = cls._compressed_sparse_size_mb(
            nnz=nnz,
            ptr_len=n_users + 1,
            data_dtype=np.float64,
        )
        session_item_csc_mb = cls._compressed_sparse_size_mb(
            nnz=nnz,
            ptr_len=n_items + 1,
            data_dtype=np.float64,
        )
        unique_counts_mb = cls._dense_size_mb((n_users,), np.float64)

        resident_mb = (
            flat_items_mb
            + flat_users_mb
            + user_offsets_mb
            + session_timestamps_mb
            + session_item_csc_mb
            + session_item_csr_mb
            + unique_counts_mb
        )
        train_ram_mb = cls._peak_size_mb(resident_mb, resident_mb + session_item_coo_mb)
        return {
            "train_ram_mb": train_ram_mb,
            "notes": "STAN analytical train-space estimate",
        }

    def __init__(
        self,
        params: dict,
        info: dict,
        sessions: Sessions,
        *args: Any,
        seed: int = 42,
        **kwargs: Any,
    ):
        super().__init__(params, info, *args, seed=seed, **kwargs)

        # -----------------------------------------------------------------
        # Persist training session information (pre-computed, read-only)
        # -----------------------------------------------------------------
        # We treat each user's history as one past session
        # (WarpRec's session-based modelling convention).
        self._flat_items = sessions._flat_items.astype(np.int64)
        self._flat_users = sessions._flat_users.astype(np.int64)
        self._user_offsets = sessions._user_offsets.astype(np.int64)

        # -----------------------------------------------------------------
        # Session-level timestamp t(s_j) (most-recent event per session)
        # -----------------------------------------------------------------
        # Paper Eq. 5: w2 uses t(s) - t(s_j) where t(.) is "the timestamp of
        # the most recent item i_{j,l(s_j)} in s_j".
        self._session_timestamps = self._extract_session_timestamps(sessions)

        # -----------------------------------------------------------------
        # Binary session-by-item matrices (CSR and its CSC transpose)
        # -----------------------------------------------------------------
        # The CSR lets us compute Eq. 4's numerator for every candidate in a
        # single sparse matmul (P_cand @ s_w). The CSC is the paper's
        # inverted index from Section 5.3 footnote — the column slice for
        # item i gives the session-ids that contain i.
        csr, csc = self._build_session_item_matrices(
            self._flat_items, self._flat_users, self.n_users, self.n_items
        )
        self._session_item_csr = csr
        self._session_item_csc = csc

        # Eq. 4 denominator: |s_j| on the binary session vector = number of
        # unique items in s_j. Precomputed once here.
        self._session_unique_counts = np.asarray(
            self._session_item_csr.getnnz(axis=1), dtype=np.float64
        )

    @staticmethod
    def _extract_session_timestamps(sessions: Sessions) -> Optional[np.ndarray]:
        """Extract t(s_j) for each session (= each user in WarpRec).

        Returns an array of shape ``[n_users]`` where entry u holds the max
        timestamp observed for user u, or ``None`` if the Sessions entity has
        no timestamp column.
        """
        ts_label = sessions.timestamp_label
        df = sessions._get_processed_data()  # cached, mapped, sorted
        if ts_label not in df.columns:
            return None

        # The df is sorted by (user, timestamp) — so the last row of each
        # user's slice holds t(s_j). One vectorized gather replaces a
        # per-user Python loop.
        ts_all = df.select(nw.col(ts_label)).to_numpy().flatten().astype(np.float64)
        offsets = sessions._user_offsets
        n_users = len(offsets) - 1
        if ts_all.size == 0:
            return np.zeros(n_users, dtype=np.float64)

        lengths = offsets[1:] - offsets[:-1]
        # Clip so that empty-user entries (length == 0) land on a valid slot;
        # the np.where below overwrites them with 0 anyway.
        last_idx = np.clip(offsets[1:] - 1, 0, ts_all.size - 1)
        return np.where(lengths > 0, ts_all[last_idx], 0.0).astype(np.float64)

    @staticmethod
    def _build_session_item_matrices(
        flat_items: np.ndarray,
        flat_users: np.ndarray,
        n_users: int,
        n_items: int,
    ) -> tuple:
        """Build the binary session-by-item CSR and its CSC transpose.

        A session's vector in Eq. 4 is binary, so interaction-stream
        duplicates collapse to 1. This replaces the old per-interaction
        Python loop with a single COO → CSR conversion.

        Args:
            flat_items (np.ndarray): Flat array of item ids (one entry per interaction).
            flat_users (np.ndarray): Corresponding session (= user) ids, same length.
            n_users (int): Total number of sessions (CSR rows).
            n_items (int): Total number of items (CSR cols).

        Returns:
            tuple: (csr, csc) with shapes [n_users, n_items], binary.
        """
        # Defensive guard against out-of-range items (mirrors the old build).
        valid = (flat_items >= 0) & (flat_items < n_items)
        rows = flat_users[valid]
        cols = flat_items[valid]
        data = np.ones(rows.shape[0], dtype=np.float64)

        coo = coo_matrix((data, (rows, cols)), shape=(n_users, n_items))
        csr = coo.tocsr()
        csr.sum_duplicates()
        # Binarize: an item appears in a session either 0 or 1 times.
        csr.data = np.ones_like(csr.data)
        csc = csr.tocsc()
        return csr, csc

    def _get_user_items(self, user_id: int) -> np.ndarray:
        """Return the item sequence of user ``user_id`` in chronological order."""
        start, end = (
            int(self._user_offsets[user_id]),
            int(self._user_offsets[user_id + 1]),
        )
        return self._flat_items[start:end]

    def _compute_w1(self, current_seq: np.ndarray) -> np.ndarray:
        """Eq. 3 — w1(i, s) = exp((p(i, s) - l(s)) / lambda_1).

        Produces the real-valued weight vector for the current session. The
        last position has weight 1; earlier positions decay exponentially.
        """
        length = current_seq.shape[0]
        if length == 0:
            return np.empty(0, dtype=np.float64)
        # positions are 1-indexed in the paper (p(i, s) in [1, l(s)]).
        positions = np.arange(1, length + 1, dtype=np.float64)
        return np.exp((positions - length) / float(self.lambda_1))

    def _compute_w2(
        self,
        current_ts: float,
        neighbor_ids: np.ndarray,
    ) -> np.ndarray:
        """Eq. 5 — w2(s_j | s) = exp(-(t(s) - t(s_j)) / lambda_2).

        Returns a vector of shape ``[len(neighbor_ids)]``.
        """
        if self._session_timestamps is None:
            # ASSUMPTION: dataset without timestamps -> disable Factor-2.
            return np.ones(neighbor_ids.shape[0], dtype=np.float64)

        ts_neighbors = self._session_timestamps[neighbor_ids]
        # Paper states t(s) > t(s_j); we clip negatives to 0 so that a
        # neighbor session with ts == current_ts gets weight 1 (rather than
        # blowing up via a negative delta). This is consistent with the
        # paper's intention to "decay" the past.
        delta = np.clip(current_ts - ts_neighbors, a_min=0.0, a_max=None)
        return np.exp(-delta / float(self.lambda_2))

    def _compute_w3_and_items(
        self,
        current_items: np.ndarray,
        neighbor_seq: np.ndarray,
    ) -> tuple:
        """Eq. 7 — w3(i | s, n) = exp(-|p(i, n) - p(i*, n)| / lambda_3) * I_n(i).

        Here i* is the co-occurring item between s and n that is
        most recent in s (paper: "the item i* that occurs in both s and n,
        and is most recent w.r.t. s").

        Args:
            current_items (np.ndarray): Item sequence of the current session s.
            neighbor_seq (np.ndarray): Item sequence of the neighbor session n.

        Returns:
            tuple: (items, weights) — arrays of equal length holding the
                recommendable items in n and their per-item Eq. 7 weights.
                Returns empty arrays if no co-occurring item exists (in which
                case the neighbor contributes nothing, consistent with I_n(.)
                being zero for all items).
        """
        if neighbor_seq.size == 0 or current_items.size == 0:
            return np.empty(0, dtype=np.int64), np.empty(0, dtype=np.float64)

        # Find i* = most recent item in current_items that also appears in
        # neighbor_seq. Vectorized "in" check + take the last True index.
        mask = np.isin(current_items, neighbor_seq)
        hit = np.flatnonzero(mask)
        if hit.size == 0:
            # No co-occurring item -> I_n(i*) would be 0. Drop the neighbor.
            return np.empty(0, dtype=np.int64), np.empty(0, dtype=np.float64)
        i_star = int(current_items[hit[-1]])

        # Position of i* inside the neighbor session (1-indexed, paper
        # convention). If i* appears multiple times in n, use the LAST
        # occurrence.
        # ASSUMPTION: The paper defines p(i*, n) for a single position but
        # does not state which occurrence to use if i* repeats inside n.
        # Using the latest occurrence is consistent with "most recent" being
        # the operative choice throughout the paper (cf. Factor-1). We scan
        # the reversed array to avoid materializing a full index list.
        rev_hit = int((neighbor_seq[::-1] == i_star).argmax())
        pos_i_star = neighbor_seq.shape[0] - rev_hit

        # Eq. 7 for every item in n (I_n(i) == 1 here by construction because
        # we iterate items that are in n).
        positions = np.arange(1, neighbor_seq.shape[0] + 1, dtype=np.float64)
        weights = np.exp(-np.abs(positions - float(pos_i_star)) / float(self.lambda_3))
        return neighbor_seq.astype(np.int64), weights

    def _score_session(
        self,
        current_items: np.ndarray,
        current_ts: float,
        scores: np.ndarray,
    ) -> np.ndarray:
        """Produce the STAN score vector over all items for one session.

        Implements Eqs. 4, 6, 8, 9 of the paper. Writes into the provided
        scores buffer (zeroed on entry) and returns it.

        Args:
            current_items (np.ndarray): Ordered item ids of the current
                session s (length l(s)).
            current_ts (float): t(s) — most-recent timestamp in s.
            scores (np.ndarray): Pre-allocated buffer of shape [n_items]
                that this method fills in-place.

        Returns:
            np.ndarray: The same scores buffer, now holding the score
                vector of shape [n_items].
        """
        scores.fill(0.0)
        length = current_items.shape[0]
        if length == 0:
            # Empty session — nothing to match. Return a zero vector; the
            # evaluator will still produce a ranking, equivalent to random
            # tie-breaking on an unseen user.
            return scores

        # ---- Eq. 3 & Eq. 4 ---------------------------------------------
        # s_w on the full item space. Duplicates in s accumulate their
        # Eq. 3 weights correctly via np.add.at.
        w1_vec = self._compute_w1(current_items)
        s_w = np.zeros(self.n_items, dtype=np.float64)
        np.add.at(s_w, current_items, w1_vec)

        unique_items = np.unique(current_items)

        # ---- Candidate neighbor set via inverted index -----------------
        # Paper Section 5.3 footnote: inverted index of item -> sessions.
        # Union across unique items in s via CSC column slices, then dedup.
        indptr = self._session_item_csc.indptr
        indices = self._session_item_csc.indices
        pieces = [indices[indptr[i] : indptr[i + 1]] for i in unique_items.tolist()]
        cand_ids = np.unique(np.concatenate(pieces))
        if cand_ids.size == 0:
            return scores

        # ---- Eq. 4 (sim1) vectorized over all candidates ---------------
        # Numerator:   sim1_num(s, s_j) = s_w . s_j = (P @ s_w)[j]  (binary s_j)
        # Denominator: sqrt(l(s)) * sqrt(l(s_j))   (unique cardinalities)
        p_cand = self._session_item_csr[cand_ids]
        dot = np.asarray(p_cand.dot(s_w)).ravel()
        cand_lengths = self._session_unique_counts[cand_ids]
        denom = float(np.sqrt(float(unique_items.shape[0]))) * np.sqrt(cand_lengths)
        sim1_scores = np.zeros(cand_ids.shape[0], dtype=np.float64)
        safe = denom > 0.0
        sim1_scores[safe] = dot[safe] / denom[safe]

        # ---- Eq. 5 & Eq. 6 (sim2) --------------------------------------
        w2_vec = self._compute_w2(current_ts, cand_ids)
        sim2_scores = sim1_scores * w2_vec

        # ---- Top-N neighborhood selection ------------------------------
        # Paper: "The neighborhood N(s) is then found by taking the top N
        # most similar sessions using the similarity measure sim2".
        n_nbrs = min(self.k, sim2_scores.shape[0])
        if n_nbrs <= 0:
            return scores
        # argpartition is O(n) vs. O(n log n) for argsort. We then refine.
        top_idx = np.argpartition(-sim2_scores, n_nbrs - 1)[:n_nbrs]
        # Drop neighbors with zero similarity (they cannot contribute).
        top_idx = top_idx[sim2_scores[top_idx] > 0.0]

        # ---- Eq. 7 & Eq. 8 & Eq. 9 -------------------------------------
        # scoreSTAN(i, s) = sum_{n in N(s)} sim2(s, n) * w3(i | s, n)
        for idx in top_idx:
            n_id = int(cand_ids[idx])
            n_seq = self._get_user_items(n_id)
            if n_seq.size == 0:
                continue
            items, w3_vec = self._compute_w3_and_items(current_items, n_seq)
            if items.size == 0:
                continue
            sim2_val = float(sim2_scores[idx])
            # np.add.at handles duplicates inside the neighbor session
            # correctly — Eq. 8/9 summation.
            np.add.at(scores, items, sim2_val * w3_vec)

        return scores

    def predict(
        self,
        user_indices: Tensor,
        *args: Any,
        user_seq: Optional[Tensor] = None,
        seq_len: Optional[Tensor] = None,
        item_indices: Optional[Tensor] = None,
        **kwargs: Any,
    ) -> Tensor:
        """Compute STAN scores for a batch of current sessions.

        Args:
            user_indices (Tensor): User indices for which to produce scores.
            *args (Any): List of arguments.
            user_seq (Optional[Tensor]): Padded sequences of item IDs for users to predict for.
            seq_len (Optional[Tensor]): Actual lengths of these sequences, before padding.
            item_indices (Optional[Tensor]): The batch of item indices. If None,
                full prediction will be produced.
            **kwargs (Any): The dictionary of keyword arguments.

        Returns:
            Tensor: Score matrix [batch_size, n_items] or
                [batch_size, n_samples].
        """
        batch_size = int(user_indices.shape[0])
        # Output allocated directly as float32 (the returned dtype); halves
        # memory bandwidth versus building a float64 copy first.
        all_scores = np.zeros((batch_size, self.n_items), dtype=np.float32)
        # Reusable float64 accumulation buffer — keeps Eq. 8/9 precision.
        scores_buf = np.zeros(self.n_items, dtype=np.float64)
        padding_idx = self.n_items

        user_indices_cpu = user_indices.detach().cpu().tolist()
        if user_seq is not None and seq_len is not None:
            user_seq_np = user_seq.detach().cpu().numpy()
            seq_len_np = seq_len.detach().cpu().numpy()
        else:
            user_seq_np = None
            seq_len_np = None

        for b, user_id in enumerate(user_indices_cpu):
            user_id = int(user_id)

            # Recover current session items (most-recent first order
            # preserved — the evaluator already truncates to max_seq_len).
            if user_seq_np is not None and seq_len_np is not None:
                real_len = int(seq_len_np[b])
                row = user_seq_np[b, :real_len]
                # Strip any padding ids defensively.
                current_items = row[row != padding_idx].astype(np.int64)
            else:
                current_items = self._get_user_items(user_id).astype(np.int64)

            # Current session timestamp = max timestamp observed for this
            # user (matches paper's t(s) definition). Uses cached array.
            if self._session_timestamps is not None:
                current_ts = float(self._session_timestamps[user_id])
            else:
                current_ts = 0.0  # disables Factor-2 (see _compute_w2)

            self._score_session(current_items, current_ts, scores_buf)
            all_scores[b, :] = scores_buf  # implicit float64 -> float32 cast

        predictions = torch.from_numpy(all_scores)

        if item_indices is None:
            return predictions  # [batch_size, n_items]

        # Sampled evaluation — gather only requested item slots.
        return predictions.gather(
            1,
            item_indices.to(predictions.device).clamp(max=self.n_items - 1),
        )

predict(user_indices, *args, user_seq=None, seq_len=None, item_indices=None, **kwargs)

Compute STAN scores for a batch of current sessions.

Parameters:

Name Type Description Default
user_indices Tensor

User indices for which to produce scores.

required
*args Any

List of arguments.

()
user_seq Optional[Tensor]

Padded sequences of item IDs for users to predict for.

None
seq_len Optional[Tensor]

Actual lengths of these sequences, before padding.

None
item_indices Optional[Tensor]

The batch of item indices. If None, full prediction will be produced.

None
**kwargs Any

The dictionary of keyword arguments.

{}

Returns:

Name Type Description
Tensor Tensor

Score matrix [batch_size, n_items] or [batch_size, n_samples].

Source code in warprec/recommenders/sequential_recommender/stan.py
def predict(
    self,
    user_indices: Tensor,
    *args: Any,
    user_seq: Optional[Tensor] = None,
    seq_len: Optional[Tensor] = None,
    item_indices: Optional[Tensor] = None,
    **kwargs: Any,
) -> Tensor:
    """Compute STAN scores for a batch of current sessions.

    Args:
        user_indices (Tensor): User indices for which to produce scores.
        *args (Any): List of arguments.
        user_seq (Optional[Tensor]): Padded sequences of item IDs for users to predict for.
        seq_len (Optional[Tensor]): Actual lengths of these sequences, before padding.
        item_indices (Optional[Tensor]): The batch of item indices. If None,
            full prediction will be produced.
        **kwargs (Any): The dictionary of keyword arguments.

    Returns:
        Tensor: Score matrix [batch_size, n_items] or
            [batch_size, n_samples].
    """
    batch_size = int(user_indices.shape[0])
    # Output allocated directly as float32 (the returned dtype); halves
    # memory bandwidth versus building a float64 copy first.
    all_scores = np.zeros((batch_size, self.n_items), dtype=np.float32)
    # Reusable float64 accumulation buffer — keeps Eq. 8/9 precision.
    scores_buf = np.zeros(self.n_items, dtype=np.float64)
    padding_idx = self.n_items

    user_indices_cpu = user_indices.detach().cpu().tolist()
    if user_seq is not None and seq_len is not None:
        user_seq_np = user_seq.detach().cpu().numpy()
        seq_len_np = seq_len.detach().cpu().numpy()
    else:
        user_seq_np = None
        seq_len_np = None

    for b, user_id in enumerate(user_indices_cpu):
        user_id = int(user_id)

        # Recover current session items (most-recent first order
        # preserved — the evaluator already truncates to max_seq_len).
        if user_seq_np is not None and seq_len_np is not None:
            real_len = int(seq_len_np[b])
            row = user_seq_np[b, :real_len]
            # Strip any padding ids defensively.
            current_items = row[row != padding_idx].astype(np.int64)
        else:
            current_items = self._get_user_items(user_id).astype(np.int64)

        # Current session timestamp = max timestamp observed for this
        # user (matches paper's t(s) definition). Uses cached array.
        if self._session_timestamps is not None:
            current_ts = float(self._session_timestamps[user_id])
        else:
            current_ts = 0.0  # disables Factor-2 (see _compute_w2)

        self._score_session(current_items, current_ts, scores_buf)
        all_scores[b, :] = scores_buf  # implicit float64 -> float32 cast

    predictions = torch.from_numpy(all_scores)

    if item_indices is None:
        return predictions  # [batch_size, n_items]

    # Sampled evaluation — gather only requested item slots.
    return predictions.gather(
        1,
        item_indices.to(predictions.device).clamp(max=self.n_items - 1),
    )