Edit on GitHub

shap_relativities

SHAP Relativities - actuarial-grade multiplicative relativities from tree models.

Converts a trained GBM's SHAP values into the same format as GLM exp(beta) relativities: a table of (feature, level, relativity) triples where the base level is 1.0 and relativities multiply together to give the model's expected prediction.

Typical usage

>>> from shap_relativities import SHAPRelativities
>>> sr = SHAPRelativities(model, X, exposure=df["exposure"],
...                       categorical_features=["area", "ncd_years"])
>>> sr.fit()
>>> rels = sr.extract_relativities(
...     normalise_to="base_level",
...     base_levels={"area": "A", "ncd_years": 0},
... )

Or use the convenience wrapper for one-liners:

>>> from shap_relativities import extract_relativities
>>> rels = extract_relativities(model, X, exposure=df["exposure"],
...                             categorical_features=["area"])
 1"""
 2SHAP Relativities - actuarial-grade multiplicative relativities from tree models.
 3
 4Converts a trained GBM's SHAP values into the same format as GLM exp(beta)
 5relativities: a table of (feature, level, relativity) triples where the base
 6level is 1.0 and relativities multiply together to give the model's expected
 7prediction.
 8
 9Typical usage
10-------------
11>>> from shap_relativities import SHAPRelativities
12>>> sr = SHAPRelativities(model, X, exposure=df["exposure"],
13...                       categorical_features=["area", "ncd_years"])
14>>> sr.fit()
15>>> rels = sr.extract_relativities(
16...     normalise_to="base_level",
17...     base_levels={"area": "A", "ncd_years": 0},
18... )
19
20Or use the convenience wrapper for one-liners:
21
22>>> from shap_relativities import extract_relativities
23>>> rels = extract_relativities(model, X, exposure=df["exposure"],
24...                             categorical_features=["area"])
25"""
26
27from __future__ import annotations
28
29from typing import Any
30
31import polars as pl
32
33from ._core import SHAPRelativities
34
35__all__ = ["SHAPRelativities", "extract_relativities"]
36
37__version__ = "0.2.3"
38
39
40def extract_relativities(
41    model: Any,
42    X: Any,
43    exposure: Any = None,
44    categorical_features: list[str] | None = None,
45    base_levels: dict[str, str | float | int] | None = None,
46    ci_method: str = "clt",
47) -> pl.DataFrame:
48    """
49    One-shot extraction of SHAP relativities from a tree model.
50
51    Wraps SHAPRelativities.fit() and extract_relativities() for cases where
52    you don't need the intermediate object.
53
54    Args:
55        model: Trained CatBoost model with a log-link objective (Poisson,
56            Tweedie, or Gamma). CatBoost is the recommended choice - it handles
57            categorical features natively without encoding.
58        X: Feature matrix. Accepts a Polars or pandas DataFrame. Polars is
59            preferred; pandas is accepted and converted internally.
60        exposure: Earned policy years. If None, all observations are equally
61            weighted.
62        categorical_features: Features to aggregate by level. If None, all
63            non-numeric columns are treated as categorical.
64        base_levels: Base level for each categorical feature (gets
65            relativity = 1.0).
66        ci_method: "clt" (default) or "none".
67
68    Returns:
69        Polars DataFrame with columns: feature, level, relativity, lower_ci,
70        upper_ci, mean_shap, shap_std, n_obs, exposure_weight.
71    """
72    sr = SHAPRelativities(model, X, exposure, categorical_features)
73    sr.fit()
74    return sr.extract_relativities(base_levels=base_levels, ci_method=ci_method)
class SHAPRelativities:
 87class SHAPRelativities:
 88    """
 89    Extract multiplicative rating relativities from a tree model via SHAP.
 90
 91    Workflow::
 92
 93        sr = SHAPRelativities(
 94            model=catboost_model,
 95            X=df.select(["area", "ncd_years", "has_convictions"]),
 96            exposure=df["exposure"],
 97            categorical_features=["area", "ncd_years"],
 98        )
 99        sr.fit()
100        rels = sr.extract_relativities(
101            normalise_to="base_level",
102            base_levels={"area": "A", "ncd_years": 0},
103        )
104
105    Args:
106        model: A trained CatBoost model. Must use a log-link objective (Poisson,
107            Tweedie, Gamma). CatBoost is the recommended default - it handles
108            categoricals natively.
109        X: Feature matrix. Use training data for in-sample relativities, or a
110            representative holdout sample for out-of-sample. Polars DataFrames
111            are preferred; pandas DataFrames are accepted and converted
112            internally.
113        exposure: Earned policy years (or other volume measure). Used as
114            observation weights throughout. If None, all observations are
115            weighted equally.
116        categorical_features: Features to aggregate by level (bar-chart style).
117            If None, all non-numeric columns are treated as categorical.
118        continuous_features: Features to leave as per-observation points.
119            If None, all numeric columns are treated as continuous.
120        feature_perturbation: "tree_path_dependent" (default, fast, no
121            background data needed) or "interventional" (corrects for feature
122            correlation, needs background_data).
123        background_data: Required only if feature_perturbation="interventional".
124        n_background_samples: Number of background samples for interventional
125            SHAP. Default 1000.
126        annualise_exposure: If True and exposure is provided, subtract mean
127            log(exposure) from the expected_value to give an annualised
128            baseline. Default True.
129    """
130
131    def __init__(
132        self,
133        model: Any,
134        X: Any,
135        exposure: Any = None,
136        categorical_features: list[str] | None = None,
137        continuous_features: list[str] | None = None,
138        background_data: Any = None,
139        feature_perturbation: str = "tree_path_dependent",
140        n_background_samples: int = 1000,
141        annualise_exposure: bool = True,
142    ) -> None:
143        if not _SHAP_AVAILABLE:
144            raise ImportError(
145                "shap is required for SHAPRelativities. "
146                "Install it with: uv add 'shap-relativities[ml]'"
147            )
148
149        self._model = model
150        self._X: pl.DataFrame = _to_polars(X)
151        self._background_data = (
152            _to_polars(background_data) if background_data is not None else None
153        )
154
155        # Normalise exposure to a numpy array
156        if exposure is None:
157            self._exposure: np.ndarray | None = None
158        elif isinstance(exposure, np.ndarray):
159            self._exposure = exposure
160        elif isinstance(exposure, pl.Series):
161            self._exposure = exposure.to_numpy()
162        else:
163            # pd.Series or similar
164            self._exposure = np.asarray(exposure)
165
166        self._feature_perturbation = feature_perturbation
167        self._n_background_samples = n_background_samples
168        self._annualise_exposure = annualise_exposure
169
170        # Classify features
171        self._categorical_features = categorical_features or self._infer_categorical()
172        self._continuous_features = continuous_features or self._infer_continuous()
173
174        # Populated by fit()
175        self._shap_values: np.ndarray | None = None
176        self._expected_value: float | None = None
177        self._is_fitted: bool = False
178
179    def _infer_categorical(self) -> list[str]:
180        numeric_types = (
181            pl.Int8, pl.Int16, pl.Int32, pl.Int64,
182            pl.UInt8, pl.UInt16, pl.UInt32, pl.UInt64,
183            pl.Float32, pl.Float64,
184        )
185        return [
186            c for c in self._X.columns
187            if not isinstance(self._X[c].dtype, numeric_types)
188        ]
189
190    def _infer_continuous(self) -> list[str]:
191        numeric_types = (
192            pl.Int8, pl.Int16, pl.Int32, pl.Int64,
193            pl.UInt8, pl.UInt16, pl.UInt32, pl.UInt64,
194            pl.Float32, pl.Float64,
195        )
196        return [
197            c for c in self._X.columns
198            if isinstance(self._X[c].dtype, numeric_types)
199            and c not in (self._categorical_features or [])
200        ]
201
202    def fit(self) -> "SHAPRelativities":
203        """
204        Compute SHAP values for all features in X.
205
206        Must be called before extract_relativities(). Calling fit() again
207        recomputes SHAP values (e.g. after changing X or the background data).
208
209        The feature matrix is converted to pandas internally for shap's
210        TreeExplainer. The conversion is a necessary bridge - shap requires
211        pandas for column name handling.
212
213        Returns:
214            Self, for method chaining.
215        """
216        # Convert to pandas for shap - unavoidable bridge
217        X_pd = _to_pandas(self._X)
218
219        bg_data = None
220        if self._feature_perturbation == "interventional":
221            if self._background_data is not None:
222                bg_data = _to_pandas(self._background_data)
223            else:
224                n_bg = min(self._n_background_samples, len(X_pd))
225                bg_data = shap.sample(X_pd, n_bg)
226
227        explainer = shap.TreeExplainer(
228            self._model,
229            data=bg_data,
230            feature_perturbation=self._feature_perturbation,
231            model_output="raw",
232        )
233
234        raw = explainer.shap_values(X_pd)
235
236        # Some models return a list when there is a single output
237        if isinstance(raw, list):
238            if len(raw) == 1:
239                raw = raw[0]
240            else:
241                raise ValueError(
242                    f"Model has {len(raw)} outputs. SHAPRelativities supports "
243                    "single-output models only."
244                )
245
246        self._shap_values = raw
247
248        ev = explainer.expected_value
249        if isinstance(ev, (list, np.ndarray)):
250            ev = float(ev[0])
251        self._expected_value = float(ev)
252
253        self._is_fitted = True
254        return self
255
256    def _check_fitted(self) -> None:
257        if not self._is_fitted:
258            raise RuntimeError("Call fit() before using this method.")
259
260    def shap_values(self) -> np.ndarray:
261        """
262        Raw SHAP values, shape (n_obs, n_features), in log space.
263
264        Returns:
265            Array of shape (n_obs, n_features).
266        """
267        self._check_fitted()
268        return self._shap_values  # type: ignore[return-value]
269
270    def baseline(self) -> float:
271        """
272        exp(expected_value) - the base rate in prediction space.
273
274        If annualise_exposure=True and exposure was provided, this is adjusted
275        for the average log-exposure offset so it represents an annualised rate.
276
277        Returns:
278            Base rate as a float.
279        """
280        self._check_fitted()
281        ev = self._expected_value  # type: ignore[assignment]
282
283        if self._annualise_exposure and self._exposure is not None:
284            mean_log_exp = float(np.mean(np.log(np.clip(self._exposure, 1e-9, None))))
285            ev = ev - mean_log_exp
286
287        return float(np.exp(ev))
288
289    def extract_relativities(
290        self,
291        normalise_to: str = "base_level",
292        base_levels: dict[str, str | float | int] | None = None,
293        ci_method: str = "clt",
294        n_bootstrap: int = 200,
295        ci_level: float = 0.95,
296    ) -> pl.DataFrame:
297        """
298        Extract multiplicative relativities from SHAP values.
299
300        Args:
301            normalise_to: "base_level" (base level for each feature gets
302                relativity = 1.0) or "mean" (exposure-weighted portfolio
303                mean = 1.0).
304            base_levels: Mapping of feature -> base level value. Required for
305                categorical features when normalise_to="base_level". Continuous
306                features automatically use mean normalisation regardless of
307                this setting.
308            ci_method: "clt" (CLT approximation, default, fast) or "none" (no
309                CIs). "bootstrap" is not yet implemented.
310            n_bootstrap: Ignored unless ci_method="bootstrap".
311            ci_level: Two-sided confidence level. Default 0.95.
312
313        Returns:
314            Polars DataFrame with columns: feature, level, relativity,
315            lower_ci, upper_ci, mean_shap, shap_std, n_obs, exposure_weight.
316            One row per (feature, level) combination.
317        """
318        self._check_fitted()
319
320        if ci_method == "bootstrap":
321            raise NotImplementedError(
322                "Bootstrap CIs are not yet implemented. Use ci_method='clt'."
323            )
324
325        base_levels = base_levels or {}
326        weights = (
327            self._exposure if self._exposure is not None
328            else np.ones(len(self._X))
329        )
330
331        feature_names = self._X.columns
332        shap_vals = self._shap_values  # type: ignore[assignment]
333
334        parts: list[pl.DataFrame] = []
335
336        for i, feat in enumerate(feature_names):
337            feat_vals = self._X[feat].to_numpy()
338            shap_col = shap_vals[:, i]
339
340            is_categorical = feat in self._categorical_features
341
342            if is_categorical:
343                agg = aggregate_categorical(feat, feat_vals, shap_col, weights)
344            else:
345                agg = aggregate_continuous(feat, feat_vals, shap_col, weights)
346
347            # Normalisation
348            if normalise_to == "base_level" and is_categorical:
349                base = base_levels.get(feat)
350                if base is None:
351                    # Fall back to the level with the smallest mean_shap as
352                    # a sensible default (closest to intercept)
353                    base = agg.sort("mean_shap")["level"][0]
354                    warnings.warn(
355                        f"No base level specified for '{feat}'. "
356                        f"Using '{base}' (lowest mean SHAP) as base.",
357                        UserWarning,
358                        stacklevel=2,
359                    )
360
361                if ci_method == "none":
362                    base_key = str(base)
363                    base_rows = agg.filter(pl.col("level") == base_key)
364                    base_shap = base_rows["mean_shap"][0]
365                    agg = agg.with_columns([
366                        (pl.col("mean_shap") - base_shap).exp().alias("relativity"),
367                        pl.lit(float("nan")).alias("lower_ci"),
368                        pl.lit(float("nan")).alias("upper_ci"),
369                    ])
370                else:
371                    agg = normalise_base_level(agg, base, ci_level=ci_level)
372
373            else:
374                # Mean normalisation for continuous features, or when
375                # normalise_to="mean" for any feature
376                if ci_method == "none":
377                    total_weight = agg["exposure_weight"].sum()
378                    portfolio_mean = float(
379                        (agg["mean_shap"] * agg["exposure_weight"]).sum()
380                        / total_weight
381                    ) if total_weight > 0 else 0.0
382                    agg = agg.with_columns([
383                        (pl.col("mean_shap") - portfolio_mean).exp().alias("relativity"),
384                        pl.lit(float("nan")).alias("lower_ci"),
385                        pl.lit(float("nan")).alias("upper_ci"),
386                    ])
387                else:
388                    agg = normalise_mean(agg, ci_level=ci_level)
389
390            parts.append(agg)
391
392        # Cast the 'level' column to Utf8 in every part before concat.
393        # Categorical features produce level as Utf8; continuous features
394        # produce level as Float64. pl.concat with how="diagonal" cannot
395        # unify mismatched types for the same column name, so we normalise
396        # here rather than requiring callers to pre-cast their feature columns.
397        parts = [
398            p.with_columns(pl.col("level").cast(pl.Utf8))
399            if "level" in p.columns else p
400            for p in parts
401        ]
402
403        result = pl.concat(parts, how="diagonal")
404
405        # Ensure standard column order (wsq_weight is internal, not exported)
406        available = [c for c in _RELATIVITY_COLUMNS if c in result.columns]
407        return result.select(available)
408
409    def extract_continuous_curve(
410        self,
411        feature: str,
412        n_points: int = 100,
413        smooth_method: str = "loess",
414    ) -> pl.DataFrame:
415        """
416        Smoothed relativity curve for a continuous feature.
417
418        Args:
419            feature: Feature name. Must be in continuous_features.
420            n_points: Number of points in the output curve (not the input
421                data).
422            smooth_method: "loess" (locally weighted regression, requires
423                statsmodels), "isotonic" (monotone curve via isotonic
424                regression), or "none" (raw per-observation relativities).
425
426        Returns:
427            Polars DataFrame with columns: feature_value, relativity,
428            lower_ci, upper_ci.
429
430        Raises:
431            ValueError: If feature is not in X or smooth_method is unknown.
432        """
433        self._check_fitted()
434
435        if feature not in self._X.columns:
436            raise ValueError(f"Feature '{feature}' not in X.")
437
438        feat_idx = self._X.columns.index(feature)
439        feat_vals = self._X[feature].to_numpy().astype(float)
440        shap_col = self._shap_values[:, feat_idx]  # type: ignore[index]
441        weights = (
442            self._exposure if self._exposure is not None
443            else np.ones(len(self._X))
444        )
445
446        # Exposure-weighted mean over the actual data distribution
447        portfolio_mean = np.average(shap_col, weights=weights)
448        relativities = np.exp(shap_col - portfolio_mean)
449
450        grid = np.linspace(feat_vals.min(), feat_vals.max(), n_points)
451
452        if smooth_method == "none":
453            order = np.argsort(feat_vals)
454            return pl.DataFrame({
455                "feature_value": feat_vals[order],
456                "relativity": relativities[order],
457                "lower_ci": np.full(len(feat_vals), float("nan")),
458                "upper_ci": np.full(len(feat_vals), float("nan")),
459            })
460
461        elif smooth_method == "isotonic":
462            from sklearn.isotonic import IsotonicRegression
463            ir = IsotonicRegression(out_of_bounds="clip")
464            ir.fit(feat_vals, shap_col, sample_weight=weights)
465            smoothed_shap = ir.predict(grid)
466
467            # P1-4 fix: normalise the smoothed curve so the exposure-weighted
468            # geometric mean of relativities = 1.0.
469            # The smooth is on the data, evaluated on a uniform grid. Subtracting
470            # portfolio_mean (computed on the data distribution) would be correct
471            # only if the grid were distributed like the data — it isn't.
472            # Instead, compute the data-distribution-weighted mean of the smoothed
473            # curve at the original data points, then use that as the reference.
474            smoothed_at_data = ir.predict(feat_vals)
475            weighted_mean_smoothed = np.average(smoothed_at_data, weights=weights)
476            smoothed_rel = np.exp(smoothed_shap - weighted_mean_smoothed)
477
478            return pl.DataFrame({
479                "feature_value": grid,
480                "relativity": smoothed_rel,
481                "lower_ci": np.full(n_points, float("nan")),
482                "upper_ci": np.full(n_points, float("nan")),
483            })
484
485        elif smooth_method == "loess":
486            try:
487                from statsmodels.nonparametric.smoothers_lowess import lowess
488
489                smoothed_shap = lowess(
490                    shap_col, feat_vals, frac=0.3, it=3,
491                    xvals=grid, is_sorted=False,
492                )
493
494                # P1-4 fix: compute the smoothed values at original data points
495                # so the normalisation is data-distribution-weighted, not
496                # grid-uniform. Use the same lowess parameters.
497                smoothed_at_data = lowess(
498                    shap_col, feat_vals, frac=0.3, it=3,
499                    xvals=feat_vals, is_sorted=False,
500                )
501                weighted_mean_smoothed = np.average(smoothed_at_data, weights=weights)
502                smoothed_rel = np.exp(smoothed_shap - weighted_mean_smoothed)
503
504                return pl.DataFrame({
505                    "feature_value": grid,
506                    "relativity": smoothed_rel,
507                    "lower_ci": np.full(n_points, float("nan")),
508                    "upper_ci": np.full(n_points, float("nan")),
509                })
510            except ImportError:
511                warnings.warn(
512                    "statsmodels not installed; falling back to smooth_method='none'.",
513                    UserWarning,
514                    stacklevel=2,
515                )
516                return self.extract_continuous_curve(
517                    feature, n_points=n_points, smooth_method="none"
518                )
519
520        else:
521            raise ValueError(
522                f"Unknown smooth_method '{smooth_method}'. "
523                "Choose from: 'loess', 'isotonic', 'none'."
524            )
525
526    def validate(self) -> dict[str, CheckResult]:
527        """
528        Run diagnostic checks on the SHAP computation.
529
530        Checks performed:
531
532        1. reconstruction: exp(shap.sum(1) + expected_value) should match
533           model predictions within tolerance. Material failure here indicates
534           the explainer was set up incorrectly.
535
536        2. feature_coverage: every feature in X should appear in the SHAP
537           output. Currently always passes given TreeExplainer's API.
538
539        3. sparse_levels: warns if any categorical level has fewer than 30
540           observations. CLT CIs will be unreliable for these levels.
541
542        Returns:
543            Dict with keys "reconstruction", "feature_coverage",
544            "sparse_levels". Each value is a CheckResult(passed, value,
545            message).
546        """
547        self._check_fitted()
548
549        X_pd = _to_pandas(self._X)
550
551        # Get model predictions for reconstruction check
552        preds = None
553        if self._model is not None:
554            try:
555                preds = self._model.predict(X_pd)
556            except Exception:
557                preds = None
558
559        results: dict[str, CheckResult] = {}
560
561        if preds is not None:
562            results["reconstruction"] = check_reconstruction(
563                self._shap_values,  # type: ignore[arg-type]
564                self._expected_value,  # type: ignore[arg-type]
565                preds,
566                tolerance=1e-4,
567            )
568        else:
569            results["reconstruction"] = CheckResult(
570                passed=False,
571                value=float("nan"),
572                message="Could not obtain model predictions for reconstruction check.",
573            )
574
575        feature_names = self._X.columns
576        results["feature_coverage"] = check_feature_coverage(
577            feature_names, feature_names
578        )
579
580        # Check sparse levels for categorical features
581        weights = (
582            self._exposure if self._exposure is not None
583            else np.ones(len(self._X))
584        )
585
586        sparse_parts: list[pl.DataFrame] = []
587        for feat in self._categorical_features:
588            if feat not in self._X.columns:
589                continue
590            feat_idx = self._X.columns.index(feat)
591            agg = aggregate_categorical(
592                feat,
593                self._X[feat].to_numpy(),
594                self._shap_values[:, feat_idx],  # type: ignore[index]
595                weights,
596            )
597            sparse_parts.append(agg)
598
599        if sparse_parts:
600            all_agg = pl.concat(sparse_parts, how="diagonal")
601            results["sparse_levels"] = check_sparse_levels(all_agg)
602        else:
603            results["sparse_levels"] = CheckResult(
604                passed=True, value=0.0,
605                message="No categorical features to check."
606            )
607
608        return results
609
610    def plot_relativities(
611        self,
612        features: list[str] | None = None,
613        show_ci: bool = True,
614        figsize: tuple[int, int] = (12, 8),
615    ) -> None:
616        """
617        Plot relativities as bar charts (categorical) or line charts (continuous).
618
619        Args:
620            features: Subset of features to plot. Defaults to all features.
621            show_ci: Whether to show confidence intervals. Default True.
622            figsize: Overall figure size.
623        """
624        self._check_fitted()
625
626        from ._plotting import plot_relativities as _plot
627
628        rels = self.extract_relativities()
629        _plot(
630            rels,
631            categorical_features=self._categorical_features,
632            continuous_features=self._continuous_features,
633            features=features,
634            show_ci=show_ci,
635            figsize=figsize,
636        )
637
638    def to_dict(self) -> dict[str, Any]:
639        """
640        Serialisable representation of the fitted object.
641
642        Stores SHAP values, expected value, feature names, and feature
643        classification. Does not store the original model or X DataFrame.
644
645        Returns:
646            Dict suitable for JSON serialisation.
647        """
648        self._check_fitted()
649        return {
650            "shap_values": self._shap_values.tolist(),  # type: ignore[union-attr]
651            "expected_value": self._expected_value,
652            "feature_names": self._X.columns,
653            "categorical_features": self._categorical_features,
654            "continuous_features": self._continuous_features,
655            "X_values": {c: self._X[c].to_list() for c in self._X.columns},
656            "exposure": (
657                self._exposure.tolist()
658                if self._exposure is not None else None
659            ),
660            "annualise_exposure": self._annualise_exposure,
661        }
662
663    @classmethod
664    def from_dict(cls, data: dict[str, Any]) -> "SHAPRelativities":
665        """
666        Reconstruct a fitted SHAPRelativities from to_dict() output.
667
668        The reconstructed object has no model attached, so validate() and
669        plot_relativities() still work but fit() cannot be re-run.
670
671        Args:
672            data: Output of to_dict().
673
674        Returns:
675            Fitted SHAPRelativities instance.
676        """
677        # P1-2 fix: use feature_names to control column ordering in X.
678        # Without this, tools that sort JSON object keys (REST APIs, some
679        # pretty-printers) reorder X_values, misaligning columns with the
680        # shap_values matrix columns.
681        feature_names: list[str] = data.get("feature_names", list(data["X_values"].keys()))
682        X = pl.DataFrame({k: data["X_values"][k] for k in feature_names})
683
684        exposure = (
685            np.array(data["exposure"]) if data.get("exposure") is not None
686            else None
687        )
688
689        # Create a minimal instance without a real model
690        instance = cls.__new__(cls)
691        instance._model = None
692        instance._X = X
693        instance._exposure = exposure
694        instance._categorical_features = data.get("categorical_features", [])
695        instance._continuous_features = data.get("continuous_features", [])
696        instance._feature_perturbation = "tree_path_dependent"
697        instance._background_data = None
698        instance._n_background_samples = 1000
699        instance._annualise_exposure = data.get("annualise_exposure", True)
700        instance._shap_values = np.array(data["shap_values"])
701        instance._expected_value = float(data["expected_value"])
702        instance._is_fitted = True
703
704        return instance

Extract multiplicative rating relativities from a tree model via SHAP.

Workflow::

sr = SHAPRelativities(
    model=catboost_model,
    X=df.select(["area", "ncd_years", "has_convictions"]),
    exposure=df["exposure"],
    categorical_features=["area", "ncd_years"],
)
sr.fit()
rels = sr.extract_relativities(
    normalise_to="base_level",
    base_levels={"area": "A", "ncd_years": 0},
)
Arguments:
  • model: A trained CatBoost model. Must use a log-link objective (Poisson, Tweedie, Gamma). CatBoost is the recommended default - it handles categoricals natively.
  • X: Feature matrix. Use training data for in-sample relativities, or a representative holdout sample for out-of-sample. Polars DataFrames are preferred; pandas DataFrames are accepted and converted internally.
  • exposure: Earned policy years (or other volume measure). Used as observation weights throughout. If None, all observations are weighted equally.
  • categorical_features: Features to aggregate by level (bar-chart style). If None, all non-numeric columns are treated as categorical.
  • continuous_features: Features to leave as per-observation points. If None, all numeric columns are treated as continuous.
  • feature_perturbation: "tree_path_dependent" (default, fast, no background data needed) or "interventional" (corrects for feature correlation, needs background_data).
  • background_data: Required only if feature_perturbation="interventional".
  • n_background_samples: Number of background samples for interventional SHAP. Default 1000.
  • annualise_exposure: If True and exposure is provided, subtract mean log(exposure) from the expected_value to give an annualised baseline. Default True.
SHAPRelativities( model: Any, X: Any, exposure: Any = None, categorical_features: list[str] | None = None, continuous_features: list[str] | None = None, background_data: Any = None, feature_perturbation: str = 'tree_path_dependent', n_background_samples: int = 1000, annualise_exposure: bool = True)
131    def __init__(
132        self,
133        model: Any,
134        X: Any,
135        exposure: Any = None,
136        categorical_features: list[str] | None = None,
137        continuous_features: list[str] | None = None,
138        background_data: Any = None,
139        feature_perturbation: str = "tree_path_dependent",
140        n_background_samples: int = 1000,
141        annualise_exposure: bool = True,
142    ) -> None:
143        if not _SHAP_AVAILABLE:
144            raise ImportError(
145                "shap is required for SHAPRelativities. "
146                "Install it with: uv add 'shap-relativities[ml]'"
147            )
148
149        self._model = model
150        self._X: pl.DataFrame = _to_polars(X)
151        self._background_data = (
152            _to_polars(background_data) if background_data is not None else None
153        )
154
155        # Normalise exposure to a numpy array
156        if exposure is None:
157            self._exposure: np.ndarray | None = None
158        elif isinstance(exposure, np.ndarray):
159            self._exposure = exposure
160        elif isinstance(exposure, pl.Series):
161            self._exposure = exposure.to_numpy()
162        else:
163            # pd.Series or similar
164            self._exposure = np.asarray(exposure)
165
166        self._feature_perturbation = feature_perturbation
167        self._n_background_samples = n_background_samples
168        self._annualise_exposure = annualise_exposure
169
170        # Classify features
171        self._categorical_features = categorical_features or self._infer_categorical()
172        self._continuous_features = continuous_features or self._infer_continuous()
173
174        # Populated by fit()
175        self._shap_values: np.ndarray | None = None
176        self._expected_value: float | None = None
177        self._is_fitted: bool = False
def fit(self) -> SHAPRelativities:
202    def fit(self) -> "SHAPRelativities":
203        """
204        Compute SHAP values for all features in X.
205
206        Must be called before extract_relativities(). Calling fit() again
207        recomputes SHAP values (e.g. after changing X or the background data).
208
209        The feature matrix is converted to pandas internally for shap's
210        TreeExplainer. The conversion is a necessary bridge - shap requires
211        pandas for column name handling.
212
213        Returns:
214            Self, for method chaining.
215        """
216        # Convert to pandas for shap - unavoidable bridge
217        X_pd = _to_pandas(self._X)
218
219        bg_data = None
220        if self._feature_perturbation == "interventional":
221            if self._background_data is not None:
222                bg_data = _to_pandas(self._background_data)
223            else:
224                n_bg = min(self._n_background_samples, len(X_pd))
225                bg_data = shap.sample(X_pd, n_bg)
226
227        explainer = shap.TreeExplainer(
228            self._model,
229            data=bg_data,
230            feature_perturbation=self._feature_perturbation,
231            model_output="raw",
232        )
233
234        raw = explainer.shap_values(X_pd)
235
236        # Some models return a list when there is a single output
237        if isinstance(raw, list):
238            if len(raw) == 1:
239                raw = raw[0]
240            else:
241                raise ValueError(
242                    f"Model has {len(raw)} outputs. SHAPRelativities supports "
243                    "single-output models only."
244                )
245
246        self._shap_values = raw
247
248        ev = explainer.expected_value
249        if isinstance(ev, (list, np.ndarray)):
250            ev = float(ev[0])
251        self._expected_value = float(ev)
252
253        self._is_fitted = True
254        return self

Compute SHAP values for all features in X.

Must be called before extract_relativities(). Calling fit() again recomputes SHAP values (e.g. after changing X or the background data).

The feature matrix is converted to pandas internally for shap's TreeExplainer. The conversion is a necessary bridge - shap requires pandas for column name handling.

Returns:

Self, for method chaining.

def shap_values(self) -> numpy.ndarray:
260    def shap_values(self) -> np.ndarray:
261        """
262        Raw SHAP values, shape (n_obs, n_features), in log space.
263
264        Returns:
265            Array of shape (n_obs, n_features).
266        """
267        self._check_fitted()
268        return self._shap_values  # type: ignore[return-value]

Raw SHAP values, shape (n_obs, n_features), in log space.

Returns:

Array of shape (n_obs, n_features).

def baseline(self) -> float:
270    def baseline(self) -> float:
271        """
272        exp(expected_value) - the base rate in prediction space.
273
274        If annualise_exposure=True and exposure was provided, this is adjusted
275        for the average log-exposure offset so it represents an annualised rate.
276
277        Returns:
278            Base rate as a float.
279        """
280        self._check_fitted()
281        ev = self._expected_value  # type: ignore[assignment]
282
283        if self._annualise_exposure and self._exposure is not None:
284            mean_log_exp = float(np.mean(np.log(np.clip(self._exposure, 1e-9, None))))
285            ev = ev - mean_log_exp
286
287        return float(np.exp(ev))

exp(expected_value) - the base rate in prediction space.

If annualise_exposure=True and exposure was provided, this is adjusted for the average log-exposure offset so it represents an annualised rate.

Returns:

Base rate as a float.

def extract_relativities( self, normalise_to: str = 'base_level', base_levels: dict[str, str | float | int] | None = None, ci_method: str = 'clt', n_bootstrap: int = 200, ci_level: float = 0.95) -> polars.dataframe.frame.DataFrame:
289    def extract_relativities(
290        self,
291        normalise_to: str = "base_level",
292        base_levels: dict[str, str | float | int] | None = None,
293        ci_method: str = "clt",
294        n_bootstrap: int = 200,
295        ci_level: float = 0.95,
296    ) -> pl.DataFrame:
297        """
298        Extract multiplicative relativities from SHAP values.
299
300        Args:
301            normalise_to: "base_level" (base level for each feature gets
302                relativity = 1.0) or "mean" (exposure-weighted portfolio
303                mean = 1.0).
304            base_levels: Mapping of feature -> base level value. Required for
305                categorical features when normalise_to="base_level". Continuous
306                features automatically use mean normalisation regardless of
307                this setting.
308            ci_method: "clt" (CLT approximation, default, fast) or "none" (no
309                CIs). "bootstrap" is not yet implemented.
310            n_bootstrap: Ignored unless ci_method="bootstrap".
311            ci_level: Two-sided confidence level. Default 0.95.
312
313        Returns:
314            Polars DataFrame with columns: feature, level, relativity,
315            lower_ci, upper_ci, mean_shap, shap_std, n_obs, exposure_weight.
316            One row per (feature, level) combination.
317        """
318        self._check_fitted()
319
320        if ci_method == "bootstrap":
321            raise NotImplementedError(
322                "Bootstrap CIs are not yet implemented. Use ci_method='clt'."
323            )
324
325        base_levels = base_levels or {}
326        weights = (
327            self._exposure if self._exposure is not None
328            else np.ones(len(self._X))
329        )
330
331        feature_names = self._X.columns
332        shap_vals = self._shap_values  # type: ignore[assignment]
333
334        parts: list[pl.DataFrame] = []
335
336        for i, feat in enumerate(feature_names):
337            feat_vals = self._X[feat].to_numpy()
338            shap_col = shap_vals[:, i]
339
340            is_categorical = feat in self._categorical_features
341
342            if is_categorical:
343                agg = aggregate_categorical(feat, feat_vals, shap_col, weights)
344            else:
345                agg = aggregate_continuous(feat, feat_vals, shap_col, weights)
346
347            # Normalisation
348            if normalise_to == "base_level" and is_categorical:
349                base = base_levels.get(feat)
350                if base is None:
351                    # Fall back to the level with the smallest mean_shap as
352                    # a sensible default (closest to intercept)
353                    base = agg.sort("mean_shap")["level"][0]
354                    warnings.warn(
355                        f"No base level specified for '{feat}'. "
356                        f"Using '{base}' (lowest mean SHAP) as base.",
357                        UserWarning,
358                        stacklevel=2,
359                    )
360
361                if ci_method == "none":
362                    base_key = str(base)
363                    base_rows = agg.filter(pl.col("level") == base_key)
364                    base_shap = base_rows["mean_shap"][0]
365                    agg = agg.with_columns([
366                        (pl.col("mean_shap") - base_shap).exp().alias("relativity"),
367                        pl.lit(float("nan")).alias("lower_ci"),
368                        pl.lit(float("nan")).alias("upper_ci"),
369                    ])
370                else:
371                    agg = normalise_base_level(agg, base, ci_level=ci_level)
372
373            else:
374                # Mean normalisation for continuous features, or when
375                # normalise_to="mean" for any feature
376                if ci_method == "none":
377                    total_weight = agg["exposure_weight"].sum()
378                    portfolio_mean = float(
379                        (agg["mean_shap"] * agg["exposure_weight"]).sum()
380                        / total_weight
381                    ) if total_weight > 0 else 0.0
382                    agg = agg.with_columns([
383                        (pl.col("mean_shap") - portfolio_mean).exp().alias("relativity"),
384                        pl.lit(float("nan")).alias("lower_ci"),
385                        pl.lit(float("nan")).alias("upper_ci"),
386                    ])
387                else:
388                    agg = normalise_mean(agg, ci_level=ci_level)
389
390            parts.append(agg)
391
392        # Cast the 'level' column to Utf8 in every part before concat.
393        # Categorical features produce level as Utf8; continuous features
394        # produce level as Float64. pl.concat with how="diagonal" cannot
395        # unify mismatched types for the same column name, so we normalise
396        # here rather than requiring callers to pre-cast their feature columns.
397        parts = [
398            p.with_columns(pl.col("level").cast(pl.Utf8))
399            if "level" in p.columns else p
400            for p in parts
401        ]
402
403        result = pl.concat(parts, how="diagonal")
404
405        # Ensure standard column order (wsq_weight is internal, not exported)
406        available = [c for c in _RELATIVITY_COLUMNS if c in result.columns]
407        return result.select(available)

Extract multiplicative relativities from SHAP values.

Arguments:
  • normalise_to: "base_level" (base level for each feature gets relativity = 1.0) or "mean" (exposure-weighted portfolio mean = 1.0).
  • base_levels: Mapping of feature -> base level value. Required for categorical features when normalise_to="base_level". Continuous features automatically use mean normalisation regardless of this setting.
  • ci_method: "clt" (CLT approximation, default, fast) or "none" (no CIs). "bootstrap" is not yet implemented.
  • n_bootstrap: Ignored unless ci_method="bootstrap".
  • ci_level: Two-sided confidence level. Default 0.95.
Returns:

Polars DataFrame with columns: feature, level, relativity, lower_ci, upper_ci, mean_shap, shap_std, n_obs, exposure_weight. One row per (feature, level) combination.

def extract_continuous_curve( self, feature: str, n_points: int = 100, smooth_method: str = 'loess') -> polars.dataframe.frame.DataFrame:
409    def extract_continuous_curve(
410        self,
411        feature: str,
412        n_points: int = 100,
413        smooth_method: str = "loess",
414    ) -> pl.DataFrame:
415        """
416        Smoothed relativity curve for a continuous feature.
417
418        Args:
419            feature: Feature name. Must be in continuous_features.
420            n_points: Number of points in the output curve (not the input
421                data).
422            smooth_method: "loess" (locally weighted regression, requires
423                statsmodels), "isotonic" (monotone curve via isotonic
424                regression), or "none" (raw per-observation relativities).
425
426        Returns:
427            Polars DataFrame with columns: feature_value, relativity,
428            lower_ci, upper_ci.
429
430        Raises:
431            ValueError: If feature is not in X or smooth_method is unknown.
432        """
433        self._check_fitted()
434
435        if feature not in self._X.columns:
436            raise ValueError(f"Feature '{feature}' not in X.")
437
438        feat_idx = self._X.columns.index(feature)
439        feat_vals = self._X[feature].to_numpy().astype(float)
440        shap_col = self._shap_values[:, feat_idx]  # type: ignore[index]
441        weights = (
442            self._exposure if self._exposure is not None
443            else np.ones(len(self._X))
444        )
445
446        # Exposure-weighted mean over the actual data distribution
447        portfolio_mean = np.average(shap_col, weights=weights)
448        relativities = np.exp(shap_col - portfolio_mean)
449
450        grid = np.linspace(feat_vals.min(), feat_vals.max(), n_points)
451
452        if smooth_method == "none":
453            order = np.argsort(feat_vals)
454            return pl.DataFrame({
455                "feature_value": feat_vals[order],
456                "relativity": relativities[order],
457                "lower_ci": np.full(len(feat_vals), float("nan")),
458                "upper_ci": np.full(len(feat_vals), float("nan")),
459            })
460
461        elif smooth_method == "isotonic":
462            from sklearn.isotonic import IsotonicRegression
463            ir = IsotonicRegression(out_of_bounds="clip")
464            ir.fit(feat_vals, shap_col, sample_weight=weights)
465            smoothed_shap = ir.predict(grid)
466
467            # P1-4 fix: normalise the smoothed curve so the exposure-weighted
468            # geometric mean of relativities = 1.0.
469            # The smooth is on the data, evaluated on a uniform grid. Subtracting
470            # portfolio_mean (computed on the data distribution) would be correct
471            # only if the grid were distributed like the data — it isn't.
472            # Instead, compute the data-distribution-weighted mean of the smoothed
473            # curve at the original data points, then use that as the reference.
474            smoothed_at_data = ir.predict(feat_vals)
475            weighted_mean_smoothed = np.average(smoothed_at_data, weights=weights)
476            smoothed_rel = np.exp(smoothed_shap - weighted_mean_smoothed)
477
478            return pl.DataFrame({
479                "feature_value": grid,
480                "relativity": smoothed_rel,
481                "lower_ci": np.full(n_points, float("nan")),
482                "upper_ci": np.full(n_points, float("nan")),
483            })
484
485        elif smooth_method == "loess":
486            try:
487                from statsmodels.nonparametric.smoothers_lowess import lowess
488
489                smoothed_shap = lowess(
490                    shap_col, feat_vals, frac=0.3, it=3,
491                    xvals=grid, is_sorted=False,
492                )
493
494                # P1-4 fix: compute the smoothed values at original data points
495                # so the normalisation is data-distribution-weighted, not
496                # grid-uniform. Use the same lowess parameters.
497                smoothed_at_data = lowess(
498                    shap_col, feat_vals, frac=0.3, it=3,
499                    xvals=feat_vals, is_sorted=False,
500                )
501                weighted_mean_smoothed = np.average(smoothed_at_data, weights=weights)
502                smoothed_rel = np.exp(smoothed_shap - weighted_mean_smoothed)
503
504                return pl.DataFrame({
505                    "feature_value": grid,
506                    "relativity": smoothed_rel,
507                    "lower_ci": np.full(n_points, float("nan")),
508                    "upper_ci": np.full(n_points, float("nan")),
509                })
510            except ImportError:
511                warnings.warn(
512                    "statsmodels not installed; falling back to smooth_method='none'.",
513                    UserWarning,
514                    stacklevel=2,
515                )
516                return self.extract_continuous_curve(
517                    feature, n_points=n_points, smooth_method="none"
518                )
519
520        else:
521            raise ValueError(
522                f"Unknown smooth_method '{smooth_method}'. "
523                "Choose from: 'loess', 'isotonic', 'none'."
524            )

Smoothed relativity curve for a continuous feature.

Arguments:
  • feature: Feature name. Must be in continuous_features.
  • n_points: Number of points in the output curve (not the input data).
  • smooth_method: "loess" (locally weighted regression, requires statsmodels), "isotonic" (monotone curve via isotonic regression), or "none" (raw per-observation relativities).
Returns:

Polars DataFrame with columns: feature_value, relativity, lower_ci, upper_ci.

Raises:
  • ValueError: If feature is not in X or smooth_method is unknown.
def validate(self) -> dict[str, shap_relativities._validation.CheckResult]:
526    def validate(self) -> dict[str, CheckResult]:
527        """
528        Run diagnostic checks on the SHAP computation.
529
530        Checks performed:
531
532        1. reconstruction: exp(shap.sum(1) + expected_value) should match
533           model predictions within tolerance. Material failure here indicates
534           the explainer was set up incorrectly.
535
536        2. feature_coverage: every feature in X should appear in the SHAP
537           output. Currently always passes given TreeExplainer's API.
538
539        3. sparse_levels: warns if any categorical level has fewer than 30
540           observations. CLT CIs will be unreliable for these levels.
541
542        Returns:
543            Dict with keys "reconstruction", "feature_coverage",
544            "sparse_levels". Each value is a CheckResult(passed, value,
545            message).
546        """
547        self._check_fitted()
548
549        X_pd = _to_pandas(self._X)
550
551        # Get model predictions for reconstruction check
552        preds = None
553        if self._model is not None:
554            try:
555                preds = self._model.predict(X_pd)
556            except Exception:
557                preds = None
558
559        results: dict[str, CheckResult] = {}
560
561        if preds is not None:
562            results["reconstruction"] = check_reconstruction(
563                self._shap_values,  # type: ignore[arg-type]
564                self._expected_value,  # type: ignore[arg-type]
565                preds,
566                tolerance=1e-4,
567            )
568        else:
569            results["reconstruction"] = CheckResult(
570                passed=False,
571                value=float("nan"),
572                message="Could not obtain model predictions for reconstruction check.",
573            )
574
575        feature_names = self._X.columns
576        results["feature_coverage"] = check_feature_coverage(
577            feature_names, feature_names
578        )
579
580        # Check sparse levels for categorical features
581        weights = (
582            self._exposure if self._exposure is not None
583            else np.ones(len(self._X))
584        )
585
586        sparse_parts: list[pl.DataFrame] = []
587        for feat in self._categorical_features:
588            if feat not in self._X.columns:
589                continue
590            feat_idx = self._X.columns.index(feat)
591            agg = aggregate_categorical(
592                feat,
593                self._X[feat].to_numpy(),
594                self._shap_values[:, feat_idx],  # type: ignore[index]
595                weights,
596            )
597            sparse_parts.append(agg)
598
599        if sparse_parts:
600            all_agg = pl.concat(sparse_parts, how="diagonal")
601            results["sparse_levels"] = check_sparse_levels(all_agg)
602        else:
603            results["sparse_levels"] = CheckResult(
604                passed=True, value=0.0,
605                message="No categorical features to check."
606            )
607
608        return results

Run diagnostic checks on the SHAP computation.

Checks performed:

  1. reconstruction: exp(shap.sum(1) + expected_value) should match model predictions within tolerance. Material failure here indicates the explainer was set up incorrectly.

  2. feature_coverage: every feature in X should appear in the SHAP output. Currently always passes given TreeExplainer's API.

  3. sparse_levels: warns if any categorical level has fewer than 30 observations. CLT CIs will be unreliable for these levels.

Returns:

Dict with keys "reconstruction", "feature_coverage", "sparse_levels". Each value is a CheckResult(passed, value, message).

def plot_relativities( self, features: list[str] | None = None, show_ci: bool = True, figsize: tuple[int, int] = (12, 8)) -> None:
610    def plot_relativities(
611        self,
612        features: list[str] | None = None,
613        show_ci: bool = True,
614        figsize: tuple[int, int] = (12, 8),
615    ) -> None:
616        """
617        Plot relativities as bar charts (categorical) or line charts (continuous).
618
619        Args:
620            features: Subset of features to plot. Defaults to all features.
621            show_ci: Whether to show confidence intervals. Default True.
622            figsize: Overall figure size.
623        """
624        self._check_fitted()
625
626        from ._plotting import plot_relativities as _plot
627
628        rels = self.extract_relativities()
629        _plot(
630            rels,
631            categorical_features=self._categorical_features,
632            continuous_features=self._continuous_features,
633            features=features,
634            show_ci=show_ci,
635            figsize=figsize,
636        )

Plot relativities as bar charts (categorical) or line charts (continuous).

Arguments:
  • features: Subset of features to plot. Defaults to all features.
  • show_ci: Whether to show confidence intervals. Default True.
  • figsize: Overall figure size.
def to_dict(self) -> dict[str, typing.Any]:
638    def to_dict(self) -> dict[str, Any]:
639        """
640        Serialisable representation of the fitted object.
641
642        Stores SHAP values, expected value, feature names, and feature
643        classification. Does not store the original model or X DataFrame.
644
645        Returns:
646            Dict suitable for JSON serialisation.
647        """
648        self._check_fitted()
649        return {
650            "shap_values": self._shap_values.tolist(),  # type: ignore[union-attr]
651            "expected_value": self._expected_value,
652            "feature_names": self._X.columns,
653            "categorical_features": self._categorical_features,
654            "continuous_features": self._continuous_features,
655            "X_values": {c: self._X[c].to_list() for c in self._X.columns},
656            "exposure": (
657                self._exposure.tolist()
658                if self._exposure is not None else None
659            ),
660            "annualise_exposure": self._annualise_exposure,
661        }

Serialisable representation of the fitted object.

Stores SHAP values, expected value, feature names, and feature classification. Does not store the original model or X DataFrame.

Returns:

Dict suitable for JSON serialisation.

@classmethod
def from_dict( cls, data: dict[str, typing.Any]) -> SHAPRelativities:
663    @classmethod
664    def from_dict(cls, data: dict[str, Any]) -> "SHAPRelativities":
665        """
666        Reconstruct a fitted SHAPRelativities from to_dict() output.
667
668        The reconstructed object has no model attached, so validate() and
669        plot_relativities() still work but fit() cannot be re-run.
670
671        Args:
672            data: Output of to_dict().
673
674        Returns:
675            Fitted SHAPRelativities instance.
676        """
677        # P1-2 fix: use feature_names to control column ordering in X.
678        # Without this, tools that sort JSON object keys (REST APIs, some
679        # pretty-printers) reorder X_values, misaligning columns with the
680        # shap_values matrix columns.
681        feature_names: list[str] = data.get("feature_names", list(data["X_values"].keys()))
682        X = pl.DataFrame({k: data["X_values"][k] for k in feature_names})
683
684        exposure = (
685            np.array(data["exposure"]) if data.get("exposure") is not None
686            else None
687        )
688
689        # Create a minimal instance without a real model
690        instance = cls.__new__(cls)
691        instance._model = None
692        instance._X = X
693        instance._exposure = exposure
694        instance._categorical_features = data.get("categorical_features", [])
695        instance._continuous_features = data.get("continuous_features", [])
696        instance._feature_perturbation = "tree_path_dependent"
697        instance._background_data = None
698        instance._n_background_samples = 1000
699        instance._annualise_exposure = data.get("annualise_exposure", True)
700        instance._shap_values = np.array(data["shap_values"])
701        instance._expected_value = float(data["expected_value"])
702        instance._is_fitted = True
703
704        return instance

Reconstruct a fitted SHAPRelativities from to_dict() output.

The reconstructed object has no model attached, so validate() and plot_relativities() still work but fit() cannot be re-run.

Arguments:
  • data: Output of to_dict().
Returns:

Fitted SHAPRelativities instance.

def extract_relativities( model: Any, X: Any, exposure: Any = None, categorical_features: list[str] | None = None, base_levels: dict[str, str | float | int] | None = None, ci_method: str = 'clt') -> polars.dataframe.frame.DataFrame:
41def extract_relativities(
42    model: Any,
43    X: Any,
44    exposure: Any = None,
45    categorical_features: list[str] | None = None,
46    base_levels: dict[str, str | float | int] | None = None,
47    ci_method: str = "clt",
48) -> pl.DataFrame:
49    """
50    One-shot extraction of SHAP relativities from a tree model.
51
52    Wraps SHAPRelativities.fit() and extract_relativities() for cases where
53    you don't need the intermediate object.
54
55    Args:
56        model: Trained CatBoost model with a log-link objective (Poisson,
57            Tweedie, or Gamma). CatBoost is the recommended choice - it handles
58            categorical features natively without encoding.
59        X: Feature matrix. Accepts a Polars or pandas DataFrame. Polars is
60            preferred; pandas is accepted and converted internally.
61        exposure: Earned policy years. If None, all observations are equally
62            weighted.
63        categorical_features: Features to aggregate by level. If None, all
64            non-numeric columns are treated as categorical.
65        base_levels: Base level for each categorical feature (gets
66            relativity = 1.0).
67        ci_method: "clt" (default) or "none".
68
69    Returns:
70        Polars DataFrame with columns: feature, level, relativity, lower_ci,
71        upper_ci, mean_shap, shap_std, n_obs, exposure_weight.
72    """
73    sr = SHAPRelativities(model, X, exposure, categorical_features)
74    sr.fit()
75    return sr.extract_relativities(base_levels=base_levels, ci_method=ci_method)

One-shot extraction of SHAP relativities from a tree model.

Wraps SHAPRelativities.fit() and extract_relativities() for cases where you don't need the intermediate object.

Arguments:
  • model: Trained CatBoost model with a log-link objective (Poisson, Tweedie, or Gamma). CatBoost is the recommended choice - it handles categorical features natively without encoding.
  • X: Feature matrix. Accepts a Polars or pandas DataFrame. Polars is preferred; pandas is accepted and converted internally.
  • exposure: Earned policy years. If None, all observations are equally weighted.
  • categorical_features: Features to aggregate by level. If None, all non-numeric columns are treated as categorical.
  • base_levels: Base level for each categorical feature (gets relativity = 1.0).
  • ci_method: "clt" (default) or "none".
Returns:

Polars DataFrame with columns: feature, level, relativity, lower_ci, upper_ci, mean_shap, shap_std, n_obs, exposure_weight.