Skip to content

Entities - API Reference

warprec.data.entities.interactions.Interactions

Interactions class will handle the data of the transactions.

Parameters:

Name Type Description Default
data DataFrame[Any]

Transaction data in DataFrame format.

required
original_dims Tuple[int, int]

int: Number of users. int: Number of items.

required
user_mapping dict

Mapping of user ID -> user idx.

required
item_mapping dict

Mapping of item ID -> item idx.

required
side_data Optional[DataFrame[Any]]

The side information features in DataFrame format.

None
user_cluster Optional[dict]

The user cluster information.

None
item_cluster Optional[dict]

The item cluster information.

None
batch_size int

The batch size that will be used to iterate over the interactions.

1024
rating_type RatingType

The type of rating to be used.

IMPLICIT
rating_label str

The label of the rating column.

None
timestamp_label str

The label of the timestamp column.

None
context_labels Optional[List[str]]

The list of labels of the contextual data.

None
Source code in warprec/data/entities/interactions.py
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
class Interactions:
    """Interactions class will handle the data of the transactions.

    Args:
        data (DataFrame[Any]): Transaction data in DataFrame format.
        original_dims (Tuple[int, int]):
            int: Number of users.
            int: Number of items.
        user_mapping (dict): Mapping of user ID -> user idx.
        item_mapping (dict): Mapping of item ID -> item idx.
        side_data (Optional[DataFrame[Any]]): The side information features in DataFrame format.
        user_cluster (Optional[dict]): The user cluster information.
        item_cluster (Optional[dict]): The item cluster information.
        batch_size (int): The batch size that will be used to
            iterate over the interactions.
        rating_type (RatingType): The type of rating to be used.
        rating_label (str): The label of the rating column.
        timestamp_label (str): The label of the timestamp column.
        context_labels (Optional[List[str]]): The list of labels of the
            contextual data.
    """

    def __init__(
        self,
        data: DataFrame[Any],
        original_dims: Tuple[int, int],
        user_mapping: dict,
        item_mapping: dict,
        side_data: Optional[DataFrame[Any]] = None,
        user_cluster: Optional[dict] = None,
        item_cluster: Optional[dict] = None,
        batch_size: int = 1024,
        rating_type: RatingType = RatingType.IMPLICIT,
        rating_label: str = None,
        timestamp_label: str = None,
        context_labels: Optional[List[str]] = None,
    ) -> None:
        # Setup the variables
        self._inter_df = data
        self._inter_side = side_data.clone() if side_data is not None else None
        self._inter_user_cluster = user_cluster if user_cluster is not None else None
        self._inter_item_cluster = item_cluster if item_cluster is not None else None
        self.batch_size = batch_size
        self.rating_type = rating_type

        # Setup the training variables
        self._inter_dict: Optional[dict] = None
        self._inter_sparse: csr_matrix = None
        self._inter_side_sparse: csr_matrix = None
        self._inter_side_tensor: Tensor = None
        self._inter_side_labels: List[str] = []
        self._history_matrix: Tensor = None
        self._history_lens: Tensor = None
        self._history_mask: Tensor = None

        # Set DataFrame labels
        self.user_label = data.columns[0]
        self.item_label = data.columns[1]
        self.rating_label = rating_label if rating_type == RatingType.EXPLICIT else None
        self.timestamp_label = timestamp_label
        self.context_labels = context_labels if context_labels else []

        # Setup flat views cache
        self._flat_users: Optional[np.ndarray] = None
        self._flat_items: Optional[np.ndarray] = None
        self._flat_ratings: Optional[np.ndarray] = None
        self._flat_timestamps: Optional[np.ndarray] = None

        # Set mappings
        self._umap = user_mapping
        self._imap = item_mapping

        # Filter side information (if present)
        if self._inter_side is not None:
            valid_items = self._inter_df.select(self.item_label).unique()
            # We use inner join on unique items to filter
            self._inter_side = self._inter_side.join(
                valid_items, on=self.item_label, how="inner"
            )

            # Order side information to be in the same order of the dataset (by item index)
            # Create mapping DF for items
            imap_df = nw.from_dict(
                {
                    self.item_label: list(item_mapping.keys()),
                    "__order__": list(item_mapping.values()),
                },
                native_namespace=nw.get_native_namespace(self._inter_side),
            )

            # Join to get the order, sort, and drop temp column
            self._inter_side = (
                self._inter_side.join(imap_df, on=self.item_label, how="left")
                .sort("__order__")
                .drop("__order__")
            )

            # Construct lookup for side information features
            feature_cols = [c for c in self._inter_side.columns if c != self.item_label]

            # Create the lookup tensor for side information
            # Extract values to numpy
            side_values = self._inter_side.select(feature_cols).to_numpy()
            side_tensor = torch.tensor(side_values, dtype=torch.long)

            # Create the padding row (zeros)
            padding_row = torch.zeros((1, side_tensor.shape[1]), dtype=torch.long)

            # Concatenate padding row at the beginning (assuming index 0 is padding/unknown)
            self._inter_side_tensor = torch.cat([side_tensor, padding_row], dim=0)

            # Store the feature labels
            self._inter_side_labels = feature_cols

        # Definition of dimensions
        self._uid = self._inter_df.select(self.user_label).unique().to_numpy().flatten()
        self._nuid = self._inter_df.select(nw.col(self.user_label).n_unique()).item()
        self._niid = self._inter_df.select(nw.col(self.item_label).n_unique()).item()
        self._og_nuid, self._og_niid = original_dims
        self._transactions = self._inter_df.select(nw.len()).item()

    def _get_mapped_indices(self) -> Tuple[Tensor, Tensor]:
        """Retrieves mapped user and item indices directly from the sparse matrix structure.

        Returns:
            Tuple[Tensor, Tensor]: (user_indices, item_indices) aligned as LongTensors.
        """
        mat = self.get_sparse()

        if not mat.has_sorted_indices:
            mat.sort_indices()

        # Extract the positive items
        pos_items = mat.indices.astype(np.int64)

        # Reconstruct the users
        n_users = mat.shape[0]
        interactions_per_user = np.diff(mat.indptr)
        users = np.repeat(np.arange(n_users), interactions_per_user).astype(np.int64)

        # Return Tensors directly
        return torch.from_numpy(users), torch.from_numpy(pos_items)

    def get_dict(self) -> dict:
        """This method will return the transaction information in dict format.

        Returns:
            dict: The transaction information in the current
                representation {user ID: {item ID: rating}}.
        """
        if self._inter_dict is not None:
            return self._inter_dict

        u_vals = self._inter_df.select(self.user_label).to_numpy().flatten()
        i_vals = self._inter_df.select(self.item_label).to_numpy().flatten()

        self._inter_dict = {}

        if self.rating_type == RatingType.EXPLICIT:
            r_vals = self._inter_df.select(self.rating_label).to_numpy().flatten()
            for u, i, r in zip(u_vals, i_vals, r_vals):
                if u not in self._inter_dict:
                    self._inter_dict[u] = {}
                self._inter_dict[u][i] = r
        elif self.rating_type == RatingType.IMPLICIT:
            for u, i in zip(u_vals, i_vals):
                if u not in self._inter_dict:
                    self._inter_dict[u] = {}
                self._inter_dict[u][i] = 1

        return self._inter_dict

    def get_df(self) -> DataFrame[Any]:
        """This method will return the raw data.

        Returns:
            DataFrame[Any]: The raw data in tabular format.
        """
        return self._inter_df

    def get_sparse(self) -> csr_matrix:
        """This method retrieves the sparse representation of data.

        This method also checks if the sparse structure has
        already been created, if not then it also create it first.

        Returns:
            csr_matrix: Sparse representation of the transactions (CSR Format).
        """
        if isinstance(self._inter_sparse, csr_matrix):
            return self._inter_sparse
        return self._to_sparse()

    def get_flat(
        self,
    ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, Optional[np.ndarray]]:
        """Returns the flattened, aligned arrays of users, items, ratings, and timestamps.

        This method caches the arrays after the first computation to ensure fast
        subsequent retrievals. The arrays are sorted by user and item indices.

        Returns:
            Tuple[np.ndarray, np.ndarray, np.ndarray, Optional[np.ndarray]]:
                - users array
                - items array
                - ratings array (or ones for implicit)
                - timestamps array (if timestamp_label is provided, else None)
        """
        # Return cached views if already computed
        if self._flat_users is not None:
            return (
                self._flat_users,
                self._flat_items,
                self._flat_ratings,
                self._flat_timestamps,
            )

        umap_df = nw.from_dict(
            {
                self.user_label: list(self._umap.keys()),
                "__uidx__": list(self._umap.values()),
            },
            native_namespace=nw.get_native_namespace(self._inter_df),
        )

        imap_df = nw.from_dict(
            {
                self.item_label: list(self._imap.keys()),
                "__iidx__": list(self._imap.values()),
            },
            native_namespace=nw.get_native_namespace(self._inter_df),
        )

        # Join and sort to ensure reproducibility and alignment
        mapped_df = self._inter_df.join(umap_df, on=self.user_label, how="inner").join(
            imap_df, on=self.item_label, how="inner"
        )
        mapped_df = mapped_df.sort(["__uidx__", "__iidx__"])

        # Extract arrays
        self._flat_users = mapped_df.select("__uidx__").to_numpy().flatten()
        self._flat_items = mapped_df.select("__iidx__").to_numpy().flatten()

        if self.rating_type == RatingType.EXPLICIT:
            self._flat_ratings = (
                mapped_df.select(self.rating_label).to_numpy().flatten()
            )
        else:
            self._flat_ratings = np.ones(len(self._flat_users), dtype=np.float32)

        if self.timestamp_label and self.timestamp_label in mapped_df.columns:
            self._flat_timestamps = (
                mapped_df.select(self.timestamp_label).to_numpy().flatten()
            )
        else:
            self._flat_timestamps = None

        return (
            self._flat_users,
            self._flat_items,
            self._flat_ratings,
            self._flat_timestamps,
        )

    def get_sparse_by_rating(self, rating_value: float) -> coo_matrix:
        """Returns a sparse matrix (COO format) containing only the interactions
        that match a specific rating value.

        Args:
            rating_value (float): The rating value to filter by.

        Returns:
            coo_matrix: A sparse matrix of shape [num_users, num_items] for the given rating.

        Raises:
            ValueError: If interactions are not explicit or if
                rating label is None.
        """
        if self.rating_type != RatingType.EXPLICIT or self.rating_label is None:
            raise ValueError(
                "Filtering by rating is only supported for explicit feedback data."
            )

        # Filter original DataFrame for the specified rating value
        rating_df = self._inter_df.filter(nw.col(self.rating_label) == rating_value)

        # Edge case: No interactions with the specified rating
        if rating_df.select(nw.len()).item() == 0:
            return coo_matrix((self._og_nuid, self._og_niid))

        umap_df = nw.from_dict(
            {
                self.user_label: list(self._umap.keys()),
                "__uidx__": list(self._umap.values()),
            },
            native_namespace=nw.get_native_namespace(rating_df),
        )

        imap_df = nw.from_dict(
            {
                self.item_label: list(self._imap.keys()),
                "__iidx__": list(self._imap.values()),
            },
            native_namespace=nw.get_native_namespace(rating_df),
        )

        # Join to map
        mapped_df = rating_df.join(umap_df, on=self.user_label, how="inner").join(
            imap_df, on=self.item_label, how="inner"
        )

        # Sort to ensure reproducibility
        mapped_df = mapped_df.sort(["__uidx__", "__iidx__"])

        # Extract indices
        users = mapped_df.select("__uidx__").to_numpy().flatten()
        items = mapped_df.select("__iidx__").to_numpy().flatten()

        # Values are all ones for the presence of interaction
        values = np.ones(len(users))

        return coo_matrix(
            (values, (users, items)), shape=(self._og_nuid, self._og_niid)
        )

    def get_side_sparse(self) -> csr_matrix:
        """This method retrieves the sparse representation of side data.

        This method also checks if the sparse structure has
        already been created, if not then it also create it first.

        Returns:
            csr_matrix: Sparse representation of the features (CSR Format).
        """
        if isinstance(self._inter_side_sparse, csr_matrix):
            return self._inter_side_sparse
        if self._inter_side is None:
            return None

        # Drop item label and convert to sparse
        side_features = self._inter_side.drop(self.item_label)
        # Convert to numpy first
        side_np = side_features.to_numpy()

        self._inter_side_sparse = csr_matrix(side_np)
        return self._inter_side_sparse

    def get_side_tensor(self) -> Tensor:
        """This method retrieves the tensor representation of side data.

        Returns:
            Tensor: Tensor representation of the features if available.
        """
        return self._inter_side_tensor

    def get_interaction_dataloader(
        self,
        include_user_id: bool = False,
        batch_size: int = 1024,
        shuffle: bool = True,
        **kwargs: Any,
    ) -> DataLoader:
        """Create a PyTorch DataLoader that yields dense tensors of interaction batches.

        This method retrieves the sparse interaction matrix, converts it into a PyTorch
        TensorDataset, and then wraps it in a DataLoader. The batches are provided as
        dense tensors of shape [batch_size, num_items].

        Args:
            include_user_id (bool): Whether to include user IDs in the output.
            batch_size (int): The batch size to be used for the DataLoader.
            shuffle (bool): Whether to shuffle the data when loading.
            **kwargs (Any): The additional keyword arguments to pass the Dataloader.

        Returns:
            DataLoader: A DataLoader that yields batches of dense interaction tensors.
        """
        # Get the sparse matrix, which is memory-efficient.
        sparse_matrix = self.get_sparse()

        # Create the lazy dataset which just holds a reference to the sparse matrix.
        lazy_dataset = InteractionDataset(
            sparse_matrix, include_user_id=include_user_id
        )
        return DataLoader(
            lazy_dataset, batch_size=batch_size, shuffle=shuffle, **kwargs
        )

    def get_pointwise_dataloader(
        self,
        neg_samples: int = 0,
        include_side_info: bool = False,
        include_context: bool = False,
        batch_size: int = 1024,
        shuffle: bool = True,
        seed: int = 42,
        **kwargs: Any,
    ) -> DataLoader:
        """Create a PyTorch DataLoader with implicit feedback and negative sampling.

        Args:
            neg_samples (int): Number of negative samples per user.
            include_side_info (bool): Whether to include side information features in the output.
            include_context (bool): Wether to include the context in the output.
            batch_size (int): The batch size that will be used to
            shuffle (bool): Whether to shuffle the data.
            seed (int): Seed for Numpy random number generator for reproducibility.
            **kwargs (Any): The additional keyword arguments to pass the Dataloader.

        Returns:
            DataLoader: Yields (user, item, rating) with negative samples or
                (user, item, rating, context) if flagged.
        """
        pos_users, pos_items = self._get_mapped_indices()

        # Prepare side information and context if requested
        side_info_tensor = None
        if include_side_info and self._inter_side_tensor is not None:
            side_info_tensor = self._inter_side_tensor

        context_tensor = None
        if include_context and self.context_labels:
            ctx_vals = self._inter_df.select(self.context_labels).to_numpy()
            context_tensor = torch.tensor(ctx_vals, dtype=torch.long)

        # Create the Dataset
        dataset = PointWiseDataset(
            user_ids=pos_users,
            item_ids=pos_items,
            sparse_matrix=self.get_sparse(),
            neg_samples=neg_samples,
            niid=self._niid,
            side_information=side_info_tensor,
            contexts=context_tensor,
        )

        # Set the generator for the Dataloader for reproducibility
        g = torch.Generator()
        g.manual_seed(seed)

        return DataLoader(
            dataset,
            batch_size=batch_size,
            shuffle=shuffle,
            worker_init_fn=seed_worker,
            generator=g,
            **kwargs,
        )

    def get_contrastive_dataloader(
        self,
        batch_size: int = 1024,
        shuffle: bool = True,
        seed: int = 42,
        **kwargs: Any,
    ) -> DataLoader:
        """Create a PyTorch DataLoader with triplets for implicit feedback.

        Args:
            batch_size (int): The batch size.
            shuffle (bool): Whether to shuffle the data.
            seed (int): Seed for reproducibility.
            **kwargs (Any): The additional keyword arguments to pass the Dataloader.

        Returns:
            DataLoader: Yields triplets of (user, positive_item, negative_item).
        """
        pos_users, pos_items = self._get_mapped_indices()

        # Create the Dataset
        dataset = ContrastiveDataset(
            user_ids=pos_users,
            item_ids=pos_items,
            sparse_matrix=self.get_sparse(),
            niid=self._niid,
        )

        # Set the generator for the Dataloader for reproducibility
        g = torch.Generator()
        g.manual_seed(seed)

        return DataLoader(
            dataset,
            batch_size=batch_size,
            shuffle=shuffle,
            worker_init_fn=seed_worker,
            generator=g,
            **kwargs,
        )

    def get_positive_dataloader(
        self,
        batch_size: int = 1024,
        shuffle: bool = True,
        seed: int = 42,
        **kwargs: Any,
    ) -> DataLoader:
        """Create a PyTorch DataLoader with only positive (user, item) pairs.

        Args:
            batch_size (int): The batch size.
            shuffle (bool): Whether to shuffle the data.
            seed (int): Seed for reproducibility.
            **kwargs (Any): The additional keyword arguments to pass the Dataloader.

        Returns:
            DataLoader: Yields pairs of (user, positive_item).
        """
        # Extract directly the positive interactions
        pos_users, pos_items = self._get_mapped_indices()

        # Create the Dataset
        dataset = PositiveDataset(
            user_ids=pos_users,
            item_ids=pos_items,
        )

        # Set the generator for the Dataloader for reproducibility
        g = torch.Generator()
        g.manual_seed(seed)

        return DataLoader(
            dataset,
            batch_size=batch_size,
            shuffle=shuffle,
            worker_init_fn=seed_worker,
            generator=g,
            **kwargs,
        )

    # If you've reached this point in the code you deserve to see
    # Dris the cat! Please let him survive the AI-generated PRs.
    # Say hi to Dris!
    #
    #  /\_/\
    # ( o.o )
    #  > ^ <

    def get_history(self) -> Tuple[Tensor, Tensor, Tensor]:
        """Return the history representation as three Tensors.

        This method also checks if this representation has been already computed,
        if so then it just returns it without computing it again.

        Returns:
            Tuple[Tensor, Tensor, Tensor]: A tuple containing:
                - Tensor: A matrix of dimension [num_user, max_chronology_length],
                    containing transaction information.
                - Tensor: An array of dimension [num_user], containing the
                    length of each chronology (before padding).
                - Tensor: A binary mask that identifies where the real
                    transaction information are, ignoring padding.
        """
        if (
            isinstance(self._history_matrix, Tensor)
            and isinstance(self._history_lens, Tensor)
            and isinstance(self._history_mask, Tensor)
        ):
            return self._history_matrix, self._history_lens, self._history_mask
        return self._to_history()

    def get_dims(self) -> Tuple[int, int]:
        """This method will return the dimensions of the data.

        Returns:
            Tuple[int, int]: A tuple containing:
                int: Number of unique users.
                int: Number of unique items.
        """
        return (self._nuid, self._niid)

    def get_transactions(self) -> int:
        """This method will return the number of transactions.

        Returns:
            int: Number of transactions.
        """
        return self._transactions

    def get_unique_ratings(self) -> np.ndarray:
        """Returns a sorted array of unique rating values present in the dataset.
        This is useful for models that operate on explicit feedback.

        Returns:
            np.ndarray: A sorted array of unique rating values.
        """
        if self.rating_type != RatingType.EXPLICIT or self.rating_label is None:
            return np.array([])

        return np.sort(
            self._inter_df.select(self.rating_label).unique().to_numpy().flatten()
        )

    def _to_sparse(self) -> csr_matrix:
        """This method will create the sparse representation of the data contained.

        This method must not be called if the sparse representation has already be defined.

        Returns:
            csr_matrix: Sparse representation of the transactions (CSR Format).
        """
        users, items, ratings, _ = self.get_flat()

        self._inter_sparse = coo_matrix(
            (ratings, (users, items)), shape=(self._og_nuid, self._og_niid)
        ).tocsr()

        return self._inter_sparse

    def _to_history(self) -> Tuple[Tensor, Tensor, Tensor]:
        """Creates three Tensor which contains information of the
        transaction history for each user.

        Returns:
            Tuple[Tensor, Tensor, Tensor]: A tuple containing:
                - Tensor: A matrix of dimension [num_user, max_chronology_length],
                    containing transaction information.
                - Tensor: An array of dimension [num_user], containing the
                    length of each chronology (before padding).
                - Tensor: A binary mask that identifies where the real
                    transaction information are, ignoring padding.
        """
        # Get sparse interaction matrix
        sparse_matrix = self.get_sparse()
        n_users = sparse_matrix.shape[0]
        n_items = sparse_matrix.shape[1]
        indptr = sparse_matrix.indptr
        indices = sparse_matrix.indices

        # Calculate lengths for each user
        lens = indptr[1:] - indptr[:-1]
        max_history_len = int(lens.max()) if len(lens) > 0 else 0

        # Initialize matrices
        self._history_matrix = torch.full(
            (n_users, max_history_len), fill_value=n_items, dtype=torch.long
        )
        self._history_lens = torch.from_numpy(lens.astype(np.int64))
        self._history_mask = torch.zeros(n_users, max_history_len, dtype=torch.float)

        # Populate matrices using slicing
        for u in range(n_users):
            start, end = indptr[u], indptr[u + 1]
            if end > start:
                length = end - start
                items = torch.from_numpy(indices[start:end].astype(np.int64))
                self._history_matrix[u, :length] = items
                self._history_mask[u, :length] = 1.0

        return self._history_matrix, self._history_lens, self._history_mask

    def __len__(self) -> int:
        """This method calculates the length of the interactions.

        Length will be defined as the number of ratings.

        Returns:
            int: number of ratings present in the structure.
        """
        return self._transactions

__len__()

This method calculates the length of the interactions.

Length will be defined as the number of ratings.

Returns:

Name Type Description
int int

number of ratings present in the structure.

Source code in warprec/data/entities/interactions.py
def __len__(self) -> int:
    """This method calculates the length of the interactions.

    Length will be defined as the number of ratings.

    Returns:
        int: number of ratings present in the structure.
    """
    return self._transactions

get_contrastive_dataloader(batch_size=1024, shuffle=True, seed=42, **kwargs)

Create a PyTorch DataLoader with triplets for implicit feedback.

Parameters:

Name Type Description Default
batch_size int

The batch size.

1024
shuffle bool

Whether to shuffle the data.

True
seed int

Seed for reproducibility.

42
**kwargs Any

The additional keyword arguments to pass the Dataloader.

{}

Returns:

Name Type Description
DataLoader DataLoader

Yields triplets of (user, positive_item, negative_item).

Source code in warprec/data/entities/interactions.py
def get_contrastive_dataloader(
    self,
    batch_size: int = 1024,
    shuffle: bool = True,
    seed: int = 42,
    **kwargs: Any,
) -> DataLoader:
    """Create a PyTorch DataLoader with triplets for implicit feedback.

    Args:
        batch_size (int): The batch size.
        shuffle (bool): Whether to shuffle the data.
        seed (int): Seed for reproducibility.
        **kwargs (Any): The additional keyword arguments to pass the Dataloader.

    Returns:
        DataLoader: Yields triplets of (user, positive_item, negative_item).
    """
    pos_users, pos_items = self._get_mapped_indices()

    # Create the Dataset
    dataset = ContrastiveDataset(
        user_ids=pos_users,
        item_ids=pos_items,
        sparse_matrix=self.get_sparse(),
        niid=self._niid,
    )

    # Set the generator for the Dataloader for reproducibility
    g = torch.Generator()
    g.manual_seed(seed)

    return DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        worker_init_fn=seed_worker,
        generator=g,
        **kwargs,
    )

get_df()

This method will return the raw data.

Returns:

Type Description
DataFrame[Any]

DataFrame[Any]: The raw data in tabular format.

Source code in warprec/data/entities/interactions.py
def get_df(self) -> DataFrame[Any]:
    """This method will return the raw data.

    Returns:
        DataFrame[Any]: The raw data in tabular format.
    """
    return self._inter_df

get_dict()

This method will return the transaction information in dict format.

Returns:

Name Type Description
dict dict

The transaction information in the current representation {user ID: {item ID: rating}}.

Source code in warprec/data/entities/interactions.py
def get_dict(self) -> dict:
    """This method will return the transaction information in dict format.

    Returns:
        dict: The transaction information in the current
            representation {user ID: {item ID: rating}}.
    """
    if self._inter_dict is not None:
        return self._inter_dict

    u_vals = self._inter_df.select(self.user_label).to_numpy().flatten()
    i_vals = self._inter_df.select(self.item_label).to_numpy().flatten()

    self._inter_dict = {}

    if self.rating_type == RatingType.EXPLICIT:
        r_vals = self._inter_df.select(self.rating_label).to_numpy().flatten()
        for u, i, r in zip(u_vals, i_vals, r_vals):
            if u not in self._inter_dict:
                self._inter_dict[u] = {}
            self._inter_dict[u][i] = r
    elif self.rating_type == RatingType.IMPLICIT:
        for u, i in zip(u_vals, i_vals):
            if u not in self._inter_dict:
                self._inter_dict[u] = {}
            self._inter_dict[u][i] = 1

    return self._inter_dict

get_dims()

This method will return the dimensions of the data.

Returns:

Type Description
Tuple[int, int]

Tuple[int, int]: A tuple containing: int: Number of unique users. int: Number of unique items.

Source code in warprec/data/entities/interactions.py
def get_dims(self) -> Tuple[int, int]:
    """This method will return the dimensions of the data.

    Returns:
        Tuple[int, int]: A tuple containing:
            int: Number of unique users.
            int: Number of unique items.
    """
    return (self._nuid, self._niid)

get_flat()

Returns the flattened, aligned arrays of users, items, ratings, and timestamps.

This method caches the arrays after the first computation to ensure fast subsequent retrievals. The arrays are sorted by user and item indices.

Returns:

Type Description
Tuple[ndarray, ndarray, ndarray, Optional[ndarray]]

Tuple[np.ndarray, np.ndarray, np.ndarray, Optional[np.ndarray]]: - users array - items array - ratings array (or ones for implicit) - timestamps array (if timestamp_label is provided, else None)

Source code in warprec/data/entities/interactions.py
def get_flat(
    self,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, Optional[np.ndarray]]:
    """Returns the flattened, aligned arrays of users, items, ratings, and timestamps.

    This method caches the arrays after the first computation to ensure fast
    subsequent retrievals. The arrays are sorted by user and item indices.

    Returns:
        Tuple[np.ndarray, np.ndarray, np.ndarray, Optional[np.ndarray]]:
            - users array
            - items array
            - ratings array (or ones for implicit)
            - timestamps array (if timestamp_label is provided, else None)
    """
    # Return cached views if already computed
    if self._flat_users is not None:
        return (
            self._flat_users,
            self._flat_items,
            self._flat_ratings,
            self._flat_timestamps,
        )

    umap_df = nw.from_dict(
        {
            self.user_label: list(self._umap.keys()),
            "__uidx__": list(self._umap.values()),
        },
        native_namespace=nw.get_native_namespace(self._inter_df),
    )

    imap_df = nw.from_dict(
        {
            self.item_label: list(self._imap.keys()),
            "__iidx__": list(self._imap.values()),
        },
        native_namespace=nw.get_native_namespace(self._inter_df),
    )

    # Join and sort to ensure reproducibility and alignment
    mapped_df = self._inter_df.join(umap_df, on=self.user_label, how="inner").join(
        imap_df, on=self.item_label, how="inner"
    )
    mapped_df = mapped_df.sort(["__uidx__", "__iidx__"])

    # Extract arrays
    self._flat_users = mapped_df.select("__uidx__").to_numpy().flatten()
    self._flat_items = mapped_df.select("__iidx__").to_numpy().flatten()

    if self.rating_type == RatingType.EXPLICIT:
        self._flat_ratings = (
            mapped_df.select(self.rating_label).to_numpy().flatten()
        )
    else:
        self._flat_ratings = np.ones(len(self._flat_users), dtype=np.float32)

    if self.timestamp_label and self.timestamp_label in mapped_df.columns:
        self._flat_timestamps = (
            mapped_df.select(self.timestamp_label).to_numpy().flatten()
        )
    else:
        self._flat_timestamps = None

    return (
        self._flat_users,
        self._flat_items,
        self._flat_ratings,
        self._flat_timestamps,
    )

get_history()

Return the history representation as three Tensors.

This method also checks if this representation has been already computed, if so then it just returns it without computing it again.

Returns:

Type Description
Tuple[Tensor, Tensor, Tensor]

Tuple[Tensor, Tensor, Tensor]: A tuple containing: - Tensor: A matrix of dimension [num_user, max_chronology_length], containing transaction information. - Tensor: An array of dimension [num_user], containing the length of each chronology (before padding). - Tensor: A binary mask that identifies where the real transaction information are, ignoring padding.

Source code in warprec/data/entities/interactions.py
def get_history(self) -> Tuple[Tensor, Tensor, Tensor]:
    """Return the history representation as three Tensors.

    This method also checks if this representation has been already computed,
    if so then it just returns it without computing it again.

    Returns:
        Tuple[Tensor, Tensor, Tensor]: A tuple containing:
            - Tensor: A matrix of dimension [num_user, max_chronology_length],
                containing transaction information.
            - Tensor: An array of dimension [num_user], containing the
                length of each chronology (before padding).
            - Tensor: A binary mask that identifies where the real
                transaction information are, ignoring padding.
    """
    if (
        isinstance(self._history_matrix, Tensor)
        and isinstance(self._history_lens, Tensor)
        and isinstance(self._history_mask, Tensor)
    ):
        return self._history_matrix, self._history_lens, self._history_mask
    return self._to_history()

get_interaction_dataloader(include_user_id=False, batch_size=1024, shuffle=True, **kwargs)

Create a PyTorch DataLoader that yields dense tensors of interaction batches.

This method retrieves the sparse interaction matrix, converts it into a PyTorch TensorDataset, and then wraps it in a DataLoader. The batches are provided as dense tensors of shape [batch_size, num_items].

Parameters:

Name Type Description Default
include_user_id bool

Whether to include user IDs in the output.

False
batch_size int

The batch size to be used for the DataLoader.

1024
shuffle bool

Whether to shuffle the data when loading.

True
**kwargs Any

The additional keyword arguments to pass the Dataloader.

{}

Returns:

Name Type Description
DataLoader DataLoader

A DataLoader that yields batches of dense interaction tensors.

Source code in warprec/data/entities/interactions.py
def get_interaction_dataloader(
    self,
    include_user_id: bool = False,
    batch_size: int = 1024,
    shuffle: bool = True,
    **kwargs: Any,
) -> DataLoader:
    """Create a PyTorch DataLoader that yields dense tensors of interaction batches.

    This method retrieves the sparse interaction matrix, converts it into a PyTorch
    TensorDataset, and then wraps it in a DataLoader. The batches are provided as
    dense tensors of shape [batch_size, num_items].

    Args:
        include_user_id (bool): Whether to include user IDs in the output.
        batch_size (int): The batch size to be used for the DataLoader.
        shuffle (bool): Whether to shuffle the data when loading.
        **kwargs (Any): The additional keyword arguments to pass the Dataloader.

    Returns:
        DataLoader: A DataLoader that yields batches of dense interaction tensors.
    """
    # Get the sparse matrix, which is memory-efficient.
    sparse_matrix = self.get_sparse()

    # Create the lazy dataset which just holds a reference to the sparse matrix.
    lazy_dataset = InteractionDataset(
        sparse_matrix, include_user_id=include_user_id
    )
    return DataLoader(
        lazy_dataset, batch_size=batch_size, shuffle=shuffle, **kwargs
    )

get_pointwise_dataloader(neg_samples=0, include_side_info=False, include_context=False, batch_size=1024, shuffle=True, seed=42, **kwargs)

Create a PyTorch DataLoader with implicit feedback and negative sampling.

Parameters:

Name Type Description Default
neg_samples int

Number of negative samples per user.

0
include_side_info bool

Whether to include side information features in the output.

False
include_context bool

Wether to include the context in the output.

False
batch_size int

The batch size that will be used to

1024
shuffle bool

Whether to shuffle the data.

True
seed int

Seed for Numpy random number generator for reproducibility.

42
**kwargs Any

The additional keyword arguments to pass the Dataloader.

{}

Returns:

Name Type Description
DataLoader DataLoader

Yields (user, item, rating) with negative samples or (user, item, rating, context) if flagged.

Source code in warprec/data/entities/interactions.py
def get_pointwise_dataloader(
    self,
    neg_samples: int = 0,
    include_side_info: bool = False,
    include_context: bool = False,
    batch_size: int = 1024,
    shuffle: bool = True,
    seed: int = 42,
    **kwargs: Any,
) -> DataLoader:
    """Create a PyTorch DataLoader with implicit feedback and negative sampling.

    Args:
        neg_samples (int): Number of negative samples per user.
        include_side_info (bool): Whether to include side information features in the output.
        include_context (bool): Wether to include the context in the output.
        batch_size (int): The batch size that will be used to
        shuffle (bool): Whether to shuffle the data.
        seed (int): Seed for Numpy random number generator for reproducibility.
        **kwargs (Any): The additional keyword arguments to pass the Dataloader.

    Returns:
        DataLoader: Yields (user, item, rating) with negative samples or
            (user, item, rating, context) if flagged.
    """
    pos_users, pos_items = self._get_mapped_indices()

    # Prepare side information and context if requested
    side_info_tensor = None
    if include_side_info and self._inter_side_tensor is not None:
        side_info_tensor = self._inter_side_tensor

    context_tensor = None
    if include_context and self.context_labels:
        ctx_vals = self._inter_df.select(self.context_labels).to_numpy()
        context_tensor = torch.tensor(ctx_vals, dtype=torch.long)

    # Create the Dataset
    dataset = PointWiseDataset(
        user_ids=pos_users,
        item_ids=pos_items,
        sparse_matrix=self.get_sparse(),
        neg_samples=neg_samples,
        niid=self._niid,
        side_information=side_info_tensor,
        contexts=context_tensor,
    )

    # Set the generator for the Dataloader for reproducibility
    g = torch.Generator()
    g.manual_seed(seed)

    return DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        worker_init_fn=seed_worker,
        generator=g,
        **kwargs,
    )

get_positive_dataloader(batch_size=1024, shuffle=True, seed=42, **kwargs)

Create a PyTorch DataLoader with only positive (user, item) pairs.

Parameters:

Name Type Description Default
batch_size int

The batch size.

1024
shuffle bool

Whether to shuffle the data.

True
seed int

Seed for reproducibility.

42
**kwargs Any

The additional keyword arguments to pass the Dataloader.

{}

Returns:

Name Type Description
DataLoader DataLoader

Yields pairs of (user, positive_item).

Source code in warprec/data/entities/interactions.py
def get_positive_dataloader(
    self,
    batch_size: int = 1024,
    shuffle: bool = True,
    seed: int = 42,
    **kwargs: Any,
) -> DataLoader:
    """Create a PyTorch DataLoader with only positive (user, item) pairs.

    Args:
        batch_size (int): The batch size.
        shuffle (bool): Whether to shuffle the data.
        seed (int): Seed for reproducibility.
        **kwargs (Any): The additional keyword arguments to pass the Dataloader.

    Returns:
        DataLoader: Yields pairs of (user, positive_item).
    """
    # Extract directly the positive interactions
    pos_users, pos_items = self._get_mapped_indices()

    # Create the Dataset
    dataset = PositiveDataset(
        user_ids=pos_users,
        item_ids=pos_items,
    )

    # Set the generator for the Dataloader for reproducibility
    g = torch.Generator()
    g.manual_seed(seed)

    return DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        worker_init_fn=seed_worker,
        generator=g,
        **kwargs,
    )

get_side_sparse()

This method retrieves the sparse representation of side data.

This method also checks if the sparse structure has already been created, if not then it also create it first.

Returns:

Name Type Description
csr_matrix csr_matrix

Sparse representation of the features (CSR Format).

Source code in warprec/data/entities/interactions.py
def get_side_sparse(self) -> csr_matrix:
    """This method retrieves the sparse representation of side data.

    This method also checks if the sparse structure has
    already been created, if not then it also create it first.

    Returns:
        csr_matrix: Sparse representation of the features (CSR Format).
    """
    if isinstance(self._inter_side_sparse, csr_matrix):
        return self._inter_side_sparse
    if self._inter_side is None:
        return None

    # Drop item label and convert to sparse
    side_features = self._inter_side.drop(self.item_label)
    # Convert to numpy first
    side_np = side_features.to_numpy()

    self._inter_side_sparse = csr_matrix(side_np)
    return self._inter_side_sparse

get_side_tensor()

This method retrieves the tensor representation of side data.

Returns:

Name Type Description
Tensor Tensor

Tensor representation of the features if available.

Source code in warprec/data/entities/interactions.py
def get_side_tensor(self) -> Tensor:
    """This method retrieves the tensor representation of side data.

    Returns:
        Tensor: Tensor representation of the features if available.
    """
    return self._inter_side_tensor

get_sparse()

This method retrieves the sparse representation of data.

This method also checks if the sparse structure has already been created, if not then it also create it first.

Returns:

Name Type Description
csr_matrix csr_matrix

Sparse representation of the transactions (CSR Format).

Source code in warprec/data/entities/interactions.py
def get_sparse(self) -> csr_matrix:
    """This method retrieves the sparse representation of data.

    This method also checks if the sparse structure has
    already been created, if not then it also create it first.

    Returns:
        csr_matrix: Sparse representation of the transactions (CSR Format).
    """
    if isinstance(self._inter_sparse, csr_matrix):
        return self._inter_sparse
    return self._to_sparse()

get_sparse_by_rating(rating_value)

Returns a sparse matrix (COO format) containing only the interactions that match a specific rating value.

Parameters:

Name Type Description Default
rating_value float

The rating value to filter by.

required

Returns:

Name Type Description
coo_matrix coo_matrix

A sparse matrix of shape [num_users, num_items] for the given rating.

Raises:

Type Description
ValueError

If interactions are not explicit or if rating label is None.

Source code in warprec/data/entities/interactions.py
def get_sparse_by_rating(self, rating_value: float) -> coo_matrix:
    """Returns a sparse matrix (COO format) containing only the interactions
    that match a specific rating value.

    Args:
        rating_value (float): The rating value to filter by.

    Returns:
        coo_matrix: A sparse matrix of shape [num_users, num_items] for the given rating.

    Raises:
        ValueError: If interactions are not explicit or if
            rating label is None.
    """
    if self.rating_type != RatingType.EXPLICIT or self.rating_label is None:
        raise ValueError(
            "Filtering by rating is only supported for explicit feedback data."
        )

    # Filter original DataFrame for the specified rating value
    rating_df = self._inter_df.filter(nw.col(self.rating_label) == rating_value)

    # Edge case: No interactions with the specified rating
    if rating_df.select(nw.len()).item() == 0:
        return coo_matrix((self._og_nuid, self._og_niid))

    umap_df = nw.from_dict(
        {
            self.user_label: list(self._umap.keys()),
            "__uidx__": list(self._umap.values()),
        },
        native_namespace=nw.get_native_namespace(rating_df),
    )

    imap_df = nw.from_dict(
        {
            self.item_label: list(self._imap.keys()),
            "__iidx__": list(self._imap.values()),
        },
        native_namespace=nw.get_native_namespace(rating_df),
    )

    # Join to map
    mapped_df = rating_df.join(umap_df, on=self.user_label, how="inner").join(
        imap_df, on=self.item_label, how="inner"
    )

    # Sort to ensure reproducibility
    mapped_df = mapped_df.sort(["__uidx__", "__iidx__"])

    # Extract indices
    users = mapped_df.select("__uidx__").to_numpy().flatten()
    items = mapped_df.select("__iidx__").to_numpy().flatten()

    # Values are all ones for the presence of interaction
    values = np.ones(len(users))

    return coo_matrix(
        (values, (users, items)), shape=(self._og_nuid, self._og_niid)
    )

get_transactions()

This method will return the number of transactions.

Returns:

Name Type Description
int int

Number of transactions.

Source code in warprec/data/entities/interactions.py
def get_transactions(self) -> int:
    """This method will return the number of transactions.

    Returns:
        int: Number of transactions.
    """
    return self._transactions

get_unique_ratings()

Returns a sorted array of unique rating values present in the dataset. This is useful for models that operate on explicit feedback.

Returns:

Type Description
ndarray

np.ndarray: A sorted array of unique rating values.

Source code in warprec/data/entities/interactions.py
def get_unique_ratings(self) -> np.ndarray:
    """Returns a sorted array of unique rating values present in the dataset.
    This is useful for models that operate on explicit feedback.

    Returns:
        np.ndarray: A sorted array of unique rating values.
    """
    if self.rating_type != RatingType.EXPLICIT or self.rating_label is None:
        return np.array([])

    return np.sort(
        self._inter_df.select(self.rating_label).unique().to_numpy().flatten()
    )

warprec.data.entities.train_structures.interaction_structures.InteractionDataset

Bases: Dataset

A PyTorch Dataset that serves rows from a sparse matrix on-the-fly.

This avoids the massive memory allocation required by sparse_matrix.todense().

Parameters:

Name Type Description Default
sparse_matrix csr_matrix

The user-item interaction matrix in CSR format.

required
include_user_id bool

If True, also returns the index of the user.

False
Source code in warprec/data/entities/train_structures/interaction_structures.py
class InteractionDataset(Dataset):
    """A PyTorch Dataset that serves rows from a sparse matrix on-the-fly.

    This avoids the massive memory allocation required by `sparse_matrix.todense()`.

    Args:
        sparse_matrix (csr_matrix): The user-item interaction matrix in CSR format.
        include_user_id (bool): If True, also returns the index of the user.
    """

    def __init__(self, sparse_matrix: csr_matrix, include_user_id: bool = False):
        self.sparse_matrix = sparse_matrix
        self.include_user_id = include_user_id

    def __len__(self) -> int:
        return self.sparse_matrix.shape[0]

    def __getitem__(self, idx: int) -> Tuple[Tensor, ...]:
        # CSR format is highly optimized for row slicing. This operation is very fast.
        user_row_sparse = self.sparse_matrix[idx]

        # Convert only this single row to a dense NumPy array.
        user_row_dense_np = user_row_sparse.todense()

        # Convert to a PyTorch tensor and remove the unnecessary leading dimension (shape [1, N] -> [N]).
        user_tensor = (
            torch.from_numpy(user_row_dense_np).to(dtype=torch.float32).squeeze(0)
        )

        if self.include_user_id:
            # Return also the user indices
            return torch.tensor(idx, dtype=torch.long), user_tensor

        # Normal behavior
        return (user_tensor,)

warprec.data.entities.train_structures.interaction_structures.PointWiseDataset

Bases: Dataset

A PyTorch Dataset for (user, item, rating) triplets that generates samples on-the-fly.

It calculates the total number of samples (positives + negatives) and maps any given index idx to either a positive interaction (rating=1.0) or a newly sampled negative interaction (rating=0.0).

Parameters:

Name Type Description Default
user_ids Tensor

The Torch tensor of user ids aligned with the items.

required
item_ids Tensor

The Torch tensor of item ids aligned with the users.

required
sparse_matrix csr_matrix

The user-item interaction matrix in CSR format.

required
neg_samples int

The number of negative samples to generate for each positive interaction.

required
niid int

The total number of unique items for negative sampling.

required
side_information Optional[Tensor]

The tensor containing the side information of each interaction.

None
contexts Optional[Tensor]

The tensor containing the context information of each interaction.

None
Source code in warprec/data/entities/train_structures/interaction_structures.py
class PointWiseDataset(Dataset):
    """A PyTorch Dataset for (user, item, rating) triplets that generates samples on-the-fly.

    It calculates the total number of samples (positives + negatives)
    and maps any given index `idx` to either a positive interaction (rating=1.0) or a
    newly sampled negative interaction (rating=0.0).

    Args:
        user_ids (Tensor): The Torch tensor of user ids aligned with the items.
        item_ids (Tensor): The Torch tensor of item ids aligned with the users.
        sparse_matrix (csr_matrix): The user-item interaction matrix in CSR format.
        neg_samples (int): The number of negative samples to generate for each
            positive interaction.
        niid (int): The total number of unique items for negative sampling.
        side_information (Optional[Tensor]): The tensor containing the side information
            of each interaction.
        contexts (Optional[Tensor]): The tensor containing the context information
            of each interaction.
    """

    def __init__(
        self,
        user_ids: Tensor,
        item_ids: Tensor,
        sparse_matrix: csr_matrix,
        neg_samples: int,
        niid: int,
        side_information: Optional[Tensor] = None,
        contexts: Optional[Tensor] = None,
    ):
        # Keep a copy of positive values
        self.user_ids = user_ids
        self.item_ids = item_ids

        # CSR matrix for faster lookup
        self.sparse_matrix = sparse_matrix

        self.neg_samples = neg_samples
        self.niid = niid
        self.side_information = side_information
        self.contexts = contexts

        self.num_positives = len(self.user_ids)
        self.total_samples = self.num_positives * (1 + self.neg_samples)

    def __len__(self) -> int:
        return self.total_samples

    def __getitem__(self, idx: int) -> Tuple[Tensor, ...]:
        # Linear mapping
        # idx 0 -> (pos_idx=0, offset=0) -> Positive
        # idx 1 -> (pos_idx=0, offset=1) -> Negative
        pos_interaction_idx = idx // (1 + self.neg_samples)
        sample_offset = idx % (1 + self.neg_samples)

        user_tensor = self.user_ids[pos_interaction_idx]

        if sample_offset == 0:
            item_tensor = self.item_ids[pos_interaction_idx]
            rating_val = 1.0
        else:
            rating_val = 0.0

            # Fast lookup with csr matrix
            # user_tensor is a 0-d tensor, we need its integer value for indexing
            user_idx = user_tensor.item()
            start = self.sparse_matrix.indptr[user_idx]
            end = self.sparse_matrix.indptr[user_idx + 1]
            seen_items = self.sparse_matrix.indices[start:end]

            # Negative sampling
            while True:
                # NumPy random is generally faster than torch.randint for single scalars
                candidate = np.random.randint(0, self.niid)

                # Fast check on sorted array (CSR indices are sorted by default)
                idx_ins = np.searchsorted(seen_items, candidate)
                if idx_ins < len(seen_items) and seen_items[idx_ins] == candidate:
                    continue

                item_tensor = torch.tensor(candidate, dtype=torch.long)
                break

        # Rating
        rating_tensor = torch.tensor(rating_val, dtype=torch.float)

        # Explicitly type the list to avoid mypy errors
        ret: List[Tensor] = [user_tensor, item_tensor, rating_tensor]

        # Side Info
        if self.side_information is not None:
            ret.append(self.side_information[item_tensor])

        # Context
        if self.contexts is not None:
            ret.append(self.contexts[pos_interaction_idx])

        return tuple(ret)

warprec.data.entities.train_structures.interaction_structures.ContrastiveDataset

Bases: Dataset

A PyTorch Dataset for (user, positive_item, negative_item) triplets.

Generates negative samples on-the-fly using the sparse interaction matrix.

Parameters:

Name Type Description Default
user_ids Tensor

Tensor of user indices for positive interactions.

required
item_ids Tensor

Tensor of item indices for positive interactions.

required
sparse_matrix csr_matrix

The user-item interaction matrix in CSR format.

required
niid int

Total number of items available.

required
Source code in warprec/data/entities/train_structures/interaction_structures.py
class ContrastiveDataset(Dataset):
    """A PyTorch Dataset for (user, positive_item, negative_item) triplets.

    Generates negative samples on-the-fly using the sparse interaction matrix.

    Args:
        user_ids (Tensor): Tensor of user indices for positive interactions.
        item_ids (Tensor): Tensor of item indices for positive interactions.
        sparse_matrix (csr_matrix): The user-item interaction matrix in CSR format.
        niid (int): Total number of items available.
    """

    def __init__(
        self,
        user_ids: Tensor,
        item_ids: Tensor,
        sparse_matrix: csr_matrix,
        niid: int,
    ):
        self.user_ids = user_ids
        self.item_ids = item_ids
        self.sparse_matrix = sparse_matrix
        self.niid = niid

    def __len__(self) -> int:
        return len(self.user_ids)

    def __getitem__(self, idx: int) -> Tuple[Tensor, Tensor, Tensor]:
        # Retrieve positive interaction
        user_tensor = self.user_ids[idx]
        pos_item_tensor = self.item_ids[idx]

        # Negative Sampling
        # Retrieve seen items for this user using CSR slicing
        user_idx = user_tensor.item()
        start = self.sparse_matrix.indptr[user_idx]
        end = self.sparse_matrix.indptr[user_idx + 1]
        seen_items = self.sparse_matrix.indices[start:end]

        while True:
            # Sample a random item
            candidate = np.random.randint(0, self.niid)

            # Check if it is a true negative
            idx_ins = np.searchsorted(seen_items, candidate)
            if idx_ins < len(seen_items) and seen_items[idx_ins] == candidate:
                continue

            neg_item_tensor = torch.tensor(candidate, dtype=torch.long)
            break

        # Return triplet (user, pos, neg)
        return user_tensor, pos_item_tensor, neg_item_tensor

warprec.data.entities.sessions.Sessions

Handles session-based data preparation for sequential recommenders. Transforms user-item interactions into padded sequences or sliding windows.

Source code in warprec/data/entities/sessions.py
class Sessions:
    """
    Handles session-based data preparation for sequential recommenders.
    Transforms user-item interactions into padded sequences or sliding windows.
    """

    def __init__(
        self,
        data: DataFrame[Any],
        user_mapping: dict,
        item_mapping: dict,
        sparse_matrix: csr_matrix,
        user_id_label: str = "user_id",
        item_id_label: str = "item_id",
        timestamp_label: str = "timestamp",
        context_labels: Optional[List[str]] = None,
    ):
        # Validation
        if user_id_label not in data.columns:
            raise ValueError(f"User column '{user_id_label}' not found.")
        if item_id_label not in data.columns:
            raise ValueError(f"Item column '{item_id_label}' not found.")

        # Configuration
        self._inter_df = data
        self._umap = user_mapping
        self._imap = item_mapping
        self.user_label = user_id_label
        self.item_label = item_id_label
        self.timestamp_label = timestamp_label
        self.context_labels = context_labels or []

        # Dimensions & Cache
        self._niid = len(self._imap)
        self._nuid = len(self._umap)
        self._cached_user_histories: Dict[int, List[int]] = {}
        self._processed_df: DataFrame[Any] = None  # Cache for the sorted dataframe

        # Internal Structures (Lazy Loaded)
        self._flat_items: Optional[np.ndarray] = None
        self._flat_users: Optional[np.ndarray] = None
        self._user_offsets: Optional[np.ndarray] = None
        self._valid_sample_indices: Optional[np.ndarray] = None
        self._inter_sparse = sparse_matrix

        # Build Core Structures
        self._build_flat_structures()

    def _get_processed_data(self) -> DataFrame[Any]:
        """
        Centralized pipeline: Map IDs -> Drop Missing -> Sort by User/Time.
        Returns a cached Narwhals DataFrame.
        """
        if self._processed_df is not None:
            return self._processed_df

        native_ns = nw.get_native_namespace(self._inter_df)

        # Create mapping frames
        umap_df = nw.from_dict(
            {
                self.user_label: list(self._umap.keys()),
                "__uidx__": list(self._umap.values()),
            },
            native_namespace=native_ns,
        )
        imap_df = nw.from_dict(
            {
                self.item_label: list(self._imap.keys()),
                "__iidx__": list(self._imap.values()),
            },
            native_namespace=native_ns,
        )

        # Join and Map
        mapped_df = (
            self._inter_df.join(umap_df, on=self.user_label, how="inner")
            .join(imap_df, on=self.item_label, how="inner")
            .select(
                [
                    nw.col("__uidx__").alias(self.user_label).cast(nw.Int64),
                    nw.col("__iidx__").alias(self.item_label).cast(nw.Int64),
                    # Keep timestamp if exists
                    *(
                        [nw.col(self.timestamp_label)]
                        if self.timestamp_label in self._inter_df.columns
                        else []
                    ),
                    # Keep context if exists
                    *(
                        [
                            nw.col(c).cast(nw.Int64)
                            for c in self.context_labels
                            if c in self._inter_df.columns
                        ]
                    ),
                ]
            )
        )

        # Sort
        sort_cols = [self.user_label]
        if self.timestamp_label in self._inter_df.columns:
            sort_cols.append(self.timestamp_label)

        self._processed_df = mapped_df.sort(sort_cols)
        return self._processed_df

    def _build_flat_structures(self):
        """
        Converts the processed DataFrame into flat Numpy arrays ("The Tape")
        and calculates user offsets for O(1) access to any user's history.
        """
        df = self._get_processed_data()

        # Extract columns to numpy (The Tape)
        self._flat_users = df.select(self.user_label).to_numpy().flatten()
        self._flat_items = df.select(self.item_label).to_numpy().flatten()

        # Calculate Offsets
        # unique_users are sorted because df is sorted by user
        unique_users, start_indices = np.unique(self._flat_users, return_index=True)

        self._user_offsets = np.zeros(self._nuid + 1, dtype=np.int64)

        # Set starts
        self._user_offsets[unique_users] = start_indices
        # Set ends (start of next user)
        self._user_offsets[unique_users + 1] = np.roll(start_indices, -1)
        self._user_offsets[-1] = len(self._flat_items)

        # Fill gaps for users with no interactions (propagate previous offset)
        # This ensures user_offsets[u] == user_offsets[u+1] for empty users
        for i in range(1, len(self._user_offsets)):
            if self._user_offsets[i] == 0 and self._user_offsets[i - 1] > 0:
                self._user_offsets[i] = self._user_offsets[i - 1]

    def get_user_history_sequences(
        self, user_ids: List[int], max_seq_len: int
    ) -> Tuple[Tensor, Tensor]:
        """Retrieves padded historical sequences for inference/evaluation."""
        if not self._cached_user_histories:
            # Build dict cache on demand
            # Using split on the flat array is faster than iterating DF
            starts = self._user_offsets[:-1]
            ends = self._user_offsets[1:]
            # Only for users that actually exist in data
            valid_u = np.where(ends > starts)[0]
            self._cached_user_histories = {
                u: self._flat_items[starts[u] : ends[u]].tolist() for u in valid_u
            }

        seqs, lens = [], []
        for uid in user_ids:
            hist = self._cached_user_histories.get(uid, [])
            recent = hist[-max_seq_len:]
            seqs.append(torch.tensor(recent, dtype=torch.long))
            lens.append(len(recent))

        return (
            pad_sequence(seqs, batch_first=True, padding_value=self._niid),
            torch.tensor(lens, dtype=torch.long),
        )

    def get_sequential_dataloader(
        self,
        max_seq_len: int,
        neg_samples: int = 0,
        include_user_id: bool = False,
        batch_size: int = 1024,
        shuffle: bool = True,
        seed: int = 42,
        **kwargs: Any,
    ) -> DataLoader:
        """Standard SASRec/RNN style dataloader (History -> Next Item)."""

        # Identify valid targets (items that have at least 1 predecessor)
        if self._valid_sample_indices is None:
            all_indices = np.arange(len(self._flat_items))
            valid_mask = np.ones(len(self._flat_items), dtype=bool)

            # The first item of any user cannot be a target (no history)
            user_starts = self._user_offsets[:-1]
            # Filter only starts that are within bounds (active users)
            active_starts = user_starts[user_starts < len(self._flat_items)]

            valid_mask[active_starts] = False
            self._valid_sample_indices = all_indices[valid_mask]

        if len(self._valid_sample_indices) == 0:
            raise ValueError(
                "No valid sequences found (min 2 interactions per user needed)."
            )

        dataset = SequentialDataset(
            flat_items=self._flat_items,
            flat_users=self._flat_users,
            user_offsets=self._user_offsets,
            valid_target_indices=self._valid_sample_indices,
            sparse_matrix=self._inter_sparse,
            max_seq_len=max_seq_len,
            neg_samples=neg_samples,
            niid=self._niid,
            include_user_id=include_user_id,
        )

        g = torch.Generator()
        g.manual_seed(seed)

        return DataLoader(
            dataset,
            batch_size=batch_size,
            shuffle=shuffle,
            worker_init_fn=seed_worker,
            generator=g,
            **kwargs,
        )

    def get_same_target_sequential_dataloader(
        self,
        max_seq_len: int,
        batch_size: int = 1024,
        shuffle: bool = True,
        seed: int = 42,
        low_memory: bool = False,
        **kwargs: Any,
    ) -> DataLoader:
        """Sequential dataloader that also samples a same-target positive sequence."""

        if self._valid_sample_indices is None:
            all_indices = np.arange(len(self._flat_items))
            valid_mask = np.ones(len(self._flat_items), dtype=bool)

            user_starts = self._user_offsets[:-1]
            active_starts = user_starts[user_starts < len(self._flat_items)]

            valid_mask[active_starts] = False
            self._valid_sample_indices = all_indices[valid_mask]

        if len(self._valid_sample_indices) == 0:
            raise ValueError(
                "No valid sequences found (min 2 interactions per user needed)."
            )

        dataset = SameTargetSequentialDataset(
            flat_items=self._flat_items,
            flat_users=self._flat_users,
            user_offsets=self._user_offsets,
            valid_target_indices=self._valid_sample_indices,
            max_seq_len=max_seq_len,
            niid=self._niid,
        )

        g = torch.Generator()
        g.manual_seed(seed)

        return DataLoader(
            dataset,
            batch_size=batch_size,
            shuffle=shuffle,
            worker_init_fn=seed_worker,
            generator=g,
            **kwargs,
        )

    def get_sliding_window_dataloader(
        self,
        max_seq_len: int,
        neg_samples: int,
        stride: int = 1,
        batch_size: int = 1024,
        shuffle: bool = True,
        seed: int = 42,
        **kwargs: Any,
    ) -> DataLoader:
        """Sequence-to-Sequence dataloader (Sliding Windows)."""

        # 1. Calculate Windows
        user_lens = np.diff(self._user_offsets)
        valid_users = np.where(user_lens >= 2)[0]

        if len(valid_users) == 0:
            raise ValueError("No valid sliding windows found.")

        valid_lens = user_lens[valid_users]
        valid_starts = self._user_offsets[valid_users]

        # Number of windows per user
        num_windows = (
            np.floor((np.maximum(valid_lens - max_seq_len, 0) / stride)).astype(int) + 1
        )
        total_samples = np.sum(num_windows)

        # 2. Map Dataset Index -> (User, Start_Index)
        # Repeat user IDs
        window_user_ids = np.repeat(valid_users, num_windows)

        # Calculate start indices
        # Cumulative count of windows to find offsets
        cum_windows = np.zeros(len(valid_users) + 1, dtype=int)
        cum_windows[1:] = np.cumsum(num_windows)

        indices = np.arange(total_samples)
        # Find which user block each index belongs to
        user_block_indices = np.searchsorted(cum_windows, indices, side="right") - 1

        # Local index within the user's windows (0, 1, 2...)
        local_window_idx = indices - cum_windows[user_block_indices]

        # Map back to flat array index
        # Start = UserStart + (WindowIndex * Stride)
        # Note: We use user_block_indices to index valid_starts because they align with valid_users
        window_starts_flat = valid_starts[user_block_indices] + (
            local_window_idx * stride
        )

        dataset = SlidingWindowDataset(
            flat_items=self._flat_items,
            window_starts=window_starts_flat.astype(np.int64),
            window_users=window_user_ids.astype(np.int64),
            sparse_matrix=self._inter_sparse,
            max_seq_len=max_seq_len,
            neg_samples=neg_samples,
            niid=self._niid,
        )

        g = torch.Generator()
        g.manual_seed(seed)

        return DataLoader(
            dataset,
            batch_size=batch_size,
            shuffle=shuffle,
            worker_init_fn=seed_worker,
            generator=g,
            **kwargs,
        )

    def get_cloze_mask_dataloader(
        self,
        max_seq_len: int,
        mask_prob: float,
        mask_token_id: int,
        neg_samples: int,
        batch_size: int = 1024,
        shuffle: bool = True,
        seed: int = 42,
        **kwargs: Any,
    ) -> DataLoader:
        """BERT4Rec style dataloader (Masked Language Modeling)."""

        user_lens = np.diff(self._user_offsets)
        valid_users = np.where(user_lens >= 2)[0]

        if len(valid_users) == 0:
            raise ValueError("No valid users with >= 2 interactions found.")

        # Use the last available window for Cloze task
        user_starts = self._user_offsets[valid_users]
        user_ends = self._user_offsets[valid_users + 1]

        # Start is at least (End - MaxLen)
        window_starts = np.maximum(user_starts, user_ends - max_seq_len)

        dataset = ClozeDataset(
            flat_items=self._flat_items,
            window_starts=window_starts.astype(np.int64),
            window_ends=user_ends.astype(np.int64),
            window_users=valid_users.astype(np.int64),
            sparse_matrix=self._inter_sparse,
            max_seq_len=max_seq_len,
            mask_prob=mask_prob,
            mask_token_id=mask_token_id,
            neg_samples=neg_samples,
            niid=self._niid,
            seed=seed,
        )

        g = torch.Generator()
        g.manual_seed(seed)

        return DataLoader(
            dataset,
            batch_size=batch_size,
            shuffle=shuffle,
            worker_init_fn=seed_worker,
            generator=g,
            **kwargs,
        )

get_cloze_mask_dataloader(max_seq_len, mask_prob, mask_token_id, neg_samples, batch_size=1024, shuffle=True, seed=42, **kwargs)

BERT4Rec style dataloader (Masked Language Modeling).

Source code in warprec/data/entities/sessions.py
def get_cloze_mask_dataloader(
    self,
    max_seq_len: int,
    mask_prob: float,
    mask_token_id: int,
    neg_samples: int,
    batch_size: int = 1024,
    shuffle: bool = True,
    seed: int = 42,
    **kwargs: Any,
) -> DataLoader:
    """BERT4Rec style dataloader (Masked Language Modeling)."""

    user_lens = np.diff(self._user_offsets)
    valid_users = np.where(user_lens >= 2)[0]

    if len(valid_users) == 0:
        raise ValueError("No valid users with >= 2 interactions found.")

    # Use the last available window for Cloze task
    user_starts = self._user_offsets[valid_users]
    user_ends = self._user_offsets[valid_users + 1]

    # Start is at least (End - MaxLen)
    window_starts = np.maximum(user_starts, user_ends - max_seq_len)

    dataset = ClozeDataset(
        flat_items=self._flat_items,
        window_starts=window_starts.astype(np.int64),
        window_ends=user_ends.astype(np.int64),
        window_users=valid_users.astype(np.int64),
        sparse_matrix=self._inter_sparse,
        max_seq_len=max_seq_len,
        mask_prob=mask_prob,
        mask_token_id=mask_token_id,
        neg_samples=neg_samples,
        niid=self._niid,
        seed=seed,
    )

    g = torch.Generator()
    g.manual_seed(seed)

    return DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        worker_init_fn=seed_worker,
        generator=g,
        **kwargs,
    )

get_same_target_sequential_dataloader(max_seq_len, batch_size=1024, shuffle=True, seed=42, low_memory=False, **kwargs)

Sequential dataloader that also samples a same-target positive sequence.

Source code in warprec/data/entities/sessions.py
def get_same_target_sequential_dataloader(
    self,
    max_seq_len: int,
    batch_size: int = 1024,
    shuffle: bool = True,
    seed: int = 42,
    low_memory: bool = False,
    **kwargs: Any,
) -> DataLoader:
    """Sequential dataloader that also samples a same-target positive sequence."""

    if self._valid_sample_indices is None:
        all_indices = np.arange(len(self._flat_items))
        valid_mask = np.ones(len(self._flat_items), dtype=bool)

        user_starts = self._user_offsets[:-1]
        active_starts = user_starts[user_starts < len(self._flat_items)]

        valid_mask[active_starts] = False
        self._valid_sample_indices = all_indices[valid_mask]

    if len(self._valid_sample_indices) == 0:
        raise ValueError(
            "No valid sequences found (min 2 interactions per user needed)."
        )

    dataset = SameTargetSequentialDataset(
        flat_items=self._flat_items,
        flat_users=self._flat_users,
        user_offsets=self._user_offsets,
        valid_target_indices=self._valid_sample_indices,
        max_seq_len=max_seq_len,
        niid=self._niid,
    )

    g = torch.Generator()
    g.manual_seed(seed)

    return DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        worker_init_fn=seed_worker,
        generator=g,
        **kwargs,
    )

get_sequential_dataloader(max_seq_len, neg_samples=0, include_user_id=False, batch_size=1024, shuffle=True, seed=42, **kwargs)

Standard SASRec/RNN style dataloader (History -> Next Item).

Source code in warprec/data/entities/sessions.py
def get_sequential_dataloader(
    self,
    max_seq_len: int,
    neg_samples: int = 0,
    include_user_id: bool = False,
    batch_size: int = 1024,
    shuffle: bool = True,
    seed: int = 42,
    **kwargs: Any,
) -> DataLoader:
    """Standard SASRec/RNN style dataloader (History -> Next Item)."""

    # Identify valid targets (items that have at least 1 predecessor)
    if self._valid_sample_indices is None:
        all_indices = np.arange(len(self._flat_items))
        valid_mask = np.ones(len(self._flat_items), dtype=bool)

        # The first item of any user cannot be a target (no history)
        user_starts = self._user_offsets[:-1]
        # Filter only starts that are within bounds (active users)
        active_starts = user_starts[user_starts < len(self._flat_items)]

        valid_mask[active_starts] = False
        self._valid_sample_indices = all_indices[valid_mask]

    if len(self._valid_sample_indices) == 0:
        raise ValueError(
            "No valid sequences found (min 2 interactions per user needed)."
        )

    dataset = SequentialDataset(
        flat_items=self._flat_items,
        flat_users=self._flat_users,
        user_offsets=self._user_offsets,
        valid_target_indices=self._valid_sample_indices,
        sparse_matrix=self._inter_sparse,
        max_seq_len=max_seq_len,
        neg_samples=neg_samples,
        niid=self._niid,
        include_user_id=include_user_id,
    )

    g = torch.Generator()
    g.manual_seed(seed)

    return DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        worker_init_fn=seed_worker,
        generator=g,
        **kwargs,
    )

get_sliding_window_dataloader(max_seq_len, neg_samples, stride=1, batch_size=1024, shuffle=True, seed=42, **kwargs)

Sequence-to-Sequence dataloader (Sliding Windows).

Source code in warprec/data/entities/sessions.py
def get_sliding_window_dataloader(
    self,
    max_seq_len: int,
    neg_samples: int,
    stride: int = 1,
    batch_size: int = 1024,
    shuffle: bool = True,
    seed: int = 42,
    **kwargs: Any,
) -> DataLoader:
    """Sequence-to-Sequence dataloader (Sliding Windows)."""

    # 1. Calculate Windows
    user_lens = np.diff(self._user_offsets)
    valid_users = np.where(user_lens >= 2)[0]

    if len(valid_users) == 0:
        raise ValueError("No valid sliding windows found.")

    valid_lens = user_lens[valid_users]
    valid_starts = self._user_offsets[valid_users]

    # Number of windows per user
    num_windows = (
        np.floor((np.maximum(valid_lens - max_seq_len, 0) / stride)).astype(int) + 1
    )
    total_samples = np.sum(num_windows)

    # 2. Map Dataset Index -> (User, Start_Index)
    # Repeat user IDs
    window_user_ids = np.repeat(valid_users, num_windows)

    # Calculate start indices
    # Cumulative count of windows to find offsets
    cum_windows = np.zeros(len(valid_users) + 1, dtype=int)
    cum_windows[1:] = np.cumsum(num_windows)

    indices = np.arange(total_samples)
    # Find which user block each index belongs to
    user_block_indices = np.searchsorted(cum_windows, indices, side="right") - 1

    # Local index within the user's windows (0, 1, 2...)
    local_window_idx = indices - cum_windows[user_block_indices]

    # Map back to flat array index
    # Start = UserStart + (WindowIndex * Stride)
    # Note: We use user_block_indices to index valid_starts because they align with valid_users
    window_starts_flat = valid_starts[user_block_indices] + (
        local_window_idx * stride
    )

    dataset = SlidingWindowDataset(
        flat_items=self._flat_items,
        window_starts=window_starts_flat.astype(np.int64),
        window_users=window_user_ids.astype(np.int64),
        sparse_matrix=self._inter_sparse,
        max_seq_len=max_seq_len,
        neg_samples=neg_samples,
        niid=self._niid,
    )

    g = torch.Generator()
    g.manual_seed(seed)

    return DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        worker_init_fn=seed_worker,
        generator=g,
        **kwargs,
    )

get_user_history_sequences(user_ids, max_seq_len)

Retrieves padded historical sequences for inference/evaluation.

Source code in warprec/data/entities/sessions.py
def get_user_history_sequences(
    self, user_ids: List[int], max_seq_len: int
) -> Tuple[Tensor, Tensor]:
    """Retrieves padded historical sequences for inference/evaluation."""
    if not self._cached_user_histories:
        # Build dict cache on demand
        # Using split on the flat array is faster than iterating DF
        starts = self._user_offsets[:-1]
        ends = self._user_offsets[1:]
        # Only for users that actually exist in data
        valid_u = np.where(ends > starts)[0]
        self._cached_user_histories = {
            u: self._flat_items[starts[u] : ends[u]].tolist() for u in valid_u
        }

    seqs, lens = [], []
    for uid in user_ids:
        hist = self._cached_user_histories.get(uid, [])
        recent = hist[-max_seq_len:]
        seqs.append(torch.tensor(recent, dtype=torch.long))
        lens.append(len(recent))

    return (
        pad_sequence(seqs, batch_first=True, padding_value=self._niid),
        torch.tensor(lens, dtype=torch.long),
    )

warprec.data.entities.train_structures.session_structures.SequentialDataset

Bases: Dataset

Standard Sequential dataset.

Sampled Output: (Sequence, Length, PosTarget, [NegTarget])

Source code in warprec/data/entities/train_structures/session_structures.py
class SequentialDataset(Dataset):
    """
    Standard Sequential dataset.

    Sampled Output: (Sequence, Length, PosTarget, [NegTarget])
    """

    def __init__(
        self,
        flat_items: np.ndarray,
        flat_users: np.ndarray,
        user_offsets: np.ndarray,
        valid_target_indices: np.ndarray,
        sparse_matrix: csr_matrix,
        max_seq_len: int,
        neg_samples: int,
        niid: int,
        include_user_id: bool = False,
    ):
        self.flat_items = flat_items
        self.flat_users = flat_users
        self.user_offsets = user_offsets
        self.sample_indices = valid_target_indices
        self.sparse_matrix = sparse_matrix
        self.max_seq_len = max_seq_len
        self.neg_samples = neg_samples
        self.niid = niid
        self.include_user_id = include_user_id
        self.padding_token = niid

    def __len__(self) -> int:
        return len(self.sample_indices)

    def __getitem__(self, idx: int) -> Tuple[Tensor, ...]:
        target_flat_idx = self.sample_indices[idx]
        user_idx = self.flat_users[target_flat_idx]
        pos_item = self.flat_items[target_flat_idx]

        # History ends just before target
        user_start_idx = self.user_offsets[user_idx]
        seq_end_idx = target_flat_idx
        seq_start_idx = max(user_start_idx, seq_end_idx - self.max_seq_len)

        seq_array = self.flat_items[seq_start_idx:seq_end_idx].copy()
        seq_len = len(seq_array)

        # Pad Sequence (Left-aligned data)
        seq_tensor = torch.full(
            (self.max_seq_len,), self.padding_token, dtype=torch.long
        )
        seq_tensor[:seq_len] = torch.from_numpy(seq_array)

        ret = [
            seq_tensor,
            torch.tensor(seq_len, dtype=torch.long),
            torch.tensor(pos_item, dtype=torch.long),
        ]

        if self.include_user_id:
            ret.insert(0, torch.tensor(user_idx, dtype=torch.long))

        # Negative Sampling
        if self.neg_samples > 0:
            neg_items: list[Any] = []
            u_start = self.sparse_matrix.indptr[user_idx]
            u_end = self.sparse_matrix.indptr[user_idx + 1]
            seen_items = self.sparse_matrix.indices[u_start:u_end]

            while len(neg_items) < self.neg_samples:
                cand = np.random.randint(0, self.niid)
                # Fast check on sorted CSR indices
                idx_ins = np.searchsorted(seen_items, cand)
                if idx_ins < len(seen_items) and seen_items[idx_ins] == cand:
                    continue
                if cand == pos_item:
                    continue
                neg_items.append(cand)

            neg_tensor = torch.tensor(neg_items, dtype=torch.long)
            ret.append(neg_tensor)

        return tuple(ret)

warprec.data.entities.train_structures.session_structures.SlidingWindowDataset

Bases: Dataset

Dataset for Sequence-to-Sequence training.

Sampled Output: (InputSequence, NegativeSamplesMatrix)

Source code in warprec/data/entities/train_structures/session_structures.py
class SlidingWindowDataset(Dataset):
    """
    Dataset for Sequence-to-Sequence training.

    Sampled Output: (InputSequence, NegativeSamplesMatrix)
    """

    def __init__(
        self,
        flat_items: np.ndarray,
        window_starts: np.ndarray,
        window_users: np.ndarray,
        sparse_matrix: csr_matrix,
        max_seq_len: int,
        neg_samples: int,
        niid: int,
    ):
        self.flat_items = flat_items
        self.window_starts = window_starts
        self.window_users = window_users
        self.sparse_matrix = sparse_matrix
        self.max_seq_len = max_seq_len
        self.neg_samples = neg_samples
        self.niid = niid
        self.padding_token = niid

    def __len__(self) -> int:
        return len(self.window_starts)

    def __getitem__(self, idx: int) -> Tuple[Tensor, ...]:
        start_idx = self.window_starts[idx]
        user_idx = self.window_users[idx]

        # Clip end to user boundary
        user_end_limit = self.sparse_matrix.indptr[user_idx + 1]
        end_idx = min(start_idx + self.max_seq_len, user_end_limit)

        seq_array = self.flat_items[start_idx:end_idx]
        real_len = len(seq_array)

        # Input Sequence
        pos_seq = torch.full((self.max_seq_len,), self.padding_token, dtype=torch.long)
        pos_seq[:real_len] = torch.from_numpy(seq_array)

        if self.neg_samples > 0:
            neg_seq = torch.full(
                (self.max_seq_len, self.neg_samples),
                self.padding_token,
                dtype=torch.long,
            )

            u_start = self.sparse_matrix.indptr[user_idx]
            u_end = self.sparse_matrix.indptr[user_idx + 1]
            seen_items = self.sparse_matrix.indices[u_start:u_end]

            # Generate negatives for each valid time step
            for t in range(real_len):
                found = 0
                while found < self.neg_samples:
                    needed = self.neg_samples - found
                    candidates = np.random.randint(0, self.niid, size=needed)

                    # Vectorized check
                    idxs = np.searchsorted(seen_items, candidates)
                    idxs = np.clip(idxs, 0, len(seen_items) - 1)
                    is_seen = seen_items[idxs] == candidates

                    valid_cands = candidates[~is_seen]

                    num_valid = len(valid_cands)
                    if num_valid > 0:
                        neg_seq[t, found : found + num_valid] = torch.from_numpy(
                            valid_cands
                        )
                        found += num_valid

            return pos_seq, neg_seq

        return (pos_seq,)

warprec.data.entities.train_structures.session_structures.ClozeDataset

Bases: Dataset

Dataset for Cloze Mask training.

Sampled Output: (MaskedSeq, PosItems, NegItems, MaskedIndices)

Source code in warprec/data/entities/train_structures/session_structures.py
class ClozeDataset(Dataset):
    """
    Dataset for Cloze Mask training.

    Sampled Output: (MaskedSeq, PosItems, NegItems, MaskedIndices)
    """

    def __init__(
        self,
        flat_items: np.ndarray,
        window_starts: np.ndarray,
        window_ends: np.ndarray,
        window_users: np.ndarray,
        sparse_matrix: csr_matrix,
        max_seq_len: int,
        mask_prob: float,
        mask_token_id: int,
        neg_samples: int,
        niid: int,
        seed: int = 42,
    ):
        self.flat_items = flat_items
        self.window_starts = window_starts
        self.window_ends = window_ends
        self.window_users = window_users
        self.sparse_matrix = sparse_matrix
        self.max_seq_len = max_seq_len
        self.mask_prob = mask_prob
        self.mask_token_id = mask_token_id
        self.neg_samples = neg_samples
        self.niid = niid
        self.padding_token = niid
        self.rng = np.random.default_rng(seed)

    def __len__(self) -> int:
        return len(self.window_starts)

    def __getitem__(self, idx: int) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
        start = self.window_starts[idx]
        end = self.window_ends[idx]
        user_idx = self.window_users[idx]

        seq_array = self.flat_items[start:end].copy()
        real_seq_len = len(seq_array)

        # Masking Logic
        num_to_mask = max(1, int(real_seq_len * self.mask_prob))
        masked_indices = np.random.choice(real_seq_len, size=num_to_mask, replace=False)

        pos_targets = seq_array[masked_indices]
        seq_array[masked_indices] = self.mask_token_id

        # Negative Sampling (Only for masked items)
        neg_targets = np.full(
            (num_to_mask, self.neg_samples), self.padding_token, dtype=np.int64
        )

        if self.neg_samples > 0:
            u_start = self.sparse_matrix.indptr[user_idx]
            u_end = self.sparse_matrix.indptr[user_idx + 1]
            seen_items = self.sparse_matrix.indices[u_start:u_end]

            for i in range(num_to_mask):
                true_item = pos_targets[i]
                found_count = 0
                while found_count < self.neg_samples:
                    cand = np.random.randint(0, self.niid)
                    if cand == true_item:
                        continue

                    idx_ins = np.searchsorted(seen_items, cand)
                    if idx_ins < len(seen_items) and seen_items[idx_ins] == cand:
                        continue

                    neg_targets[i, found_count] = cand
                    found_count += 1

        # Tensor Construction (Compacted/Dense format for targets)

        # Masked Sequence [Item, Mask, Pad...]
        masked_seq_tensor = torch.full(
            (self.max_seq_len,), self.padding_token, dtype=torch.long
        )
        masked_seq_tensor[:real_seq_len] = torch.from_numpy(seq_array)

        # Positive Items [Target1, Target2, Pad...]
        pos_items_tensor = torch.full(
            (self.max_seq_len,), self.padding_token, dtype=torch.long
        )
        pos_items_tensor[:num_to_mask] = torch.from_numpy(pos_targets)

        # Negative Items [Negs1, Negs2, Pad...]
        neg_items_tensor = torch.full(
            (self.max_seq_len, self.neg_samples), self.padding_token, dtype=torch.long
        )
        neg_items_tensor[:num_to_mask, :] = torch.from_numpy(neg_targets)

        # Masked Indices [Idx1, Idx2, 0, 0...]
        masked_indices_tensor = torch.zeros(self.max_seq_len, dtype=torch.long)
        masked_indices_tensor[:num_to_mask] = torch.from_numpy(masked_indices)

        return (
            masked_seq_tensor,
            pos_items_tensor,
            neg_items_tensor,
            masked_indices_tensor,
        )