Skip to content

Rating Metrics - API Reference

Auto-generated documentation for rating metric classes.

warprec.evaluation.metrics.rating.mae.MAE

Bases: RatingMetric

Mean Absolute Error (MAE) metric.

This metric computes the average absolute difference between the predictions and targets.

Source code in warprec/evaluation/metrics/rating/mae.py
@metric_registry.register("MAE")
class MAE(RatingMetric):
    """Mean Absolute Error (MAE) metric.

    This metric computes the average absolute difference between the predictions and targets.
    """

    def _compute_element_error(self, preds: Tensor, target: Tensor) -> Tensor:
        return torch.abs(preds - target)

warprec.evaluation.metrics.rating.mse.MSE

Bases: RatingMetric

Mean Squared Error (MSE) metric.

This metric computes the average squared difference between the predictions and targets.

Source code in warprec/evaluation/metrics/rating/mse.py
@metric_registry.register("MSE")
class MSE(RatingMetric):
    """Mean Squared Error (MSE) metric.

    This metric computes the average squared difference between the predictions and targets.
    """

    def _compute_element_error(self, preds: Tensor, target: Tensor) -> Tensor:
        return (preds - target) ** 2

warprec.evaluation.metrics.rating.rmse.RMSE

Bases: MSE

Root Mean Squared Error (RMSE) metric.

This metric computes the square root of the average squared difference between the predictions and targets.

Source code in warprec/evaluation/metrics/rating/rmse.py
@metric_registry.register("RMSE")
class RMSE(MSE):
    """Root Mean Squared Error (RMSE) metric.

    This metric computes the square root of the average squared difference between the predictions and targets.
    """

    def compute(self):
        # Get the MSE per user
        mse = super().compute()[self.name]

        # Apply sqrt to the tensor
        rmse = torch.sqrt(mse)

        return {self.name: rmse}