Edit on GitHub

shap_relativities

SHAP Relativities - actuarial-grade multiplicative relativities from tree models.

Converts a trained GBM's SHAP values into the same format as GLM exp(beta) relativities: a table of (feature, level, relativity) triples where the base level is 1.0 and relativities multiply together to give the model's expected prediction.

Typical usage

>>> from shap_relativities import SHAPRelativities
>>> sr = SHAPRelativities(model, X, exposure=df["exposure"],
...                       categorical_features=["area", "ncd_years"])
>>> sr.fit()
>>> rels = sr.extract_relativities(
...     normalise_to="base_level",
...     base_levels={"area": "A", "ncd_years": 0},
... )

For statistically valid SHAP importance CIs (v0.5.0+):

>>> from shap_relativities import SHAPInference
>>> si = SHAPInference(shap_values, y, feature_names=["age", "ncd", "area"])
>>> si.fit()
>>> si.importance_table()

Or use the convenience wrapper for one-liners:

>>> from shap_relativities import extract_relativities
>>> rels = extract_relativities(model, X, exposure=df["exposure"],
...                             categorical_features=["area"])
 1"""
 2SHAP Relativities - actuarial-grade multiplicative relativities from tree models.
 3
 4Converts a trained GBM's SHAP values into the same format as GLM exp(beta)
 5relativities: a table of (feature, level, relativity) triples where the base
 6level is 1.0 and relativities multiply together to give the model's expected
 7prediction.
 8
 9Typical usage
10-------------
11>>> from shap_relativities import SHAPRelativities
12>>> sr = SHAPRelativities(model, X, exposure=df["exposure"],
13...                       categorical_features=["area", "ncd_years"])
14>>> sr.fit()
15>>> rels = sr.extract_relativities(
16...     normalise_to="base_level",
17...     base_levels={"area": "A", "ncd_years": 0},
18... )
19
20For statistically valid SHAP importance CIs (v0.5.0+):
21
22>>> from shap_relativities import SHAPInference
23>>> si = SHAPInference(shap_values, y, feature_names=["age", "ncd", "area"])
24>>> si.fit()
25>>> si.importance_table()
26
27Or use the convenience wrapper for one-liners:
28
29>>> from shap_relativities import extract_relativities
30>>> rels = extract_relativities(model, X, exposure=df["exposure"],
31...                             categorical_features=["area"])
32"""
33
34from __future__ import annotations
35
36from typing import Any
37
38import polars as pl
39
40from ._core import SHAPRelativities
41from ._inference import SHAPInference
42
43__all__ = ["SHAPRelativities", "SHAPInference", "extract_relativities"]
44
45from importlib.metadata import version, PackageNotFoundError
46
47try:
48    __version__ = version("shap-relativities")
49except PackageNotFoundError:
50    __version__ = "0.0.0"  # not installed
51
52
53def extract_relativities(
54    model: Any,
55    X: Any,
56    exposure: Any = None,
57    categorical_features: list[str] | None = None,
58    base_levels: dict[str, str | float | int] | None = None,
59    ci_method: str = "clt",
60) -> pl.DataFrame:
61    """
62    One-shot extraction of SHAP relativities from a tree model.
63
64    Wraps SHAPRelativities.fit() and extract_relativities() for cases where
65    you don't need the intermediate object.
66
67    Args:
68        model: Trained CatBoost model with a log-link objective (Poisson,
69            Tweedie, or Gamma). CatBoost is the recommended choice - it handles
70            categorical features natively without encoding.
71        X: Feature matrix. Accepts a Polars or pandas DataFrame. Polars is
72            preferred; pandas is accepted and converted internally.
73        exposure: Earned policy years. If None, all observations are equally
74            weighted.
75        categorical_features: Features to aggregate by level. If None, all
76            non-numeric columns are treated as categorical.
77        base_levels: Base level for each categorical feature (gets
78            relativity = 1.0).
79        ci_method: "clt" (default) or "none".
80
81    Returns:
82        Polars DataFrame with columns: feature, level, relativity, lower_ci,
83        upper_ci, mean_shap, shap_std, n_obs, exposure_weight.
84    """
85    sr = SHAPRelativities(model, X, exposure, categorical_features)
86    sr.fit()
87    return sr.extract_relativities(base_levels=base_levels, ci_method=ci_method)
class SHAPRelativities:
154class SHAPRelativities:
155    """
156    Extract multiplicative rating relativities from a tree model via SHAP.
157
158    Workflow::
159
160        sr = SHAPRelativities(
161            model=catboost_model,
162            X=df.select(["area", "ncd_years", "has_convictions"]),
163            exposure=df["exposure"],
164            categorical_features=["area", "ncd_years"],
165        )
166        sr.fit()
167        rels = sr.extract_relativities(
168            normalise_to="base_level",
169            base_levels={"area": "A", "ncd_years": 0},
170        )
171
172    Args:
173        model: A trained CatBoost model. Must use a log-link objective (Poisson,
174            Tweedie, Gamma). CatBoost is the recommended default - it handles
175            categoricals natively.
176        X: Feature matrix. Use training data for in-sample relativities, or a
177            representative holdout sample for out-of-sample. Polars DataFrames
178            are preferred; pandas DataFrames are accepted and converted
179            internally.
180        exposure: Earned policy years (or other volume measure). Used as
181            observation weights throughout. If None, all observations are
182            weighted equally.
183        categorical_features: Features to aggregate by level (bar-chart style).
184            If None, all non-numeric columns are treated as categorical.
185        continuous_features: Features to leave as per-observation points.
186            If None, all numeric columns are treated as continuous.
187        feature_perturbation: "tree_path_dependent" (default, fast, no
188            background data needed) or "interventional" (corrects for feature
189            correlation, needs background_data).
190        background_data: Required only if feature_perturbation="interventional".
191        n_background_samples: Number of background samples for interventional
192            SHAP. Default 1000.
193        annualise_exposure: If True and exposure is provided, subtract mean
194            log(exposure) from the expected_value to give an annualised
195            baseline. Default True.
196        verbose: If True (default), show a progress indicator during fit().
197            Set to False to suppress all output — useful in batch jobs,
198            notebooks with tqdm already present, or when calling fit()
199            many times. If tqdm is not installed, falls back to a plain
200            elapsed-time message.
201    """
202
203    def __init__(
204        self,
205        model: Any,
206        X: Any,
207        exposure: Any = None,
208        categorical_features: list[str] | None = None,
209        continuous_features: list[str] | None = None,
210        background_data: Any = None,
211        feature_perturbation: str = "tree_path_dependent",
212        n_background_samples: int = 1000,
213        annualise_exposure: bool = True,
214        verbose: bool = True,
215    ) -> None:
216        if not _SHAP_AVAILABLE:
217            raise ImportError(
218                "shap is required for SHAPRelativities. "
219                "Install it with: uv add 'shap-relativities[ml]'"
220            )
221
222        self._model = model
223        self._X: pl.DataFrame = _to_polars(X)
224        self._background_data = (
225            _to_polars(background_data) if background_data is not None else None
226        )
227
228        # Normalise exposure to a numpy array
229        if exposure is None:
230            self._exposure: np.ndarray | None = None
231        elif isinstance(exposure, np.ndarray):
232            self._exposure = exposure
233        elif isinstance(exposure, pl.Series):
234            self._exposure = exposure.to_numpy()
235        else:
236            # pd.Series or similar
237            self._exposure = np.asarray(exposure)
238
239        # Validate exposure length matches X
240        if self._exposure is not None and len(self._exposure) != len(self._X):
241            raise ValueError(
242                f"exposure length ({len(self._exposure)}) does not match "
243                f"X length ({len(self._X)}). Both must have the same number of rows."
244            )
245
246        self._feature_perturbation = feature_perturbation
247        self._n_background_samples = n_background_samples
248        self._annualise_exposure = annualise_exposure
249        self._verbose = verbose
250
251        # Classify features
252        self._categorical_features = categorical_features or self._infer_categorical()
253        self._continuous_features = continuous_features or self._infer_continuous()
254
255        # Populated by fit()
256        self._shap_values: np.ndarray | None = None
257        self._expected_value: float | None = None
258        self._is_fitted: bool = False
259
260    def _infer_categorical(self) -> list[str]:
261        numeric_types = (
262            pl.Int8, pl.Int16, pl.Int32, pl.Int64,
263            pl.UInt8, pl.UInt16, pl.UInt32, pl.UInt64,
264            pl.Float32, pl.Float64,
265        )
266        return [
267            c for c in self._X.columns
268            if not isinstance(self._X[c].dtype, numeric_types)
269        ]
270
271    def _infer_continuous(self) -> list[str]:
272        numeric_types = (
273            pl.Int8, pl.Int16, pl.Int32, pl.Int64,
274            pl.UInt8, pl.UInt16, pl.UInt32, pl.UInt64,
275            pl.Float32, pl.Float64,
276        )
277        return [
278            c for c in self._X.columns
279            if isinstance(self._X[c].dtype, numeric_types)
280            and c not in (self._categorical_features or [])
281        ]
282
283    def fit(self, verbose: bool | None = None) -> "SHAPRelativities":
284        """
285        Compute SHAP values for all features in X.
286
287        Must be called before extract_relativities(). Calling fit() again
288        recomputes SHAP values (e.g. after changing X or the background data).
289
290        The feature matrix is converted to pandas internally for shap's
291        TreeExplainer. The conversion is a necessary bridge - shap requires
292        pandas for column name handling.
293
294        For large datasets (50k+ policies) SHAP computation can take 30-60
295        seconds. A progress indicator is shown by default. To suppress it,
296        pass verbose=False here or set verbose=False in __init__.
297
298        Args:
299            verbose: Override the instance-level verbose setting for this call
300                only. If None (default), the instance-level setting is used.
301
302        Returns:
303            Self, for method chaining.
304        """
305        show_progress = self._verbose if verbose is None else verbose
306
307        # Convert to pandas for shap - unavoidable bridge
308        X_pd = _to_pandas(self._X)
309
310        bg_data = None
311        if self._feature_perturbation == "interventional":
312            if self._background_data is not None:
313                bg_data = _to_pandas(self._background_data)
314            else:
315                n_bg = min(self._n_background_samples, len(X_pd))
316                bg_data = shap.sample(X_pd, n_bg)
317
318        explainer = shap.TreeExplainer(
319            self._model,
320            data=bg_data,
321            feature_perturbation=self._feature_perturbation,
322            model_output="raw",
323        )
324
325        def _compute_shap():
326            return explainer.shap_values(X_pd)
327
328        raw = _run_with_spinner(_compute_shap, n_obs=len(X_pd), verbose=show_progress)
329
330        # Some models return a list when there is a single output
331        if isinstance(raw, list):
332            if len(raw) == 1:
333                raw = raw[0]
334            else:
335                raise ValueError(
336                    f"Model has {len(raw)} outputs. SHAPRelativities supports "
337                    "single-output models only."
338                )
339
340        self._shap_values = raw
341
342        ev = explainer.expected_value
343        if isinstance(ev, (list, np.ndarray)):
344            ev = float(ev[0])
345        self._expected_value = float(ev)
346
347        self._is_fitted = True
348        return self
349
350    def _check_fitted(self) -> None:
351        if not self._is_fitted:
352            raise RuntimeError("Call fit() before using this method.")
353
354    def shap_values(self) -> np.ndarray:
355        """
356        Raw SHAP values, shape (n_obs, n_features), in log space.
357
358        Returns:
359            Array of shape (n_obs, n_features).
360        """
361        self._check_fitted()
362        return self._shap_values  # type: ignore[return-value]
363
364    def baseline(self) -> float:
365        """
366        exp(expected_value) - the base rate in prediction space.
367
368        If annualise_exposure=True and exposure was provided, this is adjusted
369        for the average log-exposure offset so it represents an annualised rate.
370
371        Returns:
372            Base rate as a float.
373        """
374        self._check_fitted()
375        ev = self._expected_value  # type: ignore[assignment]
376
377        if self._annualise_exposure and self._exposure is not None:
378            mean_log_exp = float(np.mean(np.log(np.clip(self._exposure, 1e-9, None))))
379            ev = ev - mean_log_exp
380
381        return float(np.exp(ev))
382
383    def extract_relativities(
384        self,
385        normalise_to: str = "base_level",
386        base_levels: dict[str, str | float | int] | None = None,
387        ci_method: str = "clt",
388        n_bootstrap: int = 200,
389        ci_level: float = 0.95,
390    ) -> pl.DataFrame:
391        """
392        Extract multiplicative relativities from SHAP values.
393
394        Args:
395            normalise_to: "base_level" (base level for each feature gets
396                relativity = 1.0) or "mean" (exposure-weighted portfolio
397                mean = 1.0).
398            base_levels: Mapping of feature -> base level value. Required for
399                categorical features when normalise_to="base_level". Continuous
400                features automatically use mean normalisation regardless of
401                this setting.
402            ci_method: "clt" (CLT approximation, default, fast) or "none" (no
403                CIs). "bootstrap" is not yet implemented.
404            n_bootstrap: Ignored unless ci_method="bootstrap".
405            ci_level: Two-sided confidence level. Default 0.95.
406
407        Returns:
408            Polars DataFrame with columns: feature, level, relativity,
409            lower_ci, upper_ci, mean_shap, shap_std, n_obs, exposure_weight.
410            One row per (feature, level) combination.
411        """
412        self._check_fitted()
413
414        _VALID_CI_METHODS = {"clt", "bootstrap", "none"}
415        if ci_method not in _VALID_CI_METHODS:
416            raise ValueError(
417                f"Unknown ci_method {ci_method!r}. "
418                f"Valid options are: {sorted(_VALID_CI_METHODS)}."
419            )
420
421        if ci_method == "bootstrap":
422            raise NotImplementedError(
423                "Bootstrap CIs are not yet implemented. Use ci_method='clt'."
424            )
425
426        base_levels = base_levels or {}
427        weights = (
428            self._exposure if self._exposure is not None
429            else np.ones(len(self._X))
430        )
431
432        feature_names = self._X.columns
433        shap_vals = self._shap_values  # type: ignore[assignment]
434
435        parts: list[pl.DataFrame] = []
436
437        for i, feat in enumerate(feature_names):
438            feat_vals = self._X[feat].to_numpy()
439            shap_col = shap_vals[:, i]
440
441            is_categorical = feat in self._categorical_features
442
443            if is_categorical:
444                agg = aggregate_categorical(feat, feat_vals, shap_col, weights)
445            else:
446                agg = aggregate_continuous(feat, feat_vals, shap_col, weights)
447
448            # Normalisation
449            if normalise_to == "base_level" and is_categorical:
450                base = base_levels.get(feat)
451                if base is None:
452                    # Fall back to the level with the smallest mean_shap as
453                    # a sensible default (closest to intercept)
454                    base = agg.sort("mean_shap")["level"][0]
455                    warnings.warn(
456                        f"No base level specified for '{feat}'. "
457                        f"Using '{base}' (lowest mean SHAP) as base.",
458                        UserWarning,
459                        stacklevel=2,
460                    )
461
462                if ci_method == "none":
463                    base_key = str(base)
464                    base_rows = agg.filter(pl.col("level") == base_key)
465                    base_shap = base_rows["mean_shap"][0]
466                    agg = agg.with_columns([
467                        (pl.col("mean_shap") - base_shap).exp().alias("relativity"),
468                        pl.lit(float("nan")).alias("lower_ci"),
469                        pl.lit(float("nan")).alias("upper_ci"),
470                    ])
471                else:
472                    agg = normalise_base_level(agg, base, ci_level=ci_level)
473
474            else:
475                # Mean normalisation for continuous features, or when
476                # normalise_to="mean" for any feature
477                if ci_method == "none":
478                    total_weight = agg["exposure_weight"].sum()
479                    portfolio_mean = float(
480                        (agg["mean_shap"] * agg["exposure_weight"]).sum()
481                        / total_weight
482                    ) if total_weight > 0 else 0.0
483                    agg = agg.with_columns([
484                        (pl.col("mean_shap") - portfolio_mean).exp().alias("relativity"),
485                        pl.lit(float("nan")).alias("lower_ci"),
486                        pl.lit(float("nan")).alias("upper_ci"),
487                    ])
488                else:
489                    agg = normalise_mean(agg, ci_level=ci_level)
490
491            parts.append(agg)
492
493        # Cast the 'level' column to Utf8 in every part before concat.
494        # Categorical features produce level as Utf8; continuous features
495        # produce level as Float64. pl.concat with how="diagonal" cannot
496        # unify mismatched types for the same column name, so we normalise
497        # here rather than requiring callers to pre-cast their feature columns.
498        parts = [
499            p.with_columns(pl.col("level").cast(pl.String))
500            if "level" in p.columns else p
501            for p in parts
502        ]
503
504        result = pl.concat(parts, how="diagonal")
505
506        # Ensure standard column order (wsq_weight is internal, not exported)
507        available = [c for c in _RELATIVITY_COLUMNS if c in result.columns]
508        return result.select(available)
509
510    def extract_continuous_curve(
511        self,
512        feature: str,
513        n_points: int = 100,
514        smooth_method: str = "loess",
515    ) -> pl.DataFrame:
516        """
517        Smoothed relativity curve for a continuous feature.
518
519        Args:
520            feature: Feature name. Must be in continuous_features.
521            n_points: Number of points in the output curve (not the input
522                data).
523            smooth_method: "loess" (locally weighted regression, requires
524                statsmodels), "isotonic" (monotone curve via isotonic
525                regression), or "none" (raw per-observation relativities).
526
527        Returns:
528            Polars DataFrame with columns: feature_value, relativity,
529            lower_ci, upper_ci.
530
531        Raises:
532            ValueError: If feature is not in X or smooth_method is unknown.
533        """
534        self._check_fitted()
535
536        if feature not in self._X.columns:
537            raise ValueError(f"Feature '{feature}' not in X.")
538
539        feat_idx = self._X.columns.index(feature)
540        feat_vals = self._X[feature].to_numpy().astype(float)
541        shap_col = self._shap_values[:, feat_idx]  # type: ignore[index]
542        weights = (
543            self._exposure if self._exposure is not None
544            else np.ones(len(self._X))
545        )
546
547        # Exposure-weighted mean over the actual data distribution
548        portfolio_mean = np.average(shap_col, weights=weights)
549        relativities = np.exp(shap_col - portfolio_mean)
550
551        grid = np.linspace(feat_vals.min(), feat_vals.max(), n_points)
552
553        if smooth_method == "none":
554            order = np.argsort(feat_vals)
555            return pl.DataFrame({
556                "feature_value": feat_vals[order],
557                "relativity": relativities[order],
558                "lower_ci": np.full(len(feat_vals), float("nan")),
559                "upper_ci": np.full(len(feat_vals), float("nan")),
560            })
561
562        elif smooth_method == "isotonic":
563            from sklearn.isotonic import IsotonicRegression
564            ir = IsotonicRegression(out_of_bounds="clip")
565            ir.fit(feat_vals, shap_col, sample_weight=weights)
566            smoothed_shap = ir.predict(grid)
567
568            # P1-4 fix: normalise the smoothed curve so the exposure-weighted
569            # geometric mean of relativities = 1.0.
570            # The smooth is on the data, evaluated on a uniform grid. Subtracting
571            # portfolio_mean (computed on the data distribution) would be correct
572            # only if the grid were distributed like the data — it isn't.
573            # Instead, compute the data-distribution-weighted mean of the smoothed
574            # curve at the original data points, then use that as the reference.
575            smoothed_at_data = ir.predict(feat_vals)
576            weighted_mean_smoothed = np.average(smoothed_at_data, weights=weights)
577            smoothed_rel = np.exp(smoothed_shap - weighted_mean_smoothed)
578
579            return pl.DataFrame({
580                "feature_value": grid,
581                "relativity": smoothed_rel,
582                "lower_ci": np.full(n_points, float("nan")),
583                "upper_ci": np.full(n_points, float("nan")),
584            })
585
586        elif smooth_method == "loess":
587            try:
588                from statsmodels.nonparametric.smoothers_lowess import lowess
589
590                smoothed_shap = lowess(
591                    shap_col, feat_vals, frac=0.3, it=3,
592                    xvals=grid, is_sorted=False,
593                )
594
595                # P1-4 fix: compute the smoothed values at original data points
596                # so the normalisation is data-distribution-weighted, not
597                # grid-uniform. Use the same lowess parameters.
598                smoothed_at_data = lowess(
599                    shap_col, feat_vals, frac=0.3, it=3,
600                    xvals=feat_vals, is_sorted=False,
601                )
602                weighted_mean_smoothed = np.average(smoothed_at_data, weights=weights)
603                smoothed_rel = np.exp(smoothed_shap - weighted_mean_smoothed)
604
605                return pl.DataFrame({
606                    "feature_value": grid,
607                    "relativity": smoothed_rel,
608                    "lower_ci": np.full(n_points, float("nan")),
609                    "upper_ci": np.full(n_points, float("nan")),
610                })
611            except ImportError:
612                warnings.warn(
613                    "statsmodels not installed; falling back to smooth_method='none'.",
614                    UserWarning,
615                    stacklevel=2,
616                )
617                return self.extract_continuous_curve(
618                    feature, n_points=n_points, smooth_method="none"
619                )
620
621        else:
622            raise ValueError(
623                f"Unknown smooth_method '{smooth_method}'. "
624                "Choose from: 'loess', 'isotonic', 'none'."
625            )
626
627    def validate(self) -> dict[str, CheckResult]:
628        """
629        Run diagnostic checks on the SHAP computation.
630
631        Checks performed:
632
633        1. reconstruction: exp(shap.sum(1) + expected_value) should match
634           model predictions within tolerance. Material failure here indicates
635           the explainer was set up incorrectly.
636
637        2. feature_coverage: every feature in X should appear in the SHAP
638           output. Currently always passes given TreeExplainer's API.
639
640        3. sparse_levels: warns if any categorical level has fewer than 30
641           observations. CLT CIs will be unreliable for these levels.
642
643        Returns:
644            Dict with keys "reconstruction", "feature_coverage",
645            "sparse_levels". Each value is a CheckResult(passed, value,
646            message).
647        """
648        self._check_fitted()
649
650        X_pd = _to_pandas(self._X)
651
652        # Get model predictions for reconstruction check
653        preds = None
654        if self._model is not None:
655            try:
656                preds = self._model.predict(X_pd)
657            except Exception:
658                preds = None
659
660        results: dict[str, CheckResult] = {}
661
662        if preds is not None:
663            results["reconstruction"] = check_reconstruction(
664                self._shap_values,  # type: ignore[arg-type]
665                self._expected_value,  # type: ignore[arg-type]
666                preds,
667                tolerance=1e-4,
668            )
669        else:
670            results["reconstruction"] = CheckResult(
671                passed=False,
672                value=float("nan"),
673                message="Could not obtain model predictions for reconstruction check.",
674            )
675
676        feature_names = self._X.columns
677        results["feature_coverage"] = check_feature_coverage(
678            feature_names, feature_names
679        )
680
681        # Check sparse levels for categorical features
682        weights = (
683            self._exposure if self._exposure is not None
684            else np.ones(len(self._X))
685        )
686
687        sparse_parts: list[pl.DataFrame] = []
688        for feat in self._categorical_features:
689            if feat not in self._X.columns:
690                continue
691            feat_idx = self._X.columns.index(feat)
692            agg = aggregate_categorical(
693                feat,
694                self._X[feat].to_numpy(),
695                self._shap_values[:, feat_idx],  # type: ignore[index]
696                weights,
697            )
698            sparse_parts.append(agg)
699
700        if sparse_parts:
701            all_agg = pl.concat(sparse_parts, how="diagonal")
702            results["sparse_levels"] = check_sparse_levels(all_agg)
703        else:
704            results["sparse_levels"] = CheckResult(
705                passed=True, value=0.0,
706                message="No categorical features to check."
707            )
708
709        return results
710
711    def plot_relativities(
712        self,
713        features: list[str] | None = None,
714        show_ci: bool = True,
715        figsize: tuple[int, int] = (12, 8),
716    ) -> None:
717        """
718        Plot relativities as bar charts (categorical) or line charts (continuous).
719
720        Args:
721            features: Subset of features to plot. Defaults to all features.
722            show_ci: Whether to show confidence intervals. Default True.
723            figsize: Overall figure size.
724        """
725        self._check_fitted()
726
727        from ._plotting import plot_relativities as _plot
728
729        rels = self.extract_relativities()
730        _plot(
731            rels,
732            categorical_features=self._categorical_features,
733            continuous_features=self._continuous_features,
734            features=features,
735            show_ci=show_ci,
736            figsize=figsize,
737        )
738
739    def to_dict(self) -> dict[str, Any]:
740        """
741        Serialisable representation of the fitted object.
742
743        Stores SHAP values, expected value, feature names, and feature
744        classification. Does not store the original model or X DataFrame.
745
746        Returns:
747            Dict suitable for JSON serialisation.
748        """
749        self._check_fitted()
750        return {
751            "shap_values": self._shap_values.tolist(),  # type: ignore[union-attr]
752            "expected_value": self._expected_value,
753            "feature_names": self._X.columns,
754            "categorical_features": self._categorical_features,
755            "continuous_features": self._continuous_features,
756            "X_values": {c: self._X[c].to_list() for c in self._X.columns},
757            "exposure": (
758                self._exposure.tolist()
759                if self._exposure is not None else None
760            ),
761            "annualise_exposure": self._annualise_exposure,
762        }
763
764    @classmethod
765    def from_dict(cls, data: dict[str, Any]) -> "SHAPRelativities":
766        """
767        Reconstruct a fitted SHAPRelativities from to_dict() output.
768
769        The reconstructed object has no model attached, so validate() and
770        plot_relativities() still work but fit() cannot be re-run.
771
772        Args:
773            data: Output of to_dict().
774
775        Returns:
776            Fitted SHAPRelativities instance.
777        """
778        # P1-2 fix: use feature_names to control column ordering in X.
779        # Without this, tools that sort JSON object keys (REST APIs, some
780        # pretty-printers) reorder X_values, misaligning columns with the
781        # shap_values matrix columns.
782        feature_names: list[str] = data.get("feature_names", list(data["X_values"].keys()))
783        X = pl.DataFrame({k: data["X_values"][k] for k in feature_names})
784
785        exposure = (
786            np.array(data["exposure"]) if data.get("exposure") is not None
787            else None
788        )
789
790        # Create a minimal instance without a real model
791        instance = cls.__new__(cls)
792        instance._model = None
793        instance._X = X
794        instance._exposure = exposure
795        instance._categorical_features = data.get("categorical_features", [])
796        instance._continuous_features = data.get("continuous_features", [])
797        instance._feature_perturbation = "tree_path_dependent"
798        instance._background_data = None
799        instance._n_background_samples = 1000
800        instance._annualise_exposure = data.get("annualise_exposure", True)
801        instance._verbose = False  # no-op for deserialized instances
802        instance._shap_values = np.array(data["shap_values"])
803        instance._expected_value = float(data["expected_value"])
804        instance._is_fitted = True
805
806        return instance

Extract multiplicative rating relativities from a tree model via SHAP.

Workflow::

sr = SHAPRelativities(
    model=catboost_model,
    X=df.select(["area", "ncd_years", "has_convictions"]),
    exposure=df["exposure"],
    categorical_features=["area", "ncd_years"],
)
sr.fit()
rels = sr.extract_relativities(
    normalise_to="base_level",
    base_levels={"area": "A", "ncd_years": 0},
)
Arguments:
  • model: A trained CatBoost model. Must use a log-link objective (Poisson, Tweedie, Gamma). CatBoost is the recommended default - it handles categoricals natively.
  • X: Feature matrix. Use training data for in-sample relativities, or a representative holdout sample for out-of-sample. Polars DataFrames are preferred; pandas DataFrames are accepted and converted internally.
  • exposure: Earned policy years (or other volume measure). Used as observation weights throughout. If None, all observations are weighted equally.
  • categorical_features: Features to aggregate by level (bar-chart style). If None, all non-numeric columns are treated as categorical.
  • continuous_features: Features to leave as per-observation points. If None, all numeric columns are treated as continuous.
  • feature_perturbation: "tree_path_dependent" (default, fast, no background data needed) or "interventional" (corrects for feature correlation, needs background_data).
  • background_data: Required only if feature_perturbation="interventional".
  • n_background_samples: Number of background samples for interventional SHAP. Default 1000.
  • annualise_exposure: If True and exposure is provided, subtract mean log(exposure) from the expected_value to give an annualised baseline. Default True.
  • verbose: If True (default), show a progress indicator during fit(). Set to False to suppress all output — useful in batch jobs, notebooks with tqdm already present, or when calling fit() many times. If tqdm is not installed, falls back to a plain elapsed-time message.
SHAPRelativities( model: Any, X: Any, exposure: Any = None, categorical_features: list[str] | None = None, continuous_features: list[str] | None = None, background_data: Any = None, feature_perturbation: str = 'tree_path_dependent', n_background_samples: int = 1000, annualise_exposure: bool = True, verbose: bool = True)
203    def __init__(
204        self,
205        model: Any,
206        X: Any,
207        exposure: Any = None,
208        categorical_features: list[str] | None = None,
209        continuous_features: list[str] | None = None,
210        background_data: Any = None,
211        feature_perturbation: str = "tree_path_dependent",
212        n_background_samples: int = 1000,
213        annualise_exposure: bool = True,
214        verbose: bool = True,
215    ) -> None:
216        if not _SHAP_AVAILABLE:
217            raise ImportError(
218                "shap is required for SHAPRelativities. "
219                "Install it with: uv add 'shap-relativities[ml]'"
220            )
221
222        self._model = model
223        self._X: pl.DataFrame = _to_polars(X)
224        self._background_data = (
225            _to_polars(background_data) if background_data is not None else None
226        )
227
228        # Normalise exposure to a numpy array
229        if exposure is None:
230            self._exposure: np.ndarray | None = None
231        elif isinstance(exposure, np.ndarray):
232            self._exposure = exposure
233        elif isinstance(exposure, pl.Series):
234            self._exposure = exposure.to_numpy()
235        else:
236            # pd.Series or similar
237            self._exposure = np.asarray(exposure)
238
239        # Validate exposure length matches X
240        if self._exposure is not None and len(self._exposure) != len(self._X):
241            raise ValueError(
242                f"exposure length ({len(self._exposure)}) does not match "
243                f"X length ({len(self._X)}). Both must have the same number of rows."
244            )
245
246        self._feature_perturbation = feature_perturbation
247        self._n_background_samples = n_background_samples
248        self._annualise_exposure = annualise_exposure
249        self._verbose = verbose
250
251        # Classify features
252        self._categorical_features = categorical_features or self._infer_categorical()
253        self._continuous_features = continuous_features or self._infer_continuous()
254
255        # Populated by fit()
256        self._shap_values: np.ndarray | None = None
257        self._expected_value: float | None = None
258        self._is_fitted: bool = False
def fit( self, verbose: bool | None = None) -> SHAPRelativities:
283    def fit(self, verbose: bool | None = None) -> "SHAPRelativities":
284        """
285        Compute SHAP values for all features in X.
286
287        Must be called before extract_relativities(). Calling fit() again
288        recomputes SHAP values (e.g. after changing X or the background data).
289
290        The feature matrix is converted to pandas internally for shap's
291        TreeExplainer. The conversion is a necessary bridge - shap requires
292        pandas for column name handling.
293
294        For large datasets (50k+ policies) SHAP computation can take 30-60
295        seconds. A progress indicator is shown by default. To suppress it,
296        pass verbose=False here or set verbose=False in __init__.
297
298        Args:
299            verbose: Override the instance-level verbose setting for this call
300                only. If None (default), the instance-level setting is used.
301
302        Returns:
303            Self, for method chaining.
304        """
305        show_progress = self._verbose if verbose is None else verbose
306
307        # Convert to pandas for shap - unavoidable bridge
308        X_pd = _to_pandas(self._X)
309
310        bg_data = None
311        if self._feature_perturbation == "interventional":
312            if self._background_data is not None:
313                bg_data = _to_pandas(self._background_data)
314            else:
315                n_bg = min(self._n_background_samples, len(X_pd))
316                bg_data = shap.sample(X_pd, n_bg)
317
318        explainer = shap.TreeExplainer(
319            self._model,
320            data=bg_data,
321            feature_perturbation=self._feature_perturbation,
322            model_output="raw",
323        )
324
325        def _compute_shap():
326            return explainer.shap_values(X_pd)
327
328        raw = _run_with_spinner(_compute_shap, n_obs=len(X_pd), verbose=show_progress)
329
330        # Some models return a list when there is a single output
331        if isinstance(raw, list):
332            if len(raw) == 1:
333                raw = raw[0]
334            else:
335                raise ValueError(
336                    f"Model has {len(raw)} outputs. SHAPRelativities supports "
337                    "single-output models only."
338                )
339
340        self._shap_values = raw
341
342        ev = explainer.expected_value
343        if isinstance(ev, (list, np.ndarray)):
344            ev = float(ev[0])
345        self._expected_value = float(ev)
346
347        self._is_fitted = True
348        return self

Compute SHAP values for all features in X.

Must be called before extract_relativities(). Calling fit() again recomputes SHAP values (e.g. after changing X or the background data).

The feature matrix is converted to pandas internally for shap's TreeExplainer. The conversion is a necessary bridge - shap requires pandas for column name handling.

For large datasets (50k+ policies) SHAP computation can take 30-60 seconds. A progress indicator is shown by default. To suppress it, pass verbose=False here or set verbose=False in __init__.

Arguments:
  • verbose: Override the instance-level verbose setting for this call only. If None (default), the instance-level setting is used.
Returns:

Self, for method chaining.

def shap_values(self) -> numpy.ndarray:
354    def shap_values(self) -> np.ndarray:
355        """
356        Raw SHAP values, shape (n_obs, n_features), in log space.
357
358        Returns:
359            Array of shape (n_obs, n_features).
360        """
361        self._check_fitted()
362        return self._shap_values  # type: ignore[return-value]

Raw SHAP values, shape (n_obs, n_features), in log space.

Returns:

Array of shape (n_obs, n_features).

def baseline(self) -> float:
364    def baseline(self) -> float:
365        """
366        exp(expected_value) - the base rate in prediction space.
367
368        If annualise_exposure=True and exposure was provided, this is adjusted
369        for the average log-exposure offset so it represents an annualised rate.
370
371        Returns:
372            Base rate as a float.
373        """
374        self._check_fitted()
375        ev = self._expected_value  # type: ignore[assignment]
376
377        if self._annualise_exposure and self._exposure is not None:
378            mean_log_exp = float(np.mean(np.log(np.clip(self._exposure, 1e-9, None))))
379            ev = ev - mean_log_exp
380
381        return float(np.exp(ev))

exp(expected_value) - the base rate in prediction space.

If annualise_exposure=True and exposure was provided, this is adjusted for the average log-exposure offset so it represents an annualised rate.

Returns:

Base rate as a float.

def extract_relativities( self, normalise_to: str = 'base_level', base_levels: dict[str, str | float | int] | None = None, ci_method: str = 'clt', n_bootstrap: int = 200, ci_level: float = 0.95) -> polars.dataframe.frame.DataFrame:
383    def extract_relativities(
384        self,
385        normalise_to: str = "base_level",
386        base_levels: dict[str, str | float | int] | None = None,
387        ci_method: str = "clt",
388        n_bootstrap: int = 200,
389        ci_level: float = 0.95,
390    ) -> pl.DataFrame:
391        """
392        Extract multiplicative relativities from SHAP values.
393
394        Args:
395            normalise_to: "base_level" (base level for each feature gets
396                relativity = 1.0) or "mean" (exposure-weighted portfolio
397                mean = 1.0).
398            base_levels: Mapping of feature -> base level value. Required for
399                categorical features when normalise_to="base_level". Continuous
400                features automatically use mean normalisation regardless of
401                this setting.
402            ci_method: "clt" (CLT approximation, default, fast) or "none" (no
403                CIs). "bootstrap" is not yet implemented.
404            n_bootstrap: Ignored unless ci_method="bootstrap".
405            ci_level: Two-sided confidence level. Default 0.95.
406
407        Returns:
408            Polars DataFrame with columns: feature, level, relativity,
409            lower_ci, upper_ci, mean_shap, shap_std, n_obs, exposure_weight.
410            One row per (feature, level) combination.
411        """
412        self._check_fitted()
413
414        _VALID_CI_METHODS = {"clt", "bootstrap", "none"}
415        if ci_method not in _VALID_CI_METHODS:
416            raise ValueError(
417                f"Unknown ci_method {ci_method!r}. "
418                f"Valid options are: {sorted(_VALID_CI_METHODS)}."
419            )
420
421        if ci_method == "bootstrap":
422            raise NotImplementedError(
423                "Bootstrap CIs are not yet implemented. Use ci_method='clt'."
424            )
425
426        base_levels = base_levels or {}
427        weights = (
428            self._exposure if self._exposure is not None
429            else np.ones(len(self._X))
430        )
431
432        feature_names = self._X.columns
433        shap_vals = self._shap_values  # type: ignore[assignment]
434
435        parts: list[pl.DataFrame] = []
436
437        for i, feat in enumerate(feature_names):
438            feat_vals = self._X[feat].to_numpy()
439            shap_col = shap_vals[:, i]
440
441            is_categorical = feat in self._categorical_features
442
443            if is_categorical:
444                agg = aggregate_categorical(feat, feat_vals, shap_col, weights)
445            else:
446                agg = aggregate_continuous(feat, feat_vals, shap_col, weights)
447
448            # Normalisation
449            if normalise_to == "base_level" and is_categorical:
450                base = base_levels.get(feat)
451                if base is None:
452                    # Fall back to the level with the smallest mean_shap as
453                    # a sensible default (closest to intercept)
454                    base = agg.sort("mean_shap")["level"][0]
455                    warnings.warn(
456                        f"No base level specified for '{feat}'. "
457                        f"Using '{base}' (lowest mean SHAP) as base.",
458                        UserWarning,
459                        stacklevel=2,
460                    )
461
462                if ci_method == "none":
463                    base_key = str(base)
464                    base_rows = agg.filter(pl.col("level") == base_key)
465                    base_shap = base_rows["mean_shap"][0]
466                    agg = agg.with_columns([
467                        (pl.col("mean_shap") - base_shap).exp().alias("relativity"),
468                        pl.lit(float("nan")).alias("lower_ci"),
469                        pl.lit(float("nan")).alias("upper_ci"),
470                    ])
471                else:
472                    agg = normalise_base_level(agg, base, ci_level=ci_level)
473
474            else:
475                # Mean normalisation for continuous features, or when
476                # normalise_to="mean" for any feature
477                if ci_method == "none":
478                    total_weight = agg["exposure_weight"].sum()
479                    portfolio_mean = float(
480                        (agg["mean_shap"] * agg["exposure_weight"]).sum()
481                        / total_weight
482                    ) if total_weight > 0 else 0.0
483                    agg = agg.with_columns([
484                        (pl.col("mean_shap") - portfolio_mean).exp().alias("relativity"),
485                        pl.lit(float("nan")).alias("lower_ci"),
486                        pl.lit(float("nan")).alias("upper_ci"),
487                    ])
488                else:
489                    agg = normalise_mean(agg, ci_level=ci_level)
490
491            parts.append(agg)
492
493        # Cast the 'level' column to Utf8 in every part before concat.
494        # Categorical features produce level as Utf8; continuous features
495        # produce level as Float64. pl.concat with how="diagonal" cannot
496        # unify mismatched types for the same column name, so we normalise
497        # here rather than requiring callers to pre-cast their feature columns.
498        parts = [
499            p.with_columns(pl.col("level").cast(pl.String))
500            if "level" in p.columns else p
501            for p in parts
502        ]
503
504        result = pl.concat(parts, how="diagonal")
505
506        # Ensure standard column order (wsq_weight is internal, not exported)
507        available = [c for c in _RELATIVITY_COLUMNS if c in result.columns]
508        return result.select(available)

Extract multiplicative relativities from SHAP values.

Arguments:
  • normalise_to: "base_level" (base level for each feature gets relativity = 1.0) or "mean" (exposure-weighted portfolio mean = 1.0).
  • base_levels: Mapping of feature -> base level value. Required for categorical features when normalise_to="base_level". Continuous features automatically use mean normalisation regardless of this setting.
  • ci_method: "clt" (CLT approximation, default, fast) or "none" (no CIs). "bootstrap" is not yet implemented.
  • n_bootstrap: Ignored unless ci_method="bootstrap".
  • ci_level: Two-sided confidence level. Default 0.95.
Returns:

Polars DataFrame with columns: feature, level, relativity, lower_ci, upper_ci, mean_shap, shap_std, n_obs, exposure_weight. One row per (feature, level) combination.

def extract_continuous_curve( self, feature: str, n_points: int = 100, smooth_method: str = 'loess') -> polars.dataframe.frame.DataFrame:
510    def extract_continuous_curve(
511        self,
512        feature: str,
513        n_points: int = 100,
514        smooth_method: str = "loess",
515    ) -> pl.DataFrame:
516        """
517        Smoothed relativity curve for a continuous feature.
518
519        Args:
520            feature: Feature name. Must be in continuous_features.
521            n_points: Number of points in the output curve (not the input
522                data).
523            smooth_method: "loess" (locally weighted regression, requires
524                statsmodels), "isotonic" (monotone curve via isotonic
525                regression), or "none" (raw per-observation relativities).
526
527        Returns:
528            Polars DataFrame with columns: feature_value, relativity,
529            lower_ci, upper_ci.
530
531        Raises:
532            ValueError: If feature is not in X or smooth_method is unknown.
533        """
534        self._check_fitted()
535
536        if feature not in self._X.columns:
537            raise ValueError(f"Feature '{feature}' not in X.")
538
539        feat_idx = self._X.columns.index(feature)
540        feat_vals = self._X[feature].to_numpy().astype(float)
541        shap_col = self._shap_values[:, feat_idx]  # type: ignore[index]
542        weights = (
543            self._exposure if self._exposure is not None
544            else np.ones(len(self._X))
545        )
546
547        # Exposure-weighted mean over the actual data distribution
548        portfolio_mean = np.average(shap_col, weights=weights)
549        relativities = np.exp(shap_col - portfolio_mean)
550
551        grid = np.linspace(feat_vals.min(), feat_vals.max(), n_points)
552
553        if smooth_method == "none":
554            order = np.argsort(feat_vals)
555            return pl.DataFrame({
556                "feature_value": feat_vals[order],
557                "relativity": relativities[order],
558                "lower_ci": np.full(len(feat_vals), float("nan")),
559                "upper_ci": np.full(len(feat_vals), float("nan")),
560            })
561
562        elif smooth_method == "isotonic":
563            from sklearn.isotonic import IsotonicRegression
564            ir = IsotonicRegression(out_of_bounds="clip")
565            ir.fit(feat_vals, shap_col, sample_weight=weights)
566            smoothed_shap = ir.predict(grid)
567
568            # P1-4 fix: normalise the smoothed curve so the exposure-weighted
569            # geometric mean of relativities = 1.0.
570            # The smooth is on the data, evaluated on a uniform grid. Subtracting
571            # portfolio_mean (computed on the data distribution) would be correct
572            # only if the grid were distributed like the data — it isn't.
573            # Instead, compute the data-distribution-weighted mean of the smoothed
574            # curve at the original data points, then use that as the reference.
575            smoothed_at_data = ir.predict(feat_vals)
576            weighted_mean_smoothed = np.average(smoothed_at_data, weights=weights)
577            smoothed_rel = np.exp(smoothed_shap - weighted_mean_smoothed)
578
579            return pl.DataFrame({
580                "feature_value": grid,
581                "relativity": smoothed_rel,
582                "lower_ci": np.full(n_points, float("nan")),
583                "upper_ci": np.full(n_points, float("nan")),
584            })
585
586        elif smooth_method == "loess":
587            try:
588                from statsmodels.nonparametric.smoothers_lowess import lowess
589
590                smoothed_shap = lowess(
591                    shap_col, feat_vals, frac=0.3, it=3,
592                    xvals=grid, is_sorted=False,
593                )
594
595                # P1-4 fix: compute the smoothed values at original data points
596                # so the normalisation is data-distribution-weighted, not
597                # grid-uniform. Use the same lowess parameters.
598                smoothed_at_data = lowess(
599                    shap_col, feat_vals, frac=0.3, it=3,
600                    xvals=feat_vals, is_sorted=False,
601                )
602                weighted_mean_smoothed = np.average(smoothed_at_data, weights=weights)
603                smoothed_rel = np.exp(smoothed_shap - weighted_mean_smoothed)
604
605                return pl.DataFrame({
606                    "feature_value": grid,
607                    "relativity": smoothed_rel,
608                    "lower_ci": np.full(n_points, float("nan")),
609                    "upper_ci": np.full(n_points, float("nan")),
610                })
611            except ImportError:
612                warnings.warn(
613                    "statsmodels not installed; falling back to smooth_method='none'.",
614                    UserWarning,
615                    stacklevel=2,
616                )
617                return self.extract_continuous_curve(
618                    feature, n_points=n_points, smooth_method="none"
619                )
620
621        else:
622            raise ValueError(
623                f"Unknown smooth_method '{smooth_method}'. "
624                "Choose from: 'loess', 'isotonic', 'none'."
625            )

Smoothed relativity curve for a continuous feature.

Arguments:
  • feature: Feature name. Must be in continuous_features.
  • n_points: Number of points in the output curve (not the input data).
  • smooth_method: "loess" (locally weighted regression, requires statsmodels), "isotonic" (monotone curve via isotonic regression), or "none" (raw per-observation relativities).
Returns:

Polars DataFrame with columns: feature_value, relativity, lower_ci, upper_ci.

Raises:
  • ValueError: If feature is not in X or smooth_method is unknown.
def validate(self) -> dict[str, shap_relativities._validation.CheckResult]:
627    def validate(self) -> dict[str, CheckResult]:
628        """
629        Run diagnostic checks on the SHAP computation.
630
631        Checks performed:
632
633        1. reconstruction: exp(shap.sum(1) + expected_value) should match
634           model predictions within tolerance. Material failure here indicates
635           the explainer was set up incorrectly.
636
637        2. feature_coverage: every feature in X should appear in the SHAP
638           output. Currently always passes given TreeExplainer's API.
639
640        3. sparse_levels: warns if any categorical level has fewer than 30
641           observations. CLT CIs will be unreliable for these levels.
642
643        Returns:
644            Dict with keys "reconstruction", "feature_coverage",
645            "sparse_levels". Each value is a CheckResult(passed, value,
646            message).
647        """
648        self._check_fitted()
649
650        X_pd = _to_pandas(self._X)
651
652        # Get model predictions for reconstruction check
653        preds = None
654        if self._model is not None:
655            try:
656                preds = self._model.predict(X_pd)
657            except Exception:
658                preds = None
659
660        results: dict[str, CheckResult] = {}
661
662        if preds is not None:
663            results["reconstruction"] = check_reconstruction(
664                self._shap_values,  # type: ignore[arg-type]
665                self._expected_value,  # type: ignore[arg-type]
666                preds,
667                tolerance=1e-4,
668            )
669        else:
670            results["reconstruction"] = CheckResult(
671                passed=False,
672                value=float("nan"),
673                message="Could not obtain model predictions for reconstruction check.",
674            )
675
676        feature_names = self._X.columns
677        results["feature_coverage"] = check_feature_coverage(
678            feature_names, feature_names
679        )
680
681        # Check sparse levels for categorical features
682        weights = (
683            self._exposure if self._exposure is not None
684            else np.ones(len(self._X))
685        )
686
687        sparse_parts: list[pl.DataFrame] = []
688        for feat in self._categorical_features:
689            if feat not in self._X.columns:
690                continue
691            feat_idx = self._X.columns.index(feat)
692            agg = aggregate_categorical(
693                feat,
694                self._X[feat].to_numpy(),
695                self._shap_values[:, feat_idx],  # type: ignore[index]
696                weights,
697            )
698            sparse_parts.append(agg)
699
700        if sparse_parts:
701            all_agg = pl.concat(sparse_parts, how="diagonal")
702            results["sparse_levels"] = check_sparse_levels(all_agg)
703        else:
704            results["sparse_levels"] = CheckResult(
705                passed=True, value=0.0,
706                message="No categorical features to check."
707            )
708
709        return results

Run diagnostic checks on the SHAP computation.

Checks performed:

  1. reconstruction: exp(shap.sum(1) + expected_value) should match model predictions within tolerance. Material failure here indicates the explainer was set up incorrectly.

  2. feature_coverage: every feature in X should appear in the SHAP output. Currently always passes given TreeExplainer's API.

  3. sparse_levels: warns if any categorical level has fewer than 30 observations. CLT CIs will be unreliable for these levels.

Returns:

Dict with keys "reconstruction", "feature_coverage", "sparse_levels". Each value is a CheckResult(passed, value, message).

def plot_relativities( self, features: list[str] | None = None, show_ci: bool = True, figsize: tuple[int, int] = (12, 8)) -> None:
711    def plot_relativities(
712        self,
713        features: list[str] | None = None,
714        show_ci: bool = True,
715        figsize: tuple[int, int] = (12, 8),
716    ) -> None:
717        """
718        Plot relativities as bar charts (categorical) or line charts (continuous).
719
720        Args:
721            features: Subset of features to plot. Defaults to all features.
722            show_ci: Whether to show confidence intervals. Default True.
723            figsize: Overall figure size.
724        """
725        self._check_fitted()
726
727        from ._plotting import plot_relativities as _plot
728
729        rels = self.extract_relativities()
730        _plot(
731            rels,
732            categorical_features=self._categorical_features,
733            continuous_features=self._continuous_features,
734            features=features,
735            show_ci=show_ci,
736            figsize=figsize,
737        )

Plot relativities as bar charts (categorical) or line charts (continuous).

Arguments:
  • features: Subset of features to plot. Defaults to all features.
  • show_ci: Whether to show confidence intervals. Default True.
  • figsize: Overall figure size.
def to_dict(self) -> dict[str, typing.Any]:
739    def to_dict(self) -> dict[str, Any]:
740        """
741        Serialisable representation of the fitted object.
742
743        Stores SHAP values, expected value, feature names, and feature
744        classification. Does not store the original model or X DataFrame.
745
746        Returns:
747            Dict suitable for JSON serialisation.
748        """
749        self._check_fitted()
750        return {
751            "shap_values": self._shap_values.tolist(),  # type: ignore[union-attr]
752            "expected_value": self._expected_value,
753            "feature_names": self._X.columns,
754            "categorical_features": self._categorical_features,
755            "continuous_features": self._continuous_features,
756            "X_values": {c: self._X[c].to_list() for c in self._X.columns},
757            "exposure": (
758                self._exposure.tolist()
759                if self._exposure is not None else None
760            ),
761            "annualise_exposure": self._annualise_exposure,
762        }

Serialisable representation of the fitted object.

Stores SHAP values, expected value, feature names, and feature classification. Does not store the original model or X DataFrame.

Returns:

Dict suitable for JSON serialisation.

@classmethod
def from_dict( cls, data: dict[str, typing.Any]) -> SHAPRelativities:
764    @classmethod
765    def from_dict(cls, data: dict[str, Any]) -> "SHAPRelativities":
766        """
767        Reconstruct a fitted SHAPRelativities from to_dict() output.
768
769        The reconstructed object has no model attached, so validate() and
770        plot_relativities() still work but fit() cannot be re-run.
771
772        Args:
773            data: Output of to_dict().
774
775        Returns:
776            Fitted SHAPRelativities instance.
777        """
778        # P1-2 fix: use feature_names to control column ordering in X.
779        # Without this, tools that sort JSON object keys (REST APIs, some
780        # pretty-printers) reorder X_values, misaligning columns with the
781        # shap_values matrix columns.
782        feature_names: list[str] = data.get("feature_names", list(data["X_values"].keys()))
783        X = pl.DataFrame({k: data["X_values"][k] for k in feature_names})
784
785        exposure = (
786            np.array(data["exposure"]) if data.get("exposure") is not None
787            else None
788        )
789
790        # Create a minimal instance without a real model
791        instance = cls.__new__(cls)
792        instance._model = None
793        instance._X = X
794        instance._exposure = exposure
795        instance._categorical_features = data.get("categorical_features", [])
796        instance._continuous_features = data.get("continuous_features", [])
797        instance._feature_perturbation = "tree_path_dependent"
798        instance._background_data = None
799        instance._n_background_samples = 1000
800        instance._annualise_exposure = data.get("annualise_exposure", True)
801        instance._verbose = False  # no-op for deserialized instances
802        instance._shap_values = np.array(data["shap_values"])
803        instance._expected_value = float(data["expected_value"])
804        instance._is_fitted = True
805
806        return instance

Reconstruct a fitted SHAPRelativities from to_dict() output.

The reconstructed object has no model attached, so validate() and plot_relativities() still work but fit() cannot be re-run.

Arguments:
  • data: Output of to_dict().
Returns:

Fitted SHAPRelativities instance.

class SHAPInference:
270class SHAPInference:
271    """
272    Asymptotically valid confidence intervals for global SHAP feature importance.
273
274    Implements the de-biased U-statistic estimator from Whitehouse, Sawarni,
275    Syrgkanis (2026), arXiv:2602.10532. Provides CIs for theta_p = E[|phi_a(X)|^p]
276    for any p >= 1 for each feature.
277
278    The most common use cases are:
279      p=1: mean absolute SHAP (standard SHAP importance bar chart)
280      p=2: mean squared SHAP (variance-like, cleaner theory)
281
282    IMPORTANT: Valid inference requires interventional SHAP values, not the
283    default path-dependent TreeSHAP. Use SHAPRelativities with
284    feature_perturbation='interventional' before calling SHAPInference, or
285    pass interventional SHAP values directly.
286
287    Args:
288        shap_values: np.ndarray of shape (n_obs, n_features). SHAP values,
289            one column per feature. Must be interventional SHAP for theoretical
290            validity. Path-dependent SHAP will produce point estimates and
291            intervals, but coverage guarantees do not hold.
292        y: np.ndarray of shape (n_obs,). Observed outcomes (claim counts or
293            claim amounts). Required for the alpha nuisance correction.
294        feature_names: List[str] of length n_features. Column names.
295        p: float >= 1. Power for importance measure. Default 2.0 (mean squared
296            SHAP). p=1 gives mean absolute SHAP (the standard bar chart metric)
297            but requires smoothing. p=2 has cleaner asymptotic theory.
298        n_folds: int >= 2. Number of cross-fitting folds. Default 5. More folds
299            reduce bias from the cross-fitting but increase compute.
300        nuisance_estimator: str or sklearn estimator. Used for mu_hat (E[Y|X])
301            and gamma_hat. Default 'gradient_boosting' uses
302            HistGradientBoostingRegressor.
303        alpha_estimator: str or sklearn estimator. Used for alpha_hat. Defaults
304            to same as nuisance_estimator.
305        beta_n: float or None. Smoothing parameter for p < 2. If None, computed
306            as n^{(2-p)/(2*(p+1))} (assumes delta=1). Only used when p < 2.
307        ci_level: float. Two-sided confidence level. Default 0.95.
308        n_jobs: int. Placeholder for future parallelism. Currently unused;
309            features are estimated sequentially.
310        random_state: int or None. Controls fold splitting for reproducibility.
311
312    Examples
313    --------
314    >>> import numpy as np
315    >>> rng = np.random.default_rng(0)
316    >>> shap_vals = rng.normal(size=(500, 3))
317    >>> y = rng.poisson(1.0, size=500).astype(float)
318    >>> si = SHAPInference(shap_vals, y, feature_names=["a", "b", "c"], p=2)
319    >>> si.fit()
320    SHAPInference(n_obs=500, n_features=3, p=2.0, n_folds=5)
321    >>> tbl = si.importance_table()
322    """
323
324    def __init__(
325        self,
326        shap_values: np.ndarray,
327        y: np.ndarray,
328        feature_names: list[str],
329        p: float = 2.0,
330        n_folds: int = 5,
331        nuisance_estimator: str | Any = "gradient_boosting",
332        alpha_estimator: str | Any = "gradient_boosting",
333        beta_n: float | None = None,
334        ci_level: float = 0.95,
335        n_jobs: int = 1,
336        random_state: int | None = None,
337    ) -> None:
338        # --- Input validation ---
339        shap_values = np.asarray(shap_values, dtype=float)
340        y = np.asarray(y, dtype=float)
341
342        if shap_values.ndim != 2:
343            raise ValueError(
344                f"shap_values must be 2D array (n_obs, n_features), "
345                f"got shape {shap_values.shape}"
346            )
347        if y.ndim != 1:
348            raise ValueError(f"y must be 1D array, got shape {y.shape}")
349        if shap_values.shape[0] != len(y):
350            raise ValueError(
351                f"shap_values and y must have the same number of observations. "
352                f"Got shap_values.shape[0]={shap_values.shape[0]} and len(y)={len(y)}."
353            )
354        if len(feature_names) != shap_values.shape[1]:
355            raise ValueError(
356                f"len(feature_names)={len(feature_names)} must equal "
357                f"shap_values.shape[1]={shap_values.shape[1]}."
358            )
359        if p < 1.0:
360            raise ValueError(f"p must be >= 1. Got p={p}.")
361        if n_folds < 2:
362            raise ValueError(f"n_folds must be >= 2. Got n_folds={n_folds}.")
363        if not (0.0 < ci_level < 1.0):
364            raise ValueError(f"ci_level must be in (0, 1). Got ci_level={ci_level}.")
365        if len(feature_names) != len(set(feature_names)):
366            raise ValueError("feature_names must be unique.")
367
368        self.shap_values = shap_values
369        self.y = y
370        self.feature_names = list(feature_names)
371        self.p = float(p)
372        self.n_folds = n_folds
373        self.nuisance_estimator = nuisance_estimator
374        self.alpha_estimator = alpha_estimator
375        self.beta_n = beta_n
376        self.ci_level = ci_level
377        self.n_jobs = n_jobs
378        self.random_state = random_state
379
380        # Fitted attributes — populated by fit()
381        self._theta_hat: np.ndarray | None = None   # shape (n_features,)
382        self._se: np.ndarray | None = None          # shape (n_features,)
383        self._rho: np.ndarray | None = None         # shape (n_obs, n_features)
384        self._is_fitted: bool = False
385
386    def __repr__(self) -> str:
387        n, d = self.shap_values.shape
388        status = "fitted" if self._is_fitted else "not fitted"
389        return (
390            f"SHAPInference(n_obs={n}, n_features={d}, "
391            f"p={self.p}, n_folds={self.n_folds}, status={status})"
392        )
393
394    def fit(self) -> "SHAPInference":
395        """
396        Estimate nuisance functions via cross-fitting and compute de-biased
397        theta_hat_p for each feature.
398
399        The algorithm:
400        1. Split observations into n_folds folds.
401        2. For each fold, train mu_hat, gamma_hat, alpha_hat on the complement.
402        3. Evaluate nuisances on the held-out fold.
403        4. Assemble full-data nuisance predictions.
404        5. Compute the influence function rho_a for each feature.
405        6. theta_hat = mean(rho_a), SE = sqrt(var(rho_a) / n).
406
407        Returns:
408            self, for method chaining.
409        """
410        if not _SKLEARN_AVAILABLE:
411            raise ImportError(
412                "scikit-learn >= 1.3 is required for SHAPInference. "
413                "Install with: pip install shap-relativities[ml]"
414            )
415
416        n, d = self.shap_values.shape
417
418        # Warn if p < 2: smoothing is active
419        effective_beta_n = self.beta_n
420        if self.p < 2.0:
421            if effective_beta_n is None:
422                effective_beta_n = _default_beta_n(n, self.p)
423            warnings.warn(
424                f"p={self.p} < 2: using smoothed estimator phi_{{p,beta}} "
425                f"with beta_n={effective_beta_n:.3f}. "
426                "Coverage is asymptotically valid but may be approximate for "
427                "features with many near-zero SHAP values. "
428                "Consider p=2 for cleaner guarantees.",
429                UserWarning,
430                stacklevel=2,
431            )
432        else:
433            effective_beta_n = 0.0  # unused but keeps type consistent
434
435        nu_est = _make_nuisance_estimator(self.nuisance_estimator)
436        al_est = _make_nuisance_estimator(self.alpha_estimator)
437
438        kf = KFold(n_splits=self.n_folds, shuffle=True, random_state=self.random_state)
439
440        theta_hats = np.zeros(d)
441        ses = np.zeros(d)
442        rhos = np.zeros((n, d))
443
444        for j in range(d):
445            theta_j, se_j, rho_j = _fit_single_feature(
446                phi_col=self.shap_values[:, j],
447                y=self.y,
448                shap_matrix=self.shap_values,
449                p=self.p,
450                beta_n=effective_beta_n,
451                kf=kf,
452                nuisance_estimator=nu_est,
453                alpha_estimator=al_est,
454            )
455            theta_hats[j] = theta_j
456            ses[j] = se_j
457            rhos[:, j] = rho_j
458
459        self._theta_hat = theta_hats
460        self._se = ses
461        self._rho = rhos
462        self._is_fitted = True
463
464        return self
465
466    def _check_fitted(self) -> None:
467        if not self._is_fitted:
468            raise RuntimeError("Call .fit() before accessing results.")
469
470    def importance_table(self) -> pl.DataFrame:
471        """
472        Return feature importance estimates with confidence intervals.
473
474        All theta_hat values are theoretically non-negative (they estimate
475        E[|phi|^p]), but may be slightly negative for features with very
476        small true importance — this is expected sampling variability.
477
478        Returns:
479            Polars DataFrame with columns:
480              feature:          Feature name
481              theta_hat:        Point estimate of E[|phi_a(X)|^p]
482              theta_lower:      Lower CI bound
483              theta_upper:      Upper CI bound
484              sigma_hat:        sqrt(Var[rho_a]) — asymptotic std dev
485              se:               Standard error = sigma_hat / sqrt(n)
486              rank:             Rank by theta_hat (1 = most important)
487              rank_lower:       Conservative rank (using theta_lower)
488              rank_upper:       Optimistic rank (using theta_upper)
489              p_value_nonzero:  Two-sided p-value for H0: theta_p = 0
490        """
491        self._check_fitted()
492        assert self._theta_hat is not None
493        assert self._se is not None
494        assert self._rho is not None
495
496        n = self.shap_values.shape[0]
497        z = float(stats.norm.ppf((1 + self.ci_level) / 2))
498
499        theta = self._theta_hat
500        se = self._se
501        lower = theta - z * se
502        upper = theta + z * se
503
504        # sigma_hat = SE * sqrt(n) = sqrt(Var[rho])
505        sigma_hat = se * np.sqrt(n)
506
507        # Ranks (1 = most important)
508        rank = _dense_rank_descending(theta)
509        rank_lower = _dense_rank_descending(lower)  # conservative: lower bound
510        rank_upper = _dense_rank_descending(upper)  # optimistic: upper bound
511
512        # p-value: two-sided test H0: theta_p = 0
513        # Under H0, z_stat = theta_hat / SE ~ N(0,1) asymptotically
514        with np.errstate(divide="ignore", invalid="ignore"):
515            z_stat = np.where(se > 0, theta / se, np.inf)
516        p_values = 2.0 * (1.0 - stats.norm.cdf(np.abs(z_stat)))
517
518        return pl.DataFrame({
519            "feature": self.feature_names,
520            "theta_hat": theta.tolist(),
521            "theta_lower": lower.tolist(),
522            "theta_upper": upper.tolist(),
523            "sigma_hat": sigma_hat.tolist(),
524            "se": se.tolist(),
525            # _dense_rank_descending already returns list[int]; no .tolist() needed
526            "rank": rank,
527            "rank_lower": rank_lower,
528            "rank_upper": rank_upper,
529            "p_value_nonzero": p_values.tolist(),
530        }).sort("rank")
531
532    def ranking_ci(self, feature_a: str, feature_b: str) -> dict[str, float]:
533        """
534        Test whether feature_a has strictly higher importance than feature_b.
535
536        Uses the joint asymptotic distribution of (theta_hat_a, theta_hat_b).
537        The covariance is estimated from influence function cross-products,
538        which accounts for the fact that both features' rho vectors are
539        computed from the same observations.
540
541        H0: theta_a = theta_b
542        H1: theta_a > theta_b  (one-sided)
543
544        Args:
545            feature_a: Name of the first feature.
546            feature_b: Name of the second feature.
547
548        Returns:
549            dict with:
550              diff:      theta_hat_a - theta_hat_b
551              se_diff:   Standard error of the difference
552              z_stat:    Standardised test statistic
553              p_value:   One-sided p-value for H1: theta_a > theta_b
554              ci_lower:  Lower bound on (theta_a - theta_b) at self.ci_level
555              ci_upper:  Upper bound on (theta_a - theta_b) at self.ci_level
556        """
557        self._check_fitted()
558        assert self._theta_hat is not None
559        assert self._rho is not None
560
561        if feature_a not in self.feature_names:
562            raise ValueError(f"feature_a='{feature_a}' not in feature_names.")
563        if feature_b not in self.feature_names:
564            raise ValueError(f"feature_b='{feature_b}' not in feature_names.")
565
566        j_a = self.feature_names.index(feature_a)
567        j_b = self.feature_names.index(feature_b)
568
569        n = self.shap_values.shape[0]
570        theta_a = float(self._theta_hat[j_a])
571        theta_b = float(self._theta_hat[j_b])
572
573        rho_a = self._rho[:, j_a]
574        rho_b = self._rho[:, j_b]
575
576        # Var(theta_hat_a - theta_hat_b) = Var(mean(rho_a - rho_b)) / n
577        # = Var(rho_a - rho_b) / n
578        diff_rho = rho_a - rho_b
579        var_diff = float(np.var(diff_rho, ddof=1)) / n
580        se_diff = float(np.sqrt(max(var_diff, 0.0)))
581
582        diff = theta_a - theta_b
583        # Guard: when both feature arguments are identical, diff==0 and SE==0.
584        # 0/0 is indeterminate; the correct result is z_stat=0, p_value=1.
585        if diff == 0.0 and se_diff == 0.0:
586            z_stat = 0.0
587            p_value = 1.0
588        else:
589            z_stat = diff / se_diff if se_diff > 0 else float("inf")
590            # One-sided p-value for H1: theta_a > theta_b
591            p_value = float(1.0 - stats.norm.cdf(z_stat))
592
593        # Two-sided CI on the difference
594        z_ci = float(stats.norm.ppf((1 + self.ci_level) / 2))
595        ci_lower = diff - z_ci * se_diff
596        ci_upper = diff + z_ci * se_diff
597
598        return {
599            "diff": diff,
600            "se_diff": se_diff,
601            "z_stat": z_stat,
602            "p_value": p_value,
603            "ci_lower": ci_lower,
604            "ci_upper": ci_upper,
605        }
606
607    def plot_importance(
608        self,
609        top_n: int | None = None,
610        ax: Any | None = None,
611        sort: bool = True,
612    ) -> Any:
613        """
614        Bar chart of theta_hat with CI error bars.
615
616        Styled for insurance governance presentations: clean background,
617        coloured bars by significance (dark blue = CI excludes zero, grey =
618        CI includes zero), error bars at self.ci_level.
619
620        Args:
621            top_n: Show only top N features by theta_hat. None shows all.
622            ax: matplotlib Axes. If None, creates a new figure.
623            sort: Sort by theta_hat descending. Default True.
624
625        Returns:
626            The matplotlib Axes object.
627        """
628        self._check_fitted()
629
630        try:
631            import matplotlib.pyplot as plt
632        except ImportError as e:
633            raise ImportError(
634                "matplotlib is required for plot_importance(). "
635                "Install with: pip install shap-relativities[plot]"
636            ) from e
637
638        tbl = self.importance_table()
639        if sort:
640            tbl = tbl.sort("theta_hat", descending=True)
641        if top_n is not None:
642            tbl = tbl.head(top_n)
643
644        features = tbl["feature"].to_list()
645        theta = tbl["theta_hat"].to_numpy()
646        lower = tbl["theta_lower"].to_numpy()
647        upper = tbl["theta_upper"].to_numpy()
648
649        err_low = theta - lower
650        err_high = upper - theta
651
652        # Colour by significance: CI excludes zero => dark teal, else grey
653        significant = lower > 0
654        colours = ["#1a6b7c" if s else "#9e9e9e" for s in significant]
655
656        if ax is None:
657            fig, ax = plt.subplots(figsize=(max(6, len(features) * 0.6 + 2), 5))
658
659        y_pos = np.arange(len(features))
660        ax.barh(
661            y_pos,
662            theta,
663            xerr=[err_low, err_high],
664            color=colours,
665            capsize=4,
666            edgecolor="white",
667            linewidth=0.5,
668            error_kw={"elinewidth": 1.2, "capthick": 1.2, "ecolor": "#555555"},
669        )
670
671        ax.set_yticks(y_pos)
672        ax.set_yticklabels(features, fontsize=10)
673        ax.invert_yaxis()
674        ax.axvline(0, color="black", linewidth=0.8, linestyle="--", alpha=0.5)
675
676        p_label = f"E[|φ(X)|^{{{self.p:.4g}}}]"
677        ax.set_xlabel(p_label, fontsize=11)
678        ax.set_title(
679            f"Global SHAP Feature Importance ({int(self.ci_level * 100)}% CI)",
680            fontsize=12,
681        )
682
683        ax.spines["top"].set_visible(False)
684        ax.spines["right"].set_visible(False)
685        ax.set_facecolor("#f8f8f8")
686        if hasattr(ax, "figure") and ax.figure is not None:
687            ax.figure.set_facecolor("white")
688
689        # Legend
690        from matplotlib.patches import Patch
691        legend_elements = [
692            Patch(facecolor="#1a6b7c", label="CI excludes 0"),
693            Patch(facecolor="#9e9e9e", label="CI includes 0"),
694        ]
695        ax.legend(handles=legend_elements, fontsize=9, loc="lower right")
696
697        return ax
698
699    @property
700    def influence_matrix(self) -> np.ndarray:
701        """
702        Influence function matrix, shape (n_obs, n_features).
703
704        rho[i, j] is observation i's contribution to theta_hat_j.
705        Observations with large |rho[i, j]| are high-leverage for feature j's
706        importance estimate — useful for identifying influential policies in
707        governance reviews.
708        """
709        self._check_fitted()
710        assert self._rho is not None
711        return self._rho.copy()

Asymptotically valid confidence intervals for global SHAP feature importance.

Implements the de-biased U-statistic estimator from Whitehouse, Sawarni, Syrgkanis (2026), arXiv:2602.10532. Provides CIs for theta_p = E[|phi_a(X)|^p] for any p >= 1 for each feature.

The most common use cases are:

p=1: mean absolute SHAP (standard SHAP importance bar chart) p=2: mean squared SHAP (variance-like, cleaner theory)

IMPORTANT: Valid inference requires interventional SHAP values, not the default path-dependent TreeSHAP. Use SHAPRelativities with feature_perturbation='interventional' before calling SHAPInference, or pass interventional SHAP values directly.

Arguments:
  • shap_values: np.ndarray of shape (n_obs, n_features). SHAP values, one column per feature. Must be interventional SHAP for theoretical validity. Path-dependent SHAP will produce point estimates and intervals, but coverage guarantees do not hold.
  • y: np.ndarray of shape (n_obs,). Observed outcomes (claim counts or claim amounts). Required for the alpha nuisance correction.
  • feature_names: List[str] of length n_features. Column names.
  • p: float >= 1. Power for importance measure. Default 2.0 (mean squared SHAP). p=1 gives mean absolute SHAP (the standard bar chart metric) but requires smoothing. p=2 has cleaner asymptotic theory.
  • n_folds: int >= 2. Number of cross-fitting folds. Default 5. More folds reduce bias from the cross-fitting but increase compute.
  • nuisance_estimator: str or sklearn estimator. Used for mu_hat (E[Y|X]) and gamma_hat. Default 'gradient_boosting' uses HistGradientBoostingRegressor.
  • alpha_estimator: str or sklearn estimator. Used for alpha_hat. Defaults to same as nuisance_estimator.
  • beta_n: float or None. Smoothing parameter for p < 2. If None, computed as n^{(2-p)/(2*(p+1))} (assumes delta=1). Only used when p < 2.
  • ci_level: float. Two-sided confidence level. Default 0.95.
  • n_jobs: int. Placeholder for future parallelism. Currently unused; features are estimated sequentially.
  • random_state: int or None. Controls fold splitting for reproducibility.

Examples

>>> import numpy as np
>>> rng = np.random.default_rng(0)
>>> shap_vals = rng.normal(size=(500, 3))
>>> y = rng.poisson(1.0, size=500).astype(float)
>>> si = SHAPInference(shap_vals, y, feature_names=["a", "b", "c"], p=2)
>>> si.fit()
SHAPInference(n_obs=500, n_features=3, p=2.0, n_folds=5)
>>> tbl = si.importance_table()
SHAPInference( shap_values: numpy.ndarray, y: numpy.ndarray, feature_names: list[str], p: float = 2.0, n_folds: int = 5, nuisance_estimator: str | typing.Any = 'gradient_boosting', alpha_estimator: str | typing.Any = 'gradient_boosting', beta_n: float | None = None, ci_level: float = 0.95, n_jobs: int = 1, random_state: int | None = None)
324    def __init__(
325        self,
326        shap_values: np.ndarray,
327        y: np.ndarray,
328        feature_names: list[str],
329        p: float = 2.0,
330        n_folds: int = 5,
331        nuisance_estimator: str | Any = "gradient_boosting",
332        alpha_estimator: str | Any = "gradient_boosting",
333        beta_n: float | None = None,
334        ci_level: float = 0.95,
335        n_jobs: int = 1,
336        random_state: int | None = None,
337    ) -> None:
338        # --- Input validation ---
339        shap_values = np.asarray(shap_values, dtype=float)
340        y = np.asarray(y, dtype=float)
341
342        if shap_values.ndim != 2:
343            raise ValueError(
344                f"shap_values must be 2D array (n_obs, n_features), "
345                f"got shape {shap_values.shape}"
346            )
347        if y.ndim != 1:
348            raise ValueError(f"y must be 1D array, got shape {y.shape}")
349        if shap_values.shape[0] != len(y):
350            raise ValueError(
351                f"shap_values and y must have the same number of observations. "
352                f"Got shap_values.shape[0]={shap_values.shape[0]} and len(y)={len(y)}."
353            )
354        if len(feature_names) != shap_values.shape[1]:
355            raise ValueError(
356                f"len(feature_names)={len(feature_names)} must equal "
357                f"shap_values.shape[1]={shap_values.shape[1]}."
358            )
359        if p < 1.0:
360            raise ValueError(f"p must be >= 1. Got p={p}.")
361        if n_folds < 2:
362            raise ValueError(f"n_folds must be >= 2. Got n_folds={n_folds}.")
363        if not (0.0 < ci_level < 1.0):
364            raise ValueError(f"ci_level must be in (0, 1). Got ci_level={ci_level}.")
365        if len(feature_names) != len(set(feature_names)):
366            raise ValueError("feature_names must be unique.")
367
368        self.shap_values = shap_values
369        self.y = y
370        self.feature_names = list(feature_names)
371        self.p = float(p)
372        self.n_folds = n_folds
373        self.nuisance_estimator = nuisance_estimator
374        self.alpha_estimator = alpha_estimator
375        self.beta_n = beta_n
376        self.ci_level = ci_level
377        self.n_jobs = n_jobs
378        self.random_state = random_state
379
380        # Fitted attributes — populated by fit()
381        self._theta_hat: np.ndarray | None = None   # shape (n_features,)
382        self._se: np.ndarray | None = None          # shape (n_features,)
383        self._rho: np.ndarray | None = None         # shape (n_obs, n_features)
384        self._is_fitted: bool = False
shap_values
y
feature_names
p
n_folds
nuisance_estimator
alpha_estimator
beta_n
ci_level
n_jobs
random_state
def fit(self) -> SHAPInference:
394    def fit(self) -> "SHAPInference":
395        """
396        Estimate nuisance functions via cross-fitting and compute de-biased
397        theta_hat_p for each feature.
398
399        The algorithm:
400        1. Split observations into n_folds folds.
401        2. For each fold, train mu_hat, gamma_hat, alpha_hat on the complement.
402        3. Evaluate nuisances on the held-out fold.
403        4. Assemble full-data nuisance predictions.
404        5. Compute the influence function rho_a for each feature.
405        6. theta_hat = mean(rho_a), SE = sqrt(var(rho_a) / n).
406
407        Returns:
408            self, for method chaining.
409        """
410        if not _SKLEARN_AVAILABLE:
411            raise ImportError(
412                "scikit-learn >= 1.3 is required for SHAPInference. "
413                "Install with: pip install shap-relativities[ml]"
414            )
415
416        n, d = self.shap_values.shape
417
418        # Warn if p < 2: smoothing is active
419        effective_beta_n = self.beta_n
420        if self.p < 2.0:
421            if effective_beta_n is None:
422                effective_beta_n = _default_beta_n(n, self.p)
423            warnings.warn(
424                f"p={self.p} < 2: using smoothed estimator phi_{{p,beta}} "
425                f"with beta_n={effective_beta_n:.3f}. "
426                "Coverage is asymptotically valid but may be approximate for "
427                "features with many near-zero SHAP values. "
428                "Consider p=2 for cleaner guarantees.",
429                UserWarning,
430                stacklevel=2,
431            )
432        else:
433            effective_beta_n = 0.0  # unused but keeps type consistent
434
435        nu_est = _make_nuisance_estimator(self.nuisance_estimator)
436        al_est = _make_nuisance_estimator(self.alpha_estimator)
437
438        kf = KFold(n_splits=self.n_folds, shuffle=True, random_state=self.random_state)
439
440        theta_hats = np.zeros(d)
441        ses = np.zeros(d)
442        rhos = np.zeros((n, d))
443
444        for j in range(d):
445            theta_j, se_j, rho_j = _fit_single_feature(
446                phi_col=self.shap_values[:, j],
447                y=self.y,
448                shap_matrix=self.shap_values,
449                p=self.p,
450                beta_n=effective_beta_n,
451                kf=kf,
452                nuisance_estimator=nu_est,
453                alpha_estimator=al_est,
454            )
455            theta_hats[j] = theta_j
456            ses[j] = se_j
457            rhos[:, j] = rho_j
458
459        self._theta_hat = theta_hats
460        self._se = ses
461        self._rho = rhos
462        self._is_fitted = True
463
464        return self

Estimate nuisance functions via cross-fitting and compute de-biased theta_hat_p for each feature.

The algorithm:

  1. Split observations into n_folds folds.
  2. For each fold, train mu_hat, gamma_hat, alpha_hat on the complement.
  3. Evaluate nuisances on the held-out fold.
  4. Assemble full-data nuisance predictions.
  5. Compute the influence function rho_a for each feature.
  6. theta_hat = mean(rho_a), SE = sqrt(var(rho_a) / n).
Returns:

self, for method chaining.

def importance_table(self) -> polars.dataframe.frame.DataFrame:
470    def importance_table(self) -> pl.DataFrame:
471        """
472        Return feature importance estimates with confidence intervals.
473
474        All theta_hat values are theoretically non-negative (they estimate
475        E[|phi|^p]), but may be slightly negative for features with very
476        small true importance — this is expected sampling variability.
477
478        Returns:
479            Polars DataFrame with columns:
480              feature:          Feature name
481              theta_hat:        Point estimate of E[|phi_a(X)|^p]
482              theta_lower:      Lower CI bound
483              theta_upper:      Upper CI bound
484              sigma_hat:        sqrt(Var[rho_a]) — asymptotic std dev
485              se:               Standard error = sigma_hat / sqrt(n)
486              rank:             Rank by theta_hat (1 = most important)
487              rank_lower:       Conservative rank (using theta_lower)
488              rank_upper:       Optimistic rank (using theta_upper)
489              p_value_nonzero:  Two-sided p-value for H0: theta_p = 0
490        """
491        self._check_fitted()
492        assert self._theta_hat is not None
493        assert self._se is not None
494        assert self._rho is not None
495
496        n = self.shap_values.shape[0]
497        z = float(stats.norm.ppf((1 + self.ci_level) / 2))
498
499        theta = self._theta_hat
500        se = self._se
501        lower = theta - z * se
502        upper = theta + z * se
503
504        # sigma_hat = SE * sqrt(n) = sqrt(Var[rho])
505        sigma_hat = se * np.sqrt(n)
506
507        # Ranks (1 = most important)
508        rank = _dense_rank_descending(theta)
509        rank_lower = _dense_rank_descending(lower)  # conservative: lower bound
510        rank_upper = _dense_rank_descending(upper)  # optimistic: upper bound
511
512        # p-value: two-sided test H0: theta_p = 0
513        # Under H0, z_stat = theta_hat / SE ~ N(0,1) asymptotically
514        with np.errstate(divide="ignore", invalid="ignore"):
515            z_stat = np.where(se > 0, theta / se, np.inf)
516        p_values = 2.0 * (1.0 - stats.norm.cdf(np.abs(z_stat)))
517
518        return pl.DataFrame({
519            "feature": self.feature_names,
520            "theta_hat": theta.tolist(),
521            "theta_lower": lower.tolist(),
522            "theta_upper": upper.tolist(),
523            "sigma_hat": sigma_hat.tolist(),
524            "se": se.tolist(),
525            # _dense_rank_descending already returns list[int]; no .tolist() needed
526            "rank": rank,
527            "rank_lower": rank_lower,
528            "rank_upper": rank_upper,
529            "p_value_nonzero": p_values.tolist(),
530        }).sort("rank")

Return feature importance estimates with confidence intervals.

All theta_hat values are theoretically non-negative (they estimate E[|phi|^p]), but may be slightly negative for features with very small true importance — this is expected sampling variability.

Returns:

Polars DataFrame with columns: feature: Feature name theta_hat: Point estimate of E[|phi_a(X)|^p] theta_lower: Lower CI bound theta_upper: Upper CI bound sigma_hat: sqrt(Var[rho_a]) — asymptotic std dev se: Standard error = sigma_hat / sqrt(n) rank: Rank by theta_hat (1 = most important) rank_lower: Conservative rank (using theta_lower) rank_upper: Optimistic rank (using theta_upper) p_value_nonzero: Two-sided p-value for H0: theta_p = 0

def ranking_ci(self, feature_a: str, feature_b: str) -> dict[str, float]:
532    def ranking_ci(self, feature_a: str, feature_b: str) -> dict[str, float]:
533        """
534        Test whether feature_a has strictly higher importance than feature_b.
535
536        Uses the joint asymptotic distribution of (theta_hat_a, theta_hat_b).
537        The covariance is estimated from influence function cross-products,
538        which accounts for the fact that both features' rho vectors are
539        computed from the same observations.
540
541        H0: theta_a = theta_b
542        H1: theta_a > theta_b  (one-sided)
543
544        Args:
545            feature_a: Name of the first feature.
546            feature_b: Name of the second feature.
547
548        Returns:
549            dict with:
550              diff:      theta_hat_a - theta_hat_b
551              se_diff:   Standard error of the difference
552              z_stat:    Standardised test statistic
553              p_value:   One-sided p-value for H1: theta_a > theta_b
554              ci_lower:  Lower bound on (theta_a - theta_b) at self.ci_level
555              ci_upper:  Upper bound on (theta_a - theta_b) at self.ci_level
556        """
557        self._check_fitted()
558        assert self._theta_hat is not None
559        assert self._rho is not None
560
561        if feature_a not in self.feature_names:
562            raise ValueError(f"feature_a='{feature_a}' not in feature_names.")
563        if feature_b not in self.feature_names:
564            raise ValueError(f"feature_b='{feature_b}' not in feature_names.")
565
566        j_a = self.feature_names.index(feature_a)
567        j_b = self.feature_names.index(feature_b)
568
569        n = self.shap_values.shape[0]
570        theta_a = float(self._theta_hat[j_a])
571        theta_b = float(self._theta_hat[j_b])
572
573        rho_a = self._rho[:, j_a]
574        rho_b = self._rho[:, j_b]
575
576        # Var(theta_hat_a - theta_hat_b) = Var(mean(rho_a - rho_b)) / n
577        # = Var(rho_a - rho_b) / n
578        diff_rho = rho_a - rho_b
579        var_diff = float(np.var(diff_rho, ddof=1)) / n
580        se_diff = float(np.sqrt(max(var_diff, 0.0)))
581
582        diff = theta_a - theta_b
583        # Guard: when both feature arguments are identical, diff==0 and SE==0.
584        # 0/0 is indeterminate; the correct result is z_stat=0, p_value=1.
585        if diff == 0.0 and se_diff == 0.0:
586            z_stat = 0.0
587            p_value = 1.0
588        else:
589            z_stat = diff / se_diff if se_diff > 0 else float("inf")
590            # One-sided p-value for H1: theta_a > theta_b
591            p_value = float(1.0 - stats.norm.cdf(z_stat))
592
593        # Two-sided CI on the difference
594        z_ci = float(stats.norm.ppf((1 + self.ci_level) / 2))
595        ci_lower = diff - z_ci * se_diff
596        ci_upper = diff + z_ci * se_diff
597
598        return {
599            "diff": diff,
600            "se_diff": se_diff,
601            "z_stat": z_stat,
602            "p_value": p_value,
603            "ci_lower": ci_lower,
604            "ci_upper": ci_upper,
605        }

Test whether feature_a has strictly higher importance than feature_b.

Uses the joint asymptotic distribution of (theta_hat_a, theta_hat_b). The covariance is estimated from influence function cross-products, which accounts for the fact that both features' rho vectors are computed from the same observations.

H0: theta_a = theta_b H1: theta_a > theta_b (one-sided)

Arguments:
  • feature_a: Name of the first feature.
  • feature_b: Name of the second feature.
Returns:

dict with: diff: theta_hat_a - theta_hat_b se_diff: Standard error of the difference z_stat: Standardised test statistic p_value: One-sided p-value for H1: theta_a > theta_b ci_lower: Lower bound on (theta_a - theta_b) at self.ci_level ci_upper: Upper bound on (theta_a - theta_b) at self.ci_level

def plot_importance( self, top_n: int | None = None, ax: typing.Any | None = None, sort: bool = True) -> Any:
607    def plot_importance(
608        self,
609        top_n: int | None = None,
610        ax: Any | None = None,
611        sort: bool = True,
612    ) -> Any:
613        """
614        Bar chart of theta_hat with CI error bars.
615
616        Styled for insurance governance presentations: clean background,
617        coloured bars by significance (dark blue = CI excludes zero, grey =
618        CI includes zero), error bars at self.ci_level.
619
620        Args:
621            top_n: Show only top N features by theta_hat. None shows all.
622            ax: matplotlib Axes. If None, creates a new figure.
623            sort: Sort by theta_hat descending. Default True.
624
625        Returns:
626            The matplotlib Axes object.
627        """
628        self._check_fitted()
629
630        try:
631            import matplotlib.pyplot as plt
632        except ImportError as e:
633            raise ImportError(
634                "matplotlib is required for plot_importance(). "
635                "Install with: pip install shap-relativities[plot]"
636            ) from e
637
638        tbl = self.importance_table()
639        if sort:
640            tbl = tbl.sort("theta_hat", descending=True)
641        if top_n is not None:
642            tbl = tbl.head(top_n)
643
644        features = tbl["feature"].to_list()
645        theta = tbl["theta_hat"].to_numpy()
646        lower = tbl["theta_lower"].to_numpy()
647        upper = tbl["theta_upper"].to_numpy()
648
649        err_low = theta - lower
650        err_high = upper - theta
651
652        # Colour by significance: CI excludes zero => dark teal, else grey
653        significant = lower > 0
654        colours = ["#1a6b7c" if s else "#9e9e9e" for s in significant]
655
656        if ax is None:
657            fig, ax = plt.subplots(figsize=(max(6, len(features) * 0.6 + 2), 5))
658
659        y_pos = np.arange(len(features))
660        ax.barh(
661            y_pos,
662            theta,
663            xerr=[err_low, err_high],
664            color=colours,
665            capsize=4,
666            edgecolor="white",
667            linewidth=0.5,
668            error_kw={"elinewidth": 1.2, "capthick": 1.2, "ecolor": "#555555"},
669        )
670
671        ax.set_yticks(y_pos)
672        ax.set_yticklabels(features, fontsize=10)
673        ax.invert_yaxis()
674        ax.axvline(0, color="black", linewidth=0.8, linestyle="--", alpha=0.5)
675
676        p_label = f"E[|φ(X)|^{{{self.p:.4g}}}]"
677        ax.set_xlabel(p_label, fontsize=11)
678        ax.set_title(
679            f"Global SHAP Feature Importance ({int(self.ci_level * 100)}% CI)",
680            fontsize=12,
681        )
682
683        ax.spines["top"].set_visible(False)
684        ax.spines["right"].set_visible(False)
685        ax.set_facecolor("#f8f8f8")
686        if hasattr(ax, "figure") and ax.figure is not None:
687            ax.figure.set_facecolor("white")
688
689        # Legend
690        from matplotlib.patches import Patch
691        legend_elements = [
692            Patch(facecolor="#1a6b7c", label="CI excludes 0"),
693            Patch(facecolor="#9e9e9e", label="CI includes 0"),
694        ]
695        ax.legend(handles=legend_elements, fontsize=9, loc="lower right")
696
697        return ax

Bar chart of theta_hat with CI error bars.

Styled for insurance governance presentations: clean background, coloured bars by significance (dark blue = CI excludes zero, grey = CI includes zero), error bars at self.ci_level.

Arguments:
  • top_n: Show only top N features by theta_hat. None shows all.
  • ax: matplotlib Axes. If None, creates a new figure.
  • sort: Sort by theta_hat descending. Default True.
Returns:

The matplotlib Axes object.

influence_matrix: numpy.ndarray
699    @property
700    def influence_matrix(self) -> np.ndarray:
701        """
702        Influence function matrix, shape (n_obs, n_features).
703
704        rho[i, j] is observation i's contribution to theta_hat_j.
705        Observations with large |rho[i, j]| are high-leverage for feature j's
706        importance estimate — useful for identifying influential policies in
707        governance reviews.
708        """
709        self._check_fitted()
710        assert self._rho is not None
711        return self._rho.copy()

Influence function matrix, shape (n_obs, n_features).

rho[i, j] is observation i's contribution to theta_hat_j. Observations with large |rho[i, j]| are high-leverage for feature j's importance estimate — useful for identifying influential policies in governance reviews.

def extract_relativities( model: Any, X: Any, exposure: Any = None, categorical_features: list[str] | None = None, base_levels: dict[str, str | float | int] | None = None, ci_method: str = 'clt') -> polars.dataframe.frame.DataFrame:
54def extract_relativities(
55    model: Any,
56    X: Any,
57    exposure: Any = None,
58    categorical_features: list[str] | None = None,
59    base_levels: dict[str, str | float | int] | None = None,
60    ci_method: str = "clt",
61) -> pl.DataFrame:
62    """
63    One-shot extraction of SHAP relativities from a tree model.
64
65    Wraps SHAPRelativities.fit() and extract_relativities() for cases where
66    you don't need the intermediate object.
67
68    Args:
69        model: Trained CatBoost model with a log-link objective (Poisson,
70            Tweedie, or Gamma). CatBoost is the recommended choice - it handles
71            categorical features natively without encoding.
72        X: Feature matrix. Accepts a Polars or pandas DataFrame. Polars is
73            preferred; pandas is accepted and converted internally.
74        exposure: Earned policy years. If None, all observations are equally
75            weighted.
76        categorical_features: Features to aggregate by level. If None, all
77            non-numeric columns are treated as categorical.
78        base_levels: Base level for each categorical feature (gets
79            relativity = 1.0).
80        ci_method: "clt" (default) or "none".
81
82    Returns:
83        Polars DataFrame with columns: feature, level, relativity, lower_ci,
84        upper_ci, mean_shap, shap_std, n_obs, exposure_weight.
85    """
86    sr = SHAPRelativities(model, X, exposure, categorical_features)
87    sr.fit()
88    return sr.extract_relativities(base_levels=base_levels, ci_method=ci_method)

One-shot extraction of SHAP relativities from a tree model.

Wraps SHAPRelativities.fit() and extract_relativities() for cases where you don't need the intermediate object.

Arguments:
  • model: Trained CatBoost model with a log-link objective (Poisson, Tweedie, or Gamma). CatBoost is the recommended choice - it handles categorical features natively without encoding.
  • X: Feature matrix. Accepts a Polars or pandas DataFrame. Polars is preferred; pandas is accepted and converted internally.
  • exposure: Earned policy years. If None, all observations are equally weighted.
  • categorical_features: Features to aggregate by level. If None, all non-numeric columns are treated as categorical.
  • base_levels: Base level for each categorical feature (gets relativity = 1.0).
  • ci_method: "clt" (default) or "none".
Returns:

Polars DataFrame with columns: feature, level, relativity, lower_ci, upper_ci, mean_shap, shap_std, n_obs, exposure_weight.