shap_relativities
SHAP Relativities - actuarial-grade multiplicative relativities from tree models.
Converts a trained GBM's SHAP values into the same format as GLM exp(beta) relativities: a table of (feature, level, relativity) triples where the base level is 1.0 and relativities multiply together to give the model's expected prediction.
Typical usage
>>> from shap_relativities import SHAPRelativities
>>> sr = SHAPRelativities(model, X, exposure=df["exposure"],
... categorical_features=["area", "ncd_years"])
>>> sr.fit()
>>> rels = sr.extract_relativities(
... normalise_to="base_level",
... base_levels={"area": "A", "ncd_years": 0},
... )
For statistically valid SHAP importance CIs (v0.5.0+):
>>> from shap_relativities import SHAPInference
>>> si = SHAPInference(shap_values, y, feature_names=["age", "ncd", "area"])
>>> si.fit()
>>> si.importance_table()
Or use the convenience wrapper for one-liners:
>>> from shap_relativities import extract_relativities
>>> rels = extract_relativities(model, X, exposure=df["exposure"],
... categorical_features=["area"])
1""" 2SHAP Relativities - actuarial-grade multiplicative relativities from tree models. 3 4Converts a trained GBM's SHAP values into the same format as GLM exp(beta) 5relativities: a table of (feature, level, relativity) triples where the base 6level is 1.0 and relativities multiply together to give the model's expected 7prediction. 8 9Typical usage 10------------- 11>>> from shap_relativities import SHAPRelativities 12>>> sr = SHAPRelativities(model, X, exposure=df["exposure"], 13... categorical_features=["area", "ncd_years"]) 14>>> sr.fit() 15>>> rels = sr.extract_relativities( 16... normalise_to="base_level", 17... base_levels={"area": "A", "ncd_years": 0}, 18... ) 19 20For statistically valid SHAP importance CIs (v0.5.0+): 21 22>>> from shap_relativities import SHAPInference 23>>> si = SHAPInference(shap_values, y, feature_names=["age", "ncd", "area"]) 24>>> si.fit() 25>>> si.importance_table() 26 27Or use the convenience wrapper for one-liners: 28 29>>> from shap_relativities import extract_relativities 30>>> rels = extract_relativities(model, X, exposure=df["exposure"], 31... categorical_features=["area"]) 32""" 33 34from __future__ import annotations 35 36from typing import Any 37 38import polars as pl 39 40from ._core import SHAPRelativities 41from ._inference import SHAPInference 42 43__all__ = ["SHAPRelativities", "SHAPInference", "extract_relativities"] 44 45from importlib.metadata import version, PackageNotFoundError 46 47try: 48 __version__ = version("shap-relativities") 49except PackageNotFoundError: 50 __version__ = "0.0.0" # not installed 51 52 53def extract_relativities( 54 model: Any, 55 X: Any, 56 exposure: Any = None, 57 categorical_features: list[str] | None = None, 58 base_levels: dict[str, str | float | int] | None = None, 59 ci_method: str = "clt", 60) -> pl.DataFrame: 61 """ 62 One-shot extraction of SHAP relativities from a tree model. 63 64 Wraps SHAPRelativities.fit() and extract_relativities() for cases where 65 you don't need the intermediate object. 66 67 Args: 68 model: Trained CatBoost model with a log-link objective (Poisson, 69 Tweedie, or Gamma). CatBoost is the recommended choice - it handles 70 categorical features natively without encoding. 71 X: Feature matrix. Accepts a Polars or pandas DataFrame. Polars is 72 preferred; pandas is accepted and converted internally. 73 exposure: Earned policy years. If None, all observations are equally 74 weighted. 75 categorical_features: Features to aggregate by level. If None, all 76 non-numeric columns are treated as categorical. 77 base_levels: Base level for each categorical feature (gets 78 relativity = 1.0). 79 ci_method: "clt" (default) or "none". 80 81 Returns: 82 Polars DataFrame with columns: feature, level, relativity, lower_ci, 83 upper_ci, mean_shap, shap_std, n_obs, exposure_weight. 84 """ 85 sr = SHAPRelativities(model, X, exposure, categorical_features) 86 sr.fit() 87 return sr.extract_relativities(base_levels=base_levels, ci_method=ci_method)
154class SHAPRelativities: 155 """ 156 Extract multiplicative rating relativities from a tree model via SHAP. 157 158 Workflow:: 159 160 sr = SHAPRelativities( 161 model=catboost_model, 162 X=df.select(["area", "ncd_years", "has_convictions"]), 163 exposure=df["exposure"], 164 categorical_features=["area", "ncd_years"], 165 ) 166 sr.fit() 167 rels = sr.extract_relativities( 168 normalise_to="base_level", 169 base_levels={"area": "A", "ncd_years": 0}, 170 ) 171 172 Args: 173 model: A trained CatBoost model. Must use a log-link objective (Poisson, 174 Tweedie, Gamma). CatBoost is the recommended default - it handles 175 categoricals natively. 176 X: Feature matrix. Use training data for in-sample relativities, or a 177 representative holdout sample for out-of-sample. Polars DataFrames 178 are preferred; pandas DataFrames are accepted and converted 179 internally. 180 exposure: Earned policy years (or other volume measure). Used as 181 observation weights throughout. If None, all observations are 182 weighted equally. 183 categorical_features: Features to aggregate by level (bar-chart style). 184 If None, all non-numeric columns are treated as categorical. 185 continuous_features: Features to leave as per-observation points. 186 If None, all numeric columns are treated as continuous. 187 feature_perturbation: "tree_path_dependent" (default, fast, no 188 background data needed) or "interventional" (corrects for feature 189 correlation, needs background_data). 190 background_data: Required only if feature_perturbation="interventional". 191 n_background_samples: Number of background samples for interventional 192 SHAP. Default 1000. 193 annualise_exposure: If True and exposure is provided, subtract mean 194 log(exposure) from the expected_value to give an annualised 195 baseline. Default True. 196 verbose: If True (default), show a progress indicator during fit(). 197 Set to False to suppress all output — useful in batch jobs, 198 notebooks with tqdm already present, or when calling fit() 199 many times. If tqdm is not installed, falls back to a plain 200 elapsed-time message. 201 """ 202 203 def __init__( 204 self, 205 model: Any, 206 X: Any, 207 exposure: Any = None, 208 categorical_features: list[str] | None = None, 209 continuous_features: list[str] | None = None, 210 background_data: Any = None, 211 feature_perturbation: str = "tree_path_dependent", 212 n_background_samples: int = 1000, 213 annualise_exposure: bool = True, 214 verbose: bool = True, 215 ) -> None: 216 if not _SHAP_AVAILABLE: 217 raise ImportError( 218 "shap is required for SHAPRelativities. " 219 "Install it with: uv add 'shap-relativities[ml]'" 220 ) 221 222 self._model = model 223 self._X: pl.DataFrame = _to_polars(X) 224 self._background_data = ( 225 _to_polars(background_data) if background_data is not None else None 226 ) 227 228 # Normalise exposure to a numpy array 229 if exposure is None: 230 self._exposure: np.ndarray | None = None 231 elif isinstance(exposure, np.ndarray): 232 self._exposure = exposure 233 elif isinstance(exposure, pl.Series): 234 self._exposure = exposure.to_numpy() 235 else: 236 # pd.Series or similar 237 self._exposure = np.asarray(exposure) 238 239 # Validate exposure length matches X 240 if self._exposure is not None and len(self._exposure) != len(self._X): 241 raise ValueError( 242 f"exposure length ({len(self._exposure)}) does not match " 243 f"X length ({len(self._X)}). Both must have the same number of rows." 244 ) 245 246 self._feature_perturbation = feature_perturbation 247 self._n_background_samples = n_background_samples 248 self._annualise_exposure = annualise_exposure 249 self._verbose = verbose 250 251 # Classify features 252 self._categorical_features = categorical_features or self._infer_categorical() 253 self._continuous_features = continuous_features or self._infer_continuous() 254 255 # Populated by fit() 256 self._shap_values: np.ndarray | None = None 257 self._expected_value: float | None = None 258 self._is_fitted: bool = False 259 260 def _infer_categorical(self) -> list[str]: 261 numeric_types = ( 262 pl.Int8, pl.Int16, pl.Int32, pl.Int64, 263 pl.UInt8, pl.UInt16, pl.UInt32, pl.UInt64, 264 pl.Float32, pl.Float64, 265 ) 266 return [ 267 c for c in self._X.columns 268 if not isinstance(self._X[c].dtype, numeric_types) 269 ] 270 271 def _infer_continuous(self) -> list[str]: 272 numeric_types = ( 273 pl.Int8, pl.Int16, pl.Int32, pl.Int64, 274 pl.UInt8, pl.UInt16, pl.UInt32, pl.UInt64, 275 pl.Float32, pl.Float64, 276 ) 277 return [ 278 c for c in self._X.columns 279 if isinstance(self._X[c].dtype, numeric_types) 280 and c not in (self._categorical_features or []) 281 ] 282 283 def fit(self, verbose: bool | None = None) -> "SHAPRelativities": 284 """ 285 Compute SHAP values for all features in X. 286 287 Must be called before extract_relativities(). Calling fit() again 288 recomputes SHAP values (e.g. after changing X or the background data). 289 290 The feature matrix is converted to pandas internally for shap's 291 TreeExplainer. The conversion is a necessary bridge - shap requires 292 pandas for column name handling. 293 294 For large datasets (50k+ policies) SHAP computation can take 30-60 295 seconds. A progress indicator is shown by default. To suppress it, 296 pass verbose=False here or set verbose=False in __init__. 297 298 Args: 299 verbose: Override the instance-level verbose setting for this call 300 only. If None (default), the instance-level setting is used. 301 302 Returns: 303 Self, for method chaining. 304 """ 305 show_progress = self._verbose if verbose is None else verbose 306 307 # Convert to pandas for shap - unavoidable bridge 308 X_pd = _to_pandas(self._X) 309 310 bg_data = None 311 if self._feature_perturbation == "interventional": 312 if self._background_data is not None: 313 bg_data = _to_pandas(self._background_data) 314 else: 315 n_bg = min(self._n_background_samples, len(X_pd)) 316 bg_data = shap.sample(X_pd, n_bg) 317 318 explainer = shap.TreeExplainer( 319 self._model, 320 data=bg_data, 321 feature_perturbation=self._feature_perturbation, 322 model_output="raw", 323 ) 324 325 def _compute_shap(): 326 return explainer.shap_values(X_pd) 327 328 raw = _run_with_spinner(_compute_shap, n_obs=len(X_pd), verbose=show_progress) 329 330 # Some models return a list when there is a single output 331 if isinstance(raw, list): 332 if len(raw) == 1: 333 raw = raw[0] 334 else: 335 raise ValueError( 336 f"Model has {len(raw)} outputs. SHAPRelativities supports " 337 "single-output models only." 338 ) 339 340 self._shap_values = raw 341 342 ev = explainer.expected_value 343 if isinstance(ev, (list, np.ndarray)): 344 ev = float(ev[0]) 345 self._expected_value = float(ev) 346 347 self._is_fitted = True 348 return self 349 350 def _check_fitted(self) -> None: 351 if not self._is_fitted: 352 raise RuntimeError("Call fit() before using this method.") 353 354 def shap_values(self) -> np.ndarray: 355 """ 356 Raw SHAP values, shape (n_obs, n_features), in log space. 357 358 Returns: 359 Array of shape (n_obs, n_features). 360 """ 361 self._check_fitted() 362 return self._shap_values # type: ignore[return-value] 363 364 def baseline(self) -> float: 365 """ 366 exp(expected_value) - the base rate in prediction space. 367 368 If annualise_exposure=True and exposure was provided, this is adjusted 369 for the average log-exposure offset so it represents an annualised rate. 370 371 Returns: 372 Base rate as a float. 373 """ 374 self._check_fitted() 375 ev = self._expected_value # type: ignore[assignment] 376 377 if self._annualise_exposure and self._exposure is not None: 378 mean_log_exp = float(np.mean(np.log(np.clip(self._exposure, 1e-9, None)))) 379 ev = ev - mean_log_exp 380 381 return float(np.exp(ev)) 382 383 def extract_relativities( 384 self, 385 normalise_to: str = "base_level", 386 base_levels: dict[str, str | float | int] | None = None, 387 ci_method: str = "clt", 388 n_bootstrap: int = 200, 389 ci_level: float = 0.95, 390 ) -> pl.DataFrame: 391 """ 392 Extract multiplicative relativities from SHAP values. 393 394 Args: 395 normalise_to: "base_level" (base level for each feature gets 396 relativity = 1.0) or "mean" (exposure-weighted portfolio 397 mean = 1.0). 398 base_levels: Mapping of feature -> base level value. Required for 399 categorical features when normalise_to="base_level". Continuous 400 features automatically use mean normalisation regardless of 401 this setting. 402 ci_method: "clt" (CLT approximation, default, fast) or "none" (no 403 CIs). "bootstrap" is not yet implemented. 404 n_bootstrap: Ignored unless ci_method="bootstrap". 405 ci_level: Two-sided confidence level. Default 0.95. 406 407 Returns: 408 Polars DataFrame with columns: feature, level, relativity, 409 lower_ci, upper_ci, mean_shap, shap_std, n_obs, exposure_weight. 410 One row per (feature, level) combination. 411 """ 412 self._check_fitted() 413 414 _VALID_CI_METHODS = {"clt", "bootstrap", "none"} 415 if ci_method not in _VALID_CI_METHODS: 416 raise ValueError( 417 f"Unknown ci_method {ci_method!r}. " 418 f"Valid options are: {sorted(_VALID_CI_METHODS)}." 419 ) 420 421 if ci_method == "bootstrap": 422 raise NotImplementedError( 423 "Bootstrap CIs are not yet implemented. Use ci_method='clt'." 424 ) 425 426 base_levels = base_levels or {} 427 weights = ( 428 self._exposure if self._exposure is not None 429 else np.ones(len(self._X)) 430 ) 431 432 feature_names = self._X.columns 433 shap_vals = self._shap_values # type: ignore[assignment] 434 435 parts: list[pl.DataFrame] = [] 436 437 for i, feat in enumerate(feature_names): 438 feat_vals = self._X[feat].to_numpy() 439 shap_col = shap_vals[:, i] 440 441 is_categorical = feat in self._categorical_features 442 443 if is_categorical: 444 agg = aggregate_categorical(feat, feat_vals, shap_col, weights) 445 else: 446 agg = aggregate_continuous(feat, feat_vals, shap_col, weights) 447 448 # Normalisation 449 if normalise_to == "base_level" and is_categorical: 450 base = base_levels.get(feat) 451 if base is None: 452 # Fall back to the level with the smallest mean_shap as 453 # a sensible default (closest to intercept) 454 base = agg.sort("mean_shap")["level"][0] 455 warnings.warn( 456 f"No base level specified for '{feat}'. " 457 f"Using '{base}' (lowest mean SHAP) as base.", 458 UserWarning, 459 stacklevel=2, 460 ) 461 462 if ci_method == "none": 463 base_key = str(base) 464 base_rows = agg.filter(pl.col("level") == base_key) 465 base_shap = base_rows["mean_shap"][0] 466 agg = agg.with_columns([ 467 (pl.col("mean_shap") - base_shap).exp().alias("relativity"), 468 pl.lit(float("nan")).alias("lower_ci"), 469 pl.lit(float("nan")).alias("upper_ci"), 470 ]) 471 else: 472 agg = normalise_base_level(agg, base, ci_level=ci_level) 473 474 else: 475 # Mean normalisation for continuous features, or when 476 # normalise_to="mean" for any feature 477 if ci_method == "none": 478 total_weight = agg["exposure_weight"].sum() 479 portfolio_mean = float( 480 (agg["mean_shap"] * agg["exposure_weight"]).sum() 481 / total_weight 482 ) if total_weight > 0 else 0.0 483 agg = agg.with_columns([ 484 (pl.col("mean_shap") - portfolio_mean).exp().alias("relativity"), 485 pl.lit(float("nan")).alias("lower_ci"), 486 pl.lit(float("nan")).alias("upper_ci"), 487 ]) 488 else: 489 agg = normalise_mean(agg, ci_level=ci_level) 490 491 parts.append(agg) 492 493 # Cast the 'level' column to Utf8 in every part before concat. 494 # Categorical features produce level as Utf8; continuous features 495 # produce level as Float64. pl.concat with how="diagonal" cannot 496 # unify mismatched types for the same column name, so we normalise 497 # here rather than requiring callers to pre-cast their feature columns. 498 parts = [ 499 p.with_columns(pl.col("level").cast(pl.String)) 500 if "level" in p.columns else p 501 for p in parts 502 ] 503 504 result = pl.concat(parts, how="diagonal") 505 506 # Ensure standard column order (wsq_weight is internal, not exported) 507 available = [c for c in _RELATIVITY_COLUMNS if c in result.columns] 508 return result.select(available) 509 510 def extract_continuous_curve( 511 self, 512 feature: str, 513 n_points: int = 100, 514 smooth_method: str = "loess", 515 ) -> pl.DataFrame: 516 """ 517 Smoothed relativity curve for a continuous feature. 518 519 Args: 520 feature: Feature name. Must be in continuous_features. 521 n_points: Number of points in the output curve (not the input 522 data). 523 smooth_method: "loess" (locally weighted regression, requires 524 statsmodels), "isotonic" (monotone curve via isotonic 525 regression), or "none" (raw per-observation relativities). 526 527 Returns: 528 Polars DataFrame with columns: feature_value, relativity, 529 lower_ci, upper_ci. 530 531 Raises: 532 ValueError: If feature is not in X or smooth_method is unknown. 533 """ 534 self._check_fitted() 535 536 if feature not in self._X.columns: 537 raise ValueError(f"Feature '{feature}' not in X.") 538 539 feat_idx = self._X.columns.index(feature) 540 feat_vals = self._X[feature].to_numpy().astype(float) 541 shap_col = self._shap_values[:, feat_idx] # type: ignore[index] 542 weights = ( 543 self._exposure if self._exposure is not None 544 else np.ones(len(self._X)) 545 ) 546 547 # Exposure-weighted mean over the actual data distribution 548 portfolio_mean = np.average(shap_col, weights=weights) 549 relativities = np.exp(shap_col - portfolio_mean) 550 551 grid = np.linspace(feat_vals.min(), feat_vals.max(), n_points) 552 553 if smooth_method == "none": 554 order = np.argsort(feat_vals) 555 return pl.DataFrame({ 556 "feature_value": feat_vals[order], 557 "relativity": relativities[order], 558 "lower_ci": np.full(len(feat_vals), float("nan")), 559 "upper_ci": np.full(len(feat_vals), float("nan")), 560 }) 561 562 elif smooth_method == "isotonic": 563 from sklearn.isotonic import IsotonicRegression 564 ir = IsotonicRegression(out_of_bounds="clip") 565 ir.fit(feat_vals, shap_col, sample_weight=weights) 566 smoothed_shap = ir.predict(grid) 567 568 # P1-4 fix: normalise the smoothed curve so the exposure-weighted 569 # geometric mean of relativities = 1.0. 570 # The smooth is on the data, evaluated on a uniform grid. Subtracting 571 # portfolio_mean (computed on the data distribution) would be correct 572 # only if the grid were distributed like the data — it isn't. 573 # Instead, compute the data-distribution-weighted mean of the smoothed 574 # curve at the original data points, then use that as the reference. 575 smoothed_at_data = ir.predict(feat_vals) 576 weighted_mean_smoothed = np.average(smoothed_at_data, weights=weights) 577 smoothed_rel = np.exp(smoothed_shap - weighted_mean_smoothed) 578 579 return pl.DataFrame({ 580 "feature_value": grid, 581 "relativity": smoothed_rel, 582 "lower_ci": np.full(n_points, float("nan")), 583 "upper_ci": np.full(n_points, float("nan")), 584 }) 585 586 elif smooth_method == "loess": 587 try: 588 from statsmodels.nonparametric.smoothers_lowess import lowess 589 590 smoothed_shap = lowess( 591 shap_col, feat_vals, frac=0.3, it=3, 592 xvals=grid, is_sorted=False, 593 ) 594 595 # P1-4 fix: compute the smoothed values at original data points 596 # so the normalisation is data-distribution-weighted, not 597 # grid-uniform. Use the same lowess parameters. 598 smoothed_at_data = lowess( 599 shap_col, feat_vals, frac=0.3, it=3, 600 xvals=feat_vals, is_sorted=False, 601 ) 602 weighted_mean_smoothed = np.average(smoothed_at_data, weights=weights) 603 smoothed_rel = np.exp(smoothed_shap - weighted_mean_smoothed) 604 605 return pl.DataFrame({ 606 "feature_value": grid, 607 "relativity": smoothed_rel, 608 "lower_ci": np.full(n_points, float("nan")), 609 "upper_ci": np.full(n_points, float("nan")), 610 }) 611 except ImportError: 612 warnings.warn( 613 "statsmodels not installed; falling back to smooth_method='none'.", 614 UserWarning, 615 stacklevel=2, 616 ) 617 return self.extract_continuous_curve( 618 feature, n_points=n_points, smooth_method="none" 619 ) 620 621 else: 622 raise ValueError( 623 f"Unknown smooth_method '{smooth_method}'. " 624 "Choose from: 'loess', 'isotonic', 'none'." 625 ) 626 627 def validate(self) -> dict[str, CheckResult]: 628 """ 629 Run diagnostic checks on the SHAP computation. 630 631 Checks performed: 632 633 1. reconstruction: exp(shap.sum(1) + expected_value) should match 634 model predictions within tolerance. Material failure here indicates 635 the explainer was set up incorrectly. 636 637 2. feature_coverage: every feature in X should appear in the SHAP 638 output. Currently always passes given TreeExplainer's API. 639 640 3. sparse_levels: warns if any categorical level has fewer than 30 641 observations. CLT CIs will be unreliable for these levels. 642 643 Returns: 644 Dict with keys "reconstruction", "feature_coverage", 645 "sparse_levels". Each value is a CheckResult(passed, value, 646 message). 647 """ 648 self._check_fitted() 649 650 X_pd = _to_pandas(self._X) 651 652 # Get model predictions for reconstruction check 653 preds = None 654 if self._model is not None: 655 try: 656 preds = self._model.predict(X_pd) 657 except Exception: 658 preds = None 659 660 results: dict[str, CheckResult] = {} 661 662 if preds is not None: 663 results["reconstruction"] = check_reconstruction( 664 self._shap_values, # type: ignore[arg-type] 665 self._expected_value, # type: ignore[arg-type] 666 preds, 667 tolerance=1e-4, 668 ) 669 else: 670 results["reconstruction"] = CheckResult( 671 passed=False, 672 value=float("nan"), 673 message="Could not obtain model predictions for reconstruction check.", 674 ) 675 676 feature_names = self._X.columns 677 results["feature_coverage"] = check_feature_coverage( 678 feature_names, feature_names 679 ) 680 681 # Check sparse levels for categorical features 682 weights = ( 683 self._exposure if self._exposure is not None 684 else np.ones(len(self._X)) 685 ) 686 687 sparse_parts: list[pl.DataFrame] = [] 688 for feat in self._categorical_features: 689 if feat not in self._X.columns: 690 continue 691 feat_idx = self._X.columns.index(feat) 692 agg = aggregate_categorical( 693 feat, 694 self._X[feat].to_numpy(), 695 self._shap_values[:, feat_idx], # type: ignore[index] 696 weights, 697 ) 698 sparse_parts.append(agg) 699 700 if sparse_parts: 701 all_agg = pl.concat(sparse_parts, how="diagonal") 702 results["sparse_levels"] = check_sparse_levels(all_agg) 703 else: 704 results["sparse_levels"] = CheckResult( 705 passed=True, value=0.0, 706 message="No categorical features to check." 707 ) 708 709 return results 710 711 def plot_relativities( 712 self, 713 features: list[str] | None = None, 714 show_ci: bool = True, 715 figsize: tuple[int, int] = (12, 8), 716 ) -> None: 717 """ 718 Plot relativities as bar charts (categorical) or line charts (continuous). 719 720 Args: 721 features: Subset of features to plot. Defaults to all features. 722 show_ci: Whether to show confidence intervals. Default True. 723 figsize: Overall figure size. 724 """ 725 self._check_fitted() 726 727 from ._plotting import plot_relativities as _plot 728 729 rels = self.extract_relativities() 730 _plot( 731 rels, 732 categorical_features=self._categorical_features, 733 continuous_features=self._continuous_features, 734 features=features, 735 show_ci=show_ci, 736 figsize=figsize, 737 ) 738 739 def to_dict(self) -> dict[str, Any]: 740 """ 741 Serialisable representation of the fitted object. 742 743 Stores SHAP values, expected value, feature names, and feature 744 classification. Does not store the original model or X DataFrame. 745 746 Returns: 747 Dict suitable for JSON serialisation. 748 """ 749 self._check_fitted() 750 return { 751 "shap_values": self._shap_values.tolist(), # type: ignore[union-attr] 752 "expected_value": self._expected_value, 753 "feature_names": self._X.columns, 754 "categorical_features": self._categorical_features, 755 "continuous_features": self._continuous_features, 756 "X_values": {c: self._X[c].to_list() for c in self._X.columns}, 757 "exposure": ( 758 self._exposure.tolist() 759 if self._exposure is not None else None 760 ), 761 "annualise_exposure": self._annualise_exposure, 762 } 763 764 @classmethod 765 def from_dict(cls, data: dict[str, Any]) -> "SHAPRelativities": 766 """ 767 Reconstruct a fitted SHAPRelativities from to_dict() output. 768 769 The reconstructed object has no model attached, so validate() and 770 plot_relativities() still work but fit() cannot be re-run. 771 772 Args: 773 data: Output of to_dict(). 774 775 Returns: 776 Fitted SHAPRelativities instance. 777 """ 778 # P1-2 fix: use feature_names to control column ordering in X. 779 # Without this, tools that sort JSON object keys (REST APIs, some 780 # pretty-printers) reorder X_values, misaligning columns with the 781 # shap_values matrix columns. 782 feature_names: list[str] = data.get("feature_names", list(data["X_values"].keys())) 783 X = pl.DataFrame({k: data["X_values"][k] for k in feature_names}) 784 785 exposure = ( 786 np.array(data["exposure"]) if data.get("exposure") is not None 787 else None 788 ) 789 790 # Create a minimal instance without a real model 791 instance = cls.__new__(cls) 792 instance._model = None 793 instance._X = X 794 instance._exposure = exposure 795 instance._categorical_features = data.get("categorical_features", []) 796 instance._continuous_features = data.get("continuous_features", []) 797 instance._feature_perturbation = "tree_path_dependent" 798 instance._background_data = None 799 instance._n_background_samples = 1000 800 instance._annualise_exposure = data.get("annualise_exposure", True) 801 instance._verbose = False # no-op for deserialized instances 802 instance._shap_values = np.array(data["shap_values"]) 803 instance._expected_value = float(data["expected_value"]) 804 instance._is_fitted = True 805 806 return instance
Extract multiplicative rating relativities from a tree model via SHAP.
Workflow::
sr = SHAPRelativities(
model=catboost_model,
X=df.select(["area", "ncd_years", "has_convictions"]),
exposure=df["exposure"],
categorical_features=["area", "ncd_years"],
)
sr.fit()
rels = sr.extract_relativities(
normalise_to="base_level",
base_levels={"area": "A", "ncd_years": 0},
)
Arguments:
- model: A trained CatBoost model. Must use a log-link objective (Poisson, Tweedie, Gamma). CatBoost is the recommended default - it handles categoricals natively.
- X: Feature matrix. Use training data for in-sample relativities, or a representative holdout sample for out-of-sample. Polars DataFrames are preferred; pandas DataFrames are accepted and converted internally.
- exposure: Earned policy years (or other volume measure). Used as observation weights throughout. If None, all observations are weighted equally.
- categorical_features: Features to aggregate by level (bar-chart style). If None, all non-numeric columns are treated as categorical.
- continuous_features: Features to leave as per-observation points. If None, all numeric columns are treated as continuous.
- feature_perturbation: "tree_path_dependent" (default, fast, no background data needed) or "interventional" (corrects for feature correlation, needs background_data).
- background_data: Required only if feature_perturbation="interventional".
- n_background_samples: Number of background samples for interventional SHAP. Default 1000.
- annualise_exposure: If True and exposure is provided, subtract mean log(exposure) from the expected_value to give an annualised baseline. Default True.
- verbose: If True (default), show a progress indicator during fit(). Set to False to suppress all output — useful in batch jobs, notebooks with tqdm already present, or when calling fit() many times. If tqdm is not installed, falls back to a plain elapsed-time message.
203 def __init__( 204 self, 205 model: Any, 206 X: Any, 207 exposure: Any = None, 208 categorical_features: list[str] | None = None, 209 continuous_features: list[str] | None = None, 210 background_data: Any = None, 211 feature_perturbation: str = "tree_path_dependent", 212 n_background_samples: int = 1000, 213 annualise_exposure: bool = True, 214 verbose: bool = True, 215 ) -> None: 216 if not _SHAP_AVAILABLE: 217 raise ImportError( 218 "shap is required for SHAPRelativities. " 219 "Install it with: uv add 'shap-relativities[ml]'" 220 ) 221 222 self._model = model 223 self._X: pl.DataFrame = _to_polars(X) 224 self._background_data = ( 225 _to_polars(background_data) if background_data is not None else None 226 ) 227 228 # Normalise exposure to a numpy array 229 if exposure is None: 230 self._exposure: np.ndarray | None = None 231 elif isinstance(exposure, np.ndarray): 232 self._exposure = exposure 233 elif isinstance(exposure, pl.Series): 234 self._exposure = exposure.to_numpy() 235 else: 236 # pd.Series or similar 237 self._exposure = np.asarray(exposure) 238 239 # Validate exposure length matches X 240 if self._exposure is not None and len(self._exposure) != len(self._X): 241 raise ValueError( 242 f"exposure length ({len(self._exposure)}) does not match " 243 f"X length ({len(self._X)}). Both must have the same number of rows." 244 ) 245 246 self._feature_perturbation = feature_perturbation 247 self._n_background_samples = n_background_samples 248 self._annualise_exposure = annualise_exposure 249 self._verbose = verbose 250 251 # Classify features 252 self._categorical_features = categorical_features or self._infer_categorical() 253 self._continuous_features = continuous_features or self._infer_continuous() 254 255 # Populated by fit() 256 self._shap_values: np.ndarray | None = None 257 self._expected_value: float | None = None 258 self._is_fitted: bool = False
283 def fit(self, verbose: bool | None = None) -> "SHAPRelativities": 284 """ 285 Compute SHAP values for all features in X. 286 287 Must be called before extract_relativities(). Calling fit() again 288 recomputes SHAP values (e.g. after changing X or the background data). 289 290 The feature matrix is converted to pandas internally for shap's 291 TreeExplainer. The conversion is a necessary bridge - shap requires 292 pandas for column name handling. 293 294 For large datasets (50k+ policies) SHAP computation can take 30-60 295 seconds. A progress indicator is shown by default. To suppress it, 296 pass verbose=False here or set verbose=False in __init__. 297 298 Args: 299 verbose: Override the instance-level verbose setting for this call 300 only. If None (default), the instance-level setting is used. 301 302 Returns: 303 Self, for method chaining. 304 """ 305 show_progress = self._verbose if verbose is None else verbose 306 307 # Convert to pandas for shap - unavoidable bridge 308 X_pd = _to_pandas(self._X) 309 310 bg_data = None 311 if self._feature_perturbation == "interventional": 312 if self._background_data is not None: 313 bg_data = _to_pandas(self._background_data) 314 else: 315 n_bg = min(self._n_background_samples, len(X_pd)) 316 bg_data = shap.sample(X_pd, n_bg) 317 318 explainer = shap.TreeExplainer( 319 self._model, 320 data=bg_data, 321 feature_perturbation=self._feature_perturbation, 322 model_output="raw", 323 ) 324 325 def _compute_shap(): 326 return explainer.shap_values(X_pd) 327 328 raw = _run_with_spinner(_compute_shap, n_obs=len(X_pd), verbose=show_progress) 329 330 # Some models return a list when there is a single output 331 if isinstance(raw, list): 332 if len(raw) == 1: 333 raw = raw[0] 334 else: 335 raise ValueError( 336 f"Model has {len(raw)} outputs. SHAPRelativities supports " 337 "single-output models only." 338 ) 339 340 self._shap_values = raw 341 342 ev = explainer.expected_value 343 if isinstance(ev, (list, np.ndarray)): 344 ev = float(ev[0]) 345 self._expected_value = float(ev) 346 347 self._is_fitted = True 348 return self
Compute SHAP values for all features in X.
Must be called before extract_relativities(). Calling fit() again recomputes SHAP values (e.g. after changing X or the background data).
The feature matrix is converted to pandas internally for shap's TreeExplainer. The conversion is a necessary bridge - shap requires pandas for column name handling.
For large datasets (50k+ policies) SHAP computation can take 30-60 seconds. A progress indicator is shown by default. To suppress it, pass verbose=False here or set verbose=False in __init__.
Arguments:
- verbose: Override the instance-level verbose setting for this call only. If None (default), the instance-level setting is used.
Returns:
Self, for method chaining.
354 def shap_values(self) -> np.ndarray: 355 """ 356 Raw SHAP values, shape (n_obs, n_features), in log space. 357 358 Returns: 359 Array of shape (n_obs, n_features). 360 """ 361 self._check_fitted() 362 return self._shap_values # type: ignore[return-value]
Raw SHAP values, shape (n_obs, n_features), in log space.
Returns:
Array of shape (n_obs, n_features).
364 def baseline(self) -> float: 365 """ 366 exp(expected_value) - the base rate in prediction space. 367 368 If annualise_exposure=True and exposure was provided, this is adjusted 369 for the average log-exposure offset so it represents an annualised rate. 370 371 Returns: 372 Base rate as a float. 373 """ 374 self._check_fitted() 375 ev = self._expected_value # type: ignore[assignment] 376 377 if self._annualise_exposure and self._exposure is not None: 378 mean_log_exp = float(np.mean(np.log(np.clip(self._exposure, 1e-9, None)))) 379 ev = ev - mean_log_exp 380 381 return float(np.exp(ev))
exp(expected_value) - the base rate in prediction space.
If annualise_exposure=True and exposure was provided, this is adjusted for the average log-exposure offset so it represents an annualised rate.
Returns:
Base rate as a float.
383 def extract_relativities( 384 self, 385 normalise_to: str = "base_level", 386 base_levels: dict[str, str | float | int] | None = None, 387 ci_method: str = "clt", 388 n_bootstrap: int = 200, 389 ci_level: float = 0.95, 390 ) -> pl.DataFrame: 391 """ 392 Extract multiplicative relativities from SHAP values. 393 394 Args: 395 normalise_to: "base_level" (base level for each feature gets 396 relativity = 1.0) or "mean" (exposure-weighted portfolio 397 mean = 1.0). 398 base_levels: Mapping of feature -> base level value. Required for 399 categorical features when normalise_to="base_level". Continuous 400 features automatically use mean normalisation regardless of 401 this setting. 402 ci_method: "clt" (CLT approximation, default, fast) or "none" (no 403 CIs). "bootstrap" is not yet implemented. 404 n_bootstrap: Ignored unless ci_method="bootstrap". 405 ci_level: Two-sided confidence level. Default 0.95. 406 407 Returns: 408 Polars DataFrame with columns: feature, level, relativity, 409 lower_ci, upper_ci, mean_shap, shap_std, n_obs, exposure_weight. 410 One row per (feature, level) combination. 411 """ 412 self._check_fitted() 413 414 _VALID_CI_METHODS = {"clt", "bootstrap", "none"} 415 if ci_method not in _VALID_CI_METHODS: 416 raise ValueError( 417 f"Unknown ci_method {ci_method!r}. " 418 f"Valid options are: {sorted(_VALID_CI_METHODS)}." 419 ) 420 421 if ci_method == "bootstrap": 422 raise NotImplementedError( 423 "Bootstrap CIs are not yet implemented. Use ci_method='clt'." 424 ) 425 426 base_levels = base_levels or {} 427 weights = ( 428 self._exposure if self._exposure is not None 429 else np.ones(len(self._X)) 430 ) 431 432 feature_names = self._X.columns 433 shap_vals = self._shap_values # type: ignore[assignment] 434 435 parts: list[pl.DataFrame] = [] 436 437 for i, feat in enumerate(feature_names): 438 feat_vals = self._X[feat].to_numpy() 439 shap_col = shap_vals[:, i] 440 441 is_categorical = feat in self._categorical_features 442 443 if is_categorical: 444 agg = aggregate_categorical(feat, feat_vals, shap_col, weights) 445 else: 446 agg = aggregate_continuous(feat, feat_vals, shap_col, weights) 447 448 # Normalisation 449 if normalise_to == "base_level" and is_categorical: 450 base = base_levels.get(feat) 451 if base is None: 452 # Fall back to the level with the smallest mean_shap as 453 # a sensible default (closest to intercept) 454 base = agg.sort("mean_shap")["level"][0] 455 warnings.warn( 456 f"No base level specified for '{feat}'. " 457 f"Using '{base}' (lowest mean SHAP) as base.", 458 UserWarning, 459 stacklevel=2, 460 ) 461 462 if ci_method == "none": 463 base_key = str(base) 464 base_rows = agg.filter(pl.col("level") == base_key) 465 base_shap = base_rows["mean_shap"][0] 466 agg = agg.with_columns([ 467 (pl.col("mean_shap") - base_shap).exp().alias("relativity"), 468 pl.lit(float("nan")).alias("lower_ci"), 469 pl.lit(float("nan")).alias("upper_ci"), 470 ]) 471 else: 472 agg = normalise_base_level(agg, base, ci_level=ci_level) 473 474 else: 475 # Mean normalisation for continuous features, or when 476 # normalise_to="mean" for any feature 477 if ci_method == "none": 478 total_weight = agg["exposure_weight"].sum() 479 portfolio_mean = float( 480 (agg["mean_shap"] * agg["exposure_weight"]).sum() 481 / total_weight 482 ) if total_weight > 0 else 0.0 483 agg = agg.with_columns([ 484 (pl.col("mean_shap") - portfolio_mean).exp().alias("relativity"), 485 pl.lit(float("nan")).alias("lower_ci"), 486 pl.lit(float("nan")).alias("upper_ci"), 487 ]) 488 else: 489 agg = normalise_mean(agg, ci_level=ci_level) 490 491 parts.append(agg) 492 493 # Cast the 'level' column to Utf8 in every part before concat. 494 # Categorical features produce level as Utf8; continuous features 495 # produce level as Float64. pl.concat with how="diagonal" cannot 496 # unify mismatched types for the same column name, so we normalise 497 # here rather than requiring callers to pre-cast their feature columns. 498 parts = [ 499 p.with_columns(pl.col("level").cast(pl.String)) 500 if "level" in p.columns else p 501 for p in parts 502 ] 503 504 result = pl.concat(parts, how="diagonal") 505 506 # Ensure standard column order (wsq_weight is internal, not exported) 507 available = [c for c in _RELATIVITY_COLUMNS if c in result.columns] 508 return result.select(available)
Extract multiplicative relativities from SHAP values.
Arguments:
- normalise_to: "base_level" (base level for each feature gets relativity = 1.0) or "mean" (exposure-weighted portfolio mean = 1.0).
- base_levels: Mapping of feature -> base level value. Required for categorical features when normalise_to="base_level". Continuous features automatically use mean normalisation regardless of this setting.
- ci_method: "clt" (CLT approximation, default, fast) or "none" (no CIs). "bootstrap" is not yet implemented.
- n_bootstrap: Ignored unless ci_method="bootstrap".
- ci_level: Two-sided confidence level. Default 0.95.
Returns:
Polars DataFrame with columns: feature, level, relativity, lower_ci, upper_ci, mean_shap, shap_std, n_obs, exposure_weight. One row per (feature, level) combination.
510 def extract_continuous_curve( 511 self, 512 feature: str, 513 n_points: int = 100, 514 smooth_method: str = "loess", 515 ) -> pl.DataFrame: 516 """ 517 Smoothed relativity curve for a continuous feature. 518 519 Args: 520 feature: Feature name. Must be in continuous_features. 521 n_points: Number of points in the output curve (not the input 522 data). 523 smooth_method: "loess" (locally weighted regression, requires 524 statsmodels), "isotonic" (monotone curve via isotonic 525 regression), or "none" (raw per-observation relativities). 526 527 Returns: 528 Polars DataFrame with columns: feature_value, relativity, 529 lower_ci, upper_ci. 530 531 Raises: 532 ValueError: If feature is not in X or smooth_method is unknown. 533 """ 534 self._check_fitted() 535 536 if feature not in self._X.columns: 537 raise ValueError(f"Feature '{feature}' not in X.") 538 539 feat_idx = self._X.columns.index(feature) 540 feat_vals = self._X[feature].to_numpy().astype(float) 541 shap_col = self._shap_values[:, feat_idx] # type: ignore[index] 542 weights = ( 543 self._exposure if self._exposure is not None 544 else np.ones(len(self._X)) 545 ) 546 547 # Exposure-weighted mean over the actual data distribution 548 portfolio_mean = np.average(shap_col, weights=weights) 549 relativities = np.exp(shap_col - portfolio_mean) 550 551 grid = np.linspace(feat_vals.min(), feat_vals.max(), n_points) 552 553 if smooth_method == "none": 554 order = np.argsort(feat_vals) 555 return pl.DataFrame({ 556 "feature_value": feat_vals[order], 557 "relativity": relativities[order], 558 "lower_ci": np.full(len(feat_vals), float("nan")), 559 "upper_ci": np.full(len(feat_vals), float("nan")), 560 }) 561 562 elif smooth_method == "isotonic": 563 from sklearn.isotonic import IsotonicRegression 564 ir = IsotonicRegression(out_of_bounds="clip") 565 ir.fit(feat_vals, shap_col, sample_weight=weights) 566 smoothed_shap = ir.predict(grid) 567 568 # P1-4 fix: normalise the smoothed curve so the exposure-weighted 569 # geometric mean of relativities = 1.0. 570 # The smooth is on the data, evaluated on a uniform grid. Subtracting 571 # portfolio_mean (computed on the data distribution) would be correct 572 # only if the grid were distributed like the data — it isn't. 573 # Instead, compute the data-distribution-weighted mean of the smoothed 574 # curve at the original data points, then use that as the reference. 575 smoothed_at_data = ir.predict(feat_vals) 576 weighted_mean_smoothed = np.average(smoothed_at_data, weights=weights) 577 smoothed_rel = np.exp(smoothed_shap - weighted_mean_smoothed) 578 579 return pl.DataFrame({ 580 "feature_value": grid, 581 "relativity": smoothed_rel, 582 "lower_ci": np.full(n_points, float("nan")), 583 "upper_ci": np.full(n_points, float("nan")), 584 }) 585 586 elif smooth_method == "loess": 587 try: 588 from statsmodels.nonparametric.smoothers_lowess import lowess 589 590 smoothed_shap = lowess( 591 shap_col, feat_vals, frac=0.3, it=3, 592 xvals=grid, is_sorted=False, 593 ) 594 595 # P1-4 fix: compute the smoothed values at original data points 596 # so the normalisation is data-distribution-weighted, not 597 # grid-uniform. Use the same lowess parameters. 598 smoothed_at_data = lowess( 599 shap_col, feat_vals, frac=0.3, it=3, 600 xvals=feat_vals, is_sorted=False, 601 ) 602 weighted_mean_smoothed = np.average(smoothed_at_data, weights=weights) 603 smoothed_rel = np.exp(smoothed_shap - weighted_mean_smoothed) 604 605 return pl.DataFrame({ 606 "feature_value": grid, 607 "relativity": smoothed_rel, 608 "lower_ci": np.full(n_points, float("nan")), 609 "upper_ci": np.full(n_points, float("nan")), 610 }) 611 except ImportError: 612 warnings.warn( 613 "statsmodels not installed; falling back to smooth_method='none'.", 614 UserWarning, 615 stacklevel=2, 616 ) 617 return self.extract_continuous_curve( 618 feature, n_points=n_points, smooth_method="none" 619 ) 620 621 else: 622 raise ValueError( 623 f"Unknown smooth_method '{smooth_method}'. " 624 "Choose from: 'loess', 'isotonic', 'none'." 625 )
Smoothed relativity curve for a continuous feature.
Arguments:
- feature: Feature name. Must be in continuous_features.
- n_points: Number of points in the output curve (not the input data).
- smooth_method: "loess" (locally weighted regression, requires statsmodels), "isotonic" (monotone curve via isotonic regression), or "none" (raw per-observation relativities).
Returns:
Polars DataFrame with columns: feature_value, relativity, lower_ci, upper_ci.
Raises:
- ValueError: If feature is not in X or smooth_method is unknown.
627 def validate(self) -> dict[str, CheckResult]: 628 """ 629 Run diagnostic checks on the SHAP computation. 630 631 Checks performed: 632 633 1. reconstruction: exp(shap.sum(1) + expected_value) should match 634 model predictions within tolerance. Material failure here indicates 635 the explainer was set up incorrectly. 636 637 2. feature_coverage: every feature in X should appear in the SHAP 638 output. Currently always passes given TreeExplainer's API. 639 640 3. sparse_levels: warns if any categorical level has fewer than 30 641 observations. CLT CIs will be unreliable for these levels. 642 643 Returns: 644 Dict with keys "reconstruction", "feature_coverage", 645 "sparse_levels". Each value is a CheckResult(passed, value, 646 message). 647 """ 648 self._check_fitted() 649 650 X_pd = _to_pandas(self._X) 651 652 # Get model predictions for reconstruction check 653 preds = None 654 if self._model is not None: 655 try: 656 preds = self._model.predict(X_pd) 657 except Exception: 658 preds = None 659 660 results: dict[str, CheckResult] = {} 661 662 if preds is not None: 663 results["reconstruction"] = check_reconstruction( 664 self._shap_values, # type: ignore[arg-type] 665 self._expected_value, # type: ignore[arg-type] 666 preds, 667 tolerance=1e-4, 668 ) 669 else: 670 results["reconstruction"] = CheckResult( 671 passed=False, 672 value=float("nan"), 673 message="Could not obtain model predictions for reconstruction check.", 674 ) 675 676 feature_names = self._X.columns 677 results["feature_coverage"] = check_feature_coverage( 678 feature_names, feature_names 679 ) 680 681 # Check sparse levels for categorical features 682 weights = ( 683 self._exposure if self._exposure is not None 684 else np.ones(len(self._X)) 685 ) 686 687 sparse_parts: list[pl.DataFrame] = [] 688 for feat in self._categorical_features: 689 if feat not in self._X.columns: 690 continue 691 feat_idx = self._X.columns.index(feat) 692 agg = aggregate_categorical( 693 feat, 694 self._X[feat].to_numpy(), 695 self._shap_values[:, feat_idx], # type: ignore[index] 696 weights, 697 ) 698 sparse_parts.append(agg) 699 700 if sparse_parts: 701 all_agg = pl.concat(sparse_parts, how="diagonal") 702 results["sparse_levels"] = check_sparse_levels(all_agg) 703 else: 704 results["sparse_levels"] = CheckResult( 705 passed=True, value=0.0, 706 message="No categorical features to check." 707 ) 708 709 return results
Run diagnostic checks on the SHAP computation.
Checks performed:
reconstruction: exp(shap.sum(1) + expected_value) should match model predictions within tolerance. Material failure here indicates the explainer was set up incorrectly.
feature_coverage: every feature in X should appear in the SHAP output. Currently always passes given TreeExplainer's API.
sparse_levels: warns if any categorical level has fewer than 30 observations. CLT CIs will be unreliable for these levels.
Returns:
Dict with keys "reconstruction", "feature_coverage", "sparse_levels". Each value is a CheckResult(passed, value, message).
711 def plot_relativities( 712 self, 713 features: list[str] | None = None, 714 show_ci: bool = True, 715 figsize: tuple[int, int] = (12, 8), 716 ) -> None: 717 """ 718 Plot relativities as bar charts (categorical) or line charts (continuous). 719 720 Args: 721 features: Subset of features to plot. Defaults to all features. 722 show_ci: Whether to show confidence intervals. Default True. 723 figsize: Overall figure size. 724 """ 725 self._check_fitted() 726 727 from ._plotting import plot_relativities as _plot 728 729 rels = self.extract_relativities() 730 _plot( 731 rels, 732 categorical_features=self._categorical_features, 733 continuous_features=self._continuous_features, 734 features=features, 735 show_ci=show_ci, 736 figsize=figsize, 737 )
Plot relativities as bar charts (categorical) or line charts (continuous).
Arguments:
- features: Subset of features to plot. Defaults to all features.
- show_ci: Whether to show confidence intervals. Default True.
- figsize: Overall figure size.
739 def to_dict(self) -> dict[str, Any]: 740 """ 741 Serialisable representation of the fitted object. 742 743 Stores SHAP values, expected value, feature names, and feature 744 classification. Does not store the original model or X DataFrame. 745 746 Returns: 747 Dict suitable for JSON serialisation. 748 """ 749 self._check_fitted() 750 return { 751 "shap_values": self._shap_values.tolist(), # type: ignore[union-attr] 752 "expected_value": self._expected_value, 753 "feature_names": self._X.columns, 754 "categorical_features": self._categorical_features, 755 "continuous_features": self._continuous_features, 756 "X_values": {c: self._X[c].to_list() for c in self._X.columns}, 757 "exposure": ( 758 self._exposure.tolist() 759 if self._exposure is not None else None 760 ), 761 "annualise_exposure": self._annualise_exposure, 762 }
Serialisable representation of the fitted object.
Stores SHAP values, expected value, feature names, and feature classification. Does not store the original model or X DataFrame.
Returns:
Dict suitable for JSON serialisation.
764 @classmethod 765 def from_dict(cls, data: dict[str, Any]) -> "SHAPRelativities": 766 """ 767 Reconstruct a fitted SHAPRelativities from to_dict() output. 768 769 The reconstructed object has no model attached, so validate() and 770 plot_relativities() still work but fit() cannot be re-run. 771 772 Args: 773 data: Output of to_dict(). 774 775 Returns: 776 Fitted SHAPRelativities instance. 777 """ 778 # P1-2 fix: use feature_names to control column ordering in X. 779 # Without this, tools that sort JSON object keys (REST APIs, some 780 # pretty-printers) reorder X_values, misaligning columns with the 781 # shap_values matrix columns. 782 feature_names: list[str] = data.get("feature_names", list(data["X_values"].keys())) 783 X = pl.DataFrame({k: data["X_values"][k] for k in feature_names}) 784 785 exposure = ( 786 np.array(data["exposure"]) if data.get("exposure") is not None 787 else None 788 ) 789 790 # Create a minimal instance without a real model 791 instance = cls.__new__(cls) 792 instance._model = None 793 instance._X = X 794 instance._exposure = exposure 795 instance._categorical_features = data.get("categorical_features", []) 796 instance._continuous_features = data.get("continuous_features", []) 797 instance._feature_perturbation = "tree_path_dependent" 798 instance._background_data = None 799 instance._n_background_samples = 1000 800 instance._annualise_exposure = data.get("annualise_exposure", True) 801 instance._verbose = False # no-op for deserialized instances 802 instance._shap_values = np.array(data["shap_values"]) 803 instance._expected_value = float(data["expected_value"]) 804 instance._is_fitted = True 805 806 return instance
Reconstruct a fitted SHAPRelativities from to_dict() output.
The reconstructed object has no model attached, so validate() and plot_relativities() still work but fit() cannot be re-run.
Arguments:
- data: Output of to_dict().
Returns:
Fitted SHAPRelativities instance.
270class SHAPInference: 271 """ 272 Asymptotically valid confidence intervals for global SHAP feature importance. 273 274 Implements the de-biased U-statistic estimator from Whitehouse, Sawarni, 275 Syrgkanis (2026), arXiv:2602.10532. Provides CIs for theta_p = E[|phi_a(X)|^p] 276 for any p >= 1 for each feature. 277 278 The most common use cases are: 279 p=1: mean absolute SHAP (standard SHAP importance bar chart) 280 p=2: mean squared SHAP (variance-like, cleaner theory) 281 282 IMPORTANT: Valid inference requires interventional SHAP values, not the 283 default path-dependent TreeSHAP. Use SHAPRelativities with 284 feature_perturbation='interventional' before calling SHAPInference, or 285 pass interventional SHAP values directly. 286 287 Args: 288 shap_values: np.ndarray of shape (n_obs, n_features). SHAP values, 289 one column per feature. Must be interventional SHAP for theoretical 290 validity. Path-dependent SHAP will produce point estimates and 291 intervals, but coverage guarantees do not hold. 292 y: np.ndarray of shape (n_obs,). Observed outcomes (claim counts or 293 claim amounts). Required for the alpha nuisance correction. 294 feature_names: List[str] of length n_features. Column names. 295 p: float >= 1. Power for importance measure. Default 2.0 (mean squared 296 SHAP). p=1 gives mean absolute SHAP (the standard bar chart metric) 297 but requires smoothing. p=2 has cleaner asymptotic theory. 298 n_folds: int >= 2. Number of cross-fitting folds. Default 5. More folds 299 reduce bias from the cross-fitting but increase compute. 300 nuisance_estimator: str or sklearn estimator. Used for mu_hat (E[Y|X]) 301 and gamma_hat. Default 'gradient_boosting' uses 302 HistGradientBoostingRegressor. 303 alpha_estimator: str or sklearn estimator. Used for alpha_hat. Defaults 304 to same as nuisance_estimator. 305 beta_n: float or None. Smoothing parameter for p < 2. If None, computed 306 as n^{(2-p)/(2*(p+1))} (assumes delta=1). Only used when p < 2. 307 ci_level: float. Two-sided confidence level. Default 0.95. 308 n_jobs: int. Placeholder for future parallelism. Currently unused; 309 features are estimated sequentially. 310 random_state: int or None. Controls fold splitting for reproducibility. 311 312 Examples 313 -------- 314 >>> import numpy as np 315 >>> rng = np.random.default_rng(0) 316 >>> shap_vals = rng.normal(size=(500, 3)) 317 >>> y = rng.poisson(1.0, size=500).astype(float) 318 >>> si = SHAPInference(shap_vals, y, feature_names=["a", "b", "c"], p=2) 319 >>> si.fit() 320 SHAPInference(n_obs=500, n_features=3, p=2.0, n_folds=5) 321 >>> tbl = si.importance_table() 322 """ 323 324 def __init__( 325 self, 326 shap_values: np.ndarray, 327 y: np.ndarray, 328 feature_names: list[str], 329 p: float = 2.0, 330 n_folds: int = 5, 331 nuisance_estimator: str | Any = "gradient_boosting", 332 alpha_estimator: str | Any = "gradient_boosting", 333 beta_n: float | None = None, 334 ci_level: float = 0.95, 335 n_jobs: int = 1, 336 random_state: int | None = None, 337 ) -> None: 338 # --- Input validation --- 339 shap_values = np.asarray(shap_values, dtype=float) 340 y = np.asarray(y, dtype=float) 341 342 if shap_values.ndim != 2: 343 raise ValueError( 344 f"shap_values must be 2D array (n_obs, n_features), " 345 f"got shape {shap_values.shape}" 346 ) 347 if y.ndim != 1: 348 raise ValueError(f"y must be 1D array, got shape {y.shape}") 349 if shap_values.shape[0] != len(y): 350 raise ValueError( 351 f"shap_values and y must have the same number of observations. " 352 f"Got shap_values.shape[0]={shap_values.shape[0]} and len(y)={len(y)}." 353 ) 354 if len(feature_names) != shap_values.shape[1]: 355 raise ValueError( 356 f"len(feature_names)={len(feature_names)} must equal " 357 f"shap_values.shape[1]={shap_values.shape[1]}." 358 ) 359 if p < 1.0: 360 raise ValueError(f"p must be >= 1. Got p={p}.") 361 if n_folds < 2: 362 raise ValueError(f"n_folds must be >= 2. Got n_folds={n_folds}.") 363 if not (0.0 < ci_level < 1.0): 364 raise ValueError(f"ci_level must be in (0, 1). Got ci_level={ci_level}.") 365 if len(feature_names) != len(set(feature_names)): 366 raise ValueError("feature_names must be unique.") 367 368 self.shap_values = shap_values 369 self.y = y 370 self.feature_names = list(feature_names) 371 self.p = float(p) 372 self.n_folds = n_folds 373 self.nuisance_estimator = nuisance_estimator 374 self.alpha_estimator = alpha_estimator 375 self.beta_n = beta_n 376 self.ci_level = ci_level 377 self.n_jobs = n_jobs 378 self.random_state = random_state 379 380 # Fitted attributes — populated by fit() 381 self._theta_hat: np.ndarray | None = None # shape (n_features,) 382 self._se: np.ndarray | None = None # shape (n_features,) 383 self._rho: np.ndarray | None = None # shape (n_obs, n_features) 384 self._is_fitted: bool = False 385 386 def __repr__(self) -> str: 387 n, d = self.shap_values.shape 388 status = "fitted" if self._is_fitted else "not fitted" 389 return ( 390 f"SHAPInference(n_obs={n}, n_features={d}, " 391 f"p={self.p}, n_folds={self.n_folds}, status={status})" 392 ) 393 394 def fit(self) -> "SHAPInference": 395 """ 396 Estimate nuisance functions via cross-fitting and compute de-biased 397 theta_hat_p for each feature. 398 399 The algorithm: 400 1. Split observations into n_folds folds. 401 2. For each fold, train mu_hat, gamma_hat, alpha_hat on the complement. 402 3. Evaluate nuisances on the held-out fold. 403 4. Assemble full-data nuisance predictions. 404 5. Compute the influence function rho_a for each feature. 405 6. theta_hat = mean(rho_a), SE = sqrt(var(rho_a) / n). 406 407 Returns: 408 self, for method chaining. 409 """ 410 if not _SKLEARN_AVAILABLE: 411 raise ImportError( 412 "scikit-learn >= 1.3 is required for SHAPInference. " 413 "Install with: pip install shap-relativities[ml]" 414 ) 415 416 n, d = self.shap_values.shape 417 418 # Warn if p < 2: smoothing is active 419 effective_beta_n = self.beta_n 420 if self.p < 2.0: 421 if effective_beta_n is None: 422 effective_beta_n = _default_beta_n(n, self.p) 423 warnings.warn( 424 f"p={self.p} < 2: using smoothed estimator phi_{{p,beta}} " 425 f"with beta_n={effective_beta_n:.3f}. " 426 "Coverage is asymptotically valid but may be approximate for " 427 "features with many near-zero SHAP values. " 428 "Consider p=2 for cleaner guarantees.", 429 UserWarning, 430 stacklevel=2, 431 ) 432 else: 433 effective_beta_n = 0.0 # unused but keeps type consistent 434 435 nu_est = _make_nuisance_estimator(self.nuisance_estimator) 436 al_est = _make_nuisance_estimator(self.alpha_estimator) 437 438 kf = KFold(n_splits=self.n_folds, shuffle=True, random_state=self.random_state) 439 440 theta_hats = np.zeros(d) 441 ses = np.zeros(d) 442 rhos = np.zeros((n, d)) 443 444 for j in range(d): 445 theta_j, se_j, rho_j = _fit_single_feature( 446 phi_col=self.shap_values[:, j], 447 y=self.y, 448 shap_matrix=self.shap_values, 449 p=self.p, 450 beta_n=effective_beta_n, 451 kf=kf, 452 nuisance_estimator=nu_est, 453 alpha_estimator=al_est, 454 ) 455 theta_hats[j] = theta_j 456 ses[j] = se_j 457 rhos[:, j] = rho_j 458 459 self._theta_hat = theta_hats 460 self._se = ses 461 self._rho = rhos 462 self._is_fitted = True 463 464 return self 465 466 def _check_fitted(self) -> None: 467 if not self._is_fitted: 468 raise RuntimeError("Call .fit() before accessing results.") 469 470 def importance_table(self) -> pl.DataFrame: 471 """ 472 Return feature importance estimates with confidence intervals. 473 474 All theta_hat values are theoretically non-negative (they estimate 475 E[|phi|^p]), but may be slightly negative for features with very 476 small true importance — this is expected sampling variability. 477 478 Returns: 479 Polars DataFrame with columns: 480 feature: Feature name 481 theta_hat: Point estimate of E[|phi_a(X)|^p] 482 theta_lower: Lower CI bound 483 theta_upper: Upper CI bound 484 sigma_hat: sqrt(Var[rho_a]) — asymptotic std dev 485 se: Standard error = sigma_hat / sqrt(n) 486 rank: Rank by theta_hat (1 = most important) 487 rank_lower: Conservative rank (using theta_lower) 488 rank_upper: Optimistic rank (using theta_upper) 489 p_value_nonzero: Two-sided p-value for H0: theta_p = 0 490 """ 491 self._check_fitted() 492 assert self._theta_hat is not None 493 assert self._se is not None 494 assert self._rho is not None 495 496 n = self.shap_values.shape[0] 497 z = float(stats.norm.ppf((1 + self.ci_level) / 2)) 498 499 theta = self._theta_hat 500 se = self._se 501 lower = theta - z * se 502 upper = theta + z * se 503 504 # sigma_hat = SE * sqrt(n) = sqrt(Var[rho]) 505 sigma_hat = se * np.sqrt(n) 506 507 # Ranks (1 = most important) 508 rank = _dense_rank_descending(theta) 509 rank_lower = _dense_rank_descending(lower) # conservative: lower bound 510 rank_upper = _dense_rank_descending(upper) # optimistic: upper bound 511 512 # p-value: two-sided test H0: theta_p = 0 513 # Under H0, z_stat = theta_hat / SE ~ N(0,1) asymptotically 514 with np.errstate(divide="ignore", invalid="ignore"): 515 z_stat = np.where(se > 0, theta / se, np.inf) 516 p_values = 2.0 * (1.0 - stats.norm.cdf(np.abs(z_stat))) 517 518 return pl.DataFrame({ 519 "feature": self.feature_names, 520 "theta_hat": theta.tolist(), 521 "theta_lower": lower.tolist(), 522 "theta_upper": upper.tolist(), 523 "sigma_hat": sigma_hat.tolist(), 524 "se": se.tolist(), 525 # _dense_rank_descending already returns list[int]; no .tolist() needed 526 "rank": rank, 527 "rank_lower": rank_lower, 528 "rank_upper": rank_upper, 529 "p_value_nonzero": p_values.tolist(), 530 }).sort("rank") 531 532 def ranking_ci(self, feature_a: str, feature_b: str) -> dict[str, float]: 533 """ 534 Test whether feature_a has strictly higher importance than feature_b. 535 536 Uses the joint asymptotic distribution of (theta_hat_a, theta_hat_b). 537 The covariance is estimated from influence function cross-products, 538 which accounts for the fact that both features' rho vectors are 539 computed from the same observations. 540 541 H0: theta_a = theta_b 542 H1: theta_a > theta_b (one-sided) 543 544 Args: 545 feature_a: Name of the first feature. 546 feature_b: Name of the second feature. 547 548 Returns: 549 dict with: 550 diff: theta_hat_a - theta_hat_b 551 se_diff: Standard error of the difference 552 z_stat: Standardised test statistic 553 p_value: One-sided p-value for H1: theta_a > theta_b 554 ci_lower: Lower bound on (theta_a - theta_b) at self.ci_level 555 ci_upper: Upper bound on (theta_a - theta_b) at self.ci_level 556 """ 557 self._check_fitted() 558 assert self._theta_hat is not None 559 assert self._rho is not None 560 561 if feature_a not in self.feature_names: 562 raise ValueError(f"feature_a='{feature_a}' not in feature_names.") 563 if feature_b not in self.feature_names: 564 raise ValueError(f"feature_b='{feature_b}' not in feature_names.") 565 566 j_a = self.feature_names.index(feature_a) 567 j_b = self.feature_names.index(feature_b) 568 569 n = self.shap_values.shape[0] 570 theta_a = float(self._theta_hat[j_a]) 571 theta_b = float(self._theta_hat[j_b]) 572 573 rho_a = self._rho[:, j_a] 574 rho_b = self._rho[:, j_b] 575 576 # Var(theta_hat_a - theta_hat_b) = Var(mean(rho_a - rho_b)) / n 577 # = Var(rho_a - rho_b) / n 578 diff_rho = rho_a - rho_b 579 var_diff = float(np.var(diff_rho, ddof=1)) / n 580 se_diff = float(np.sqrt(max(var_diff, 0.0))) 581 582 diff = theta_a - theta_b 583 # Guard: when both feature arguments are identical, diff==0 and SE==0. 584 # 0/0 is indeterminate; the correct result is z_stat=0, p_value=1. 585 if diff == 0.0 and se_diff == 0.0: 586 z_stat = 0.0 587 p_value = 1.0 588 else: 589 z_stat = diff / se_diff if se_diff > 0 else float("inf") 590 # One-sided p-value for H1: theta_a > theta_b 591 p_value = float(1.0 - stats.norm.cdf(z_stat)) 592 593 # Two-sided CI on the difference 594 z_ci = float(stats.norm.ppf((1 + self.ci_level) / 2)) 595 ci_lower = diff - z_ci * se_diff 596 ci_upper = diff + z_ci * se_diff 597 598 return { 599 "diff": diff, 600 "se_diff": se_diff, 601 "z_stat": z_stat, 602 "p_value": p_value, 603 "ci_lower": ci_lower, 604 "ci_upper": ci_upper, 605 } 606 607 def plot_importance( 608 self, 609 top_n: int | None = None, 610 ax: Any | None = None, 611 sort: bool = True, 612 ) -> Any: 613 """ 614 Bar chart of theta_hat with CI error bars. 615 616 Styled for insurance governance presentations: clean background, 617 coloured bars by significance (dark blue = CI excludes zero, grey = 618 CI includes zero), error bars at self.ci_level. 619 620 Args: 621 top_n: Show only top N features by theta_hat. None shows all. 622 ax: matplotlib Axes. If None, creates a new figure. 623 sort: Sort by theta_hat descending. Default True. 624 625 Returns: 626 The matplotlib Axes object. 627 """ 628 self._check_fitted() 629 630 try: 631 import matplotlib.pyplot as plt 632 except ImportError as e: 633 raise ImportError( 634 "matplotlib is required for plot_importance(). " 635 "Install with: pip install shap-relativities[plot]" 636 ) from e 637 638 tbl = self.importance_table() 639 if sort: 640 tbl = tbl.sort("theta_hat", descending=True) 641 if top_n is not None: 642 tbl = tbl.head(top_n) 643 644 features = tbl["feature"].to_list() 645 theta = tbl["theta_hat"].to_numpy() 646 lower = tbl["theta_lower"].to_numpy() 647 upper = tbl["theta_upper"].to_numpy() 648 649 err_low = theta - lower 650 err_high = upper - theta 651 652 # Colour by significance: CI excludes zero => dark teal, else grey 653 significant = lower > 0 654 colours = ["#1a6b7c" if s else "#9e9e9e" for s in significant] 655 656 if ax is None: 657 fig, ax = plt.subplots(figsize=(max(6, len(features) * 0.6 + 2), 5)) 658 659 y_pos = np.arange(len(features)) 660 ax.barh( 661 y_pos, 662 theta, 663 xerr=[err_low, err_high], 664 color=colours, 665 capsize=4, 666 edgecolor="white", 667 linewidth=0.5, 668 error_kw={"elinewidth": 1.2, "capthick": 1.2, "ecolor": "#555555"}, 669 ) 670 671 ax.set_yticks(y_pos) 672 ax.set_yticklabels(features, fontsize=10) 673 ax.invert_yaxis() 674 ax.axvline(0, color="black", linewidth=0.8, linestyle="--", alpha=0.5) 675 676 p_label = f"E[|φ(X)|^{{{self.p:.4g}}}]" 677 ax.set_xlabel(p_label, fontsize=11) 678 ax.set_title( 679 f"Global SHAP Feature Importance ({int(self.ci_level * 100)}% CI)", 680 fontsize=12, 681 ) 682 683 ax.spines["top"].set_visible(False) 684 ax.spines["right"].set_visible(False) 685 ax.set_facecolor("#f8f8f8") 686 if hasattr(ax, "figure") and ax.figure is not None: 687 ax.figure.set_facecolor("white") 688 689 # Legend 690 from matplotlib.patches import Patch 691 legend_elements = [ 692 Patch(facecolor="#1a6b7c", label="CI excludes 0"), 693 Patch(facecolor="#9e9e9e", label="CI includes 0"), 694 ] 695 ax.legend(handles=legend_elements, fontsize=9, loc="lower right") 696 697 return ax 698 699 @property 700 def influence_matrix(self) -> np.ndarray: 701 """ 702 Influence function matrix, shape (n_obs, n_features). 703 704 rho[i, j] is observation i's contribution to theta_hat_j. 705 Observations with large |rho[i, j]| are high-leverage for feature j's 706 importance estimate — useful for identifying influential policies in 707 governance reviews. 708 """ 709 self._check_fitted() 710 assert self._rho is not None 711 return self._rho.copy()
Asymptotically valid confidence intervals for global SHAP feature importance.
Implements the de-biased U-statistic estimator from Whitehouse, Sawarni, Syrgkanis (2026), arXiv:2602.10532. Provides CIs for theta_p = E[|phi_a(X)|^p] for any p >= 1 for each feature.
The most common use cases are:
p=1: mean absolute SHAP (standard SHAP importance bar chart) p=2: mean squared SHAP (variance-like, cleaner theory)
IMPORTANT: Valid inference requires interventional SHAP values, not the default path-dependent TreeSHAP. Use SHAPRelativities with feature_perturbation='interventional' before calling SHAPInference, or pass interventional SHAP values directly.
Arguments:
- shap_values: np.ndarray of shape (n_obs, n_features). SHAP values, one column per feature. Must be interventional SHAP for theoretical validity. Path-dependent SHAP will produce point estimates and intervals, but coverage guarantees do not hold.
- y: np.ndarray of shape (n_obs,). Observed outcomes (claim counts or claim amounts). Required for the alpha nuisance correction.
- feature_names: List[str] of length n_features. Column names.
- p: float >= 1. Power for importance measure. Default 2.0 (mean squared SHAP). p=1 gives mean absolute SHAP (the standard bar chart metric) but requires smoothing. p=2 has cleaner asymptotic theory.
- n_folds: int >= 2. Number of cross-fitting folds. Default 5. More folds reduce bias from the cross-fitting but increase compute.
- nuisance_estimator: str or sklearn estimator. Used for mu_hat (E[Y|X]) and gamma_hat. Default 'gradient_boosting' uses HistGradientBoostingRegressor.
- alpha_estimator: str or sklearn estimator. Used for alpha_hat. Defaults to same as nuisance_estimator.
- beta_n: float or None. Smoothing parameter for p < 2. If None, computed as n^{(2-p)/(2*(p+1))} (assumes delta=1). Only used when p < 2.
- ci_level: float. Two-sided confidence level. Default 0.95.
- n_jobs: int. Placeholder for future parallelism. Currently unused; features are estimated sequentially.
- random_state: int or None. Controls fold splitting for reproducibility.
Examples
>>> import numpy as np
>>> rng = np.random.default_rng(0)
>>> shap_vals = rng.normal(size=(500, 3))
>>> y = rng.poisson(1.0, size=500).astype(float)
>>> si = SHAPInference(shap_vals, y, feature_names=["a", "b", "c"], p=2)
>>> si.fit()
SHAPInference(n_obs=500, n_features=3, p=2.0, n_folds=5)
>>> tbl = si.importance_table()
324 def __init__( 325 self, 326 shap_values: np.ndarray, 327 y: np.ndarray, 328 feature_names: list[str], 329 p: float = 2.0, 330 n_folds: int = 5, 331 nuisance_estimator: str | Any = "gradient_boosting", 332 alpha_estimator: str | Any = "gradient_boosting", 333 beta_n: float | None = None, 334 ci_level: float = 0.95, 335 n_jobs: int = 1, 336 random_state: int | None = None, 337 ) -> None: 338 # --- Input validation --- 339 shap_values = np.asarray(shap_values, dtype=float) 340 y = np.asarray(y, dtype=float) 341 342 if shap_values.ndim != 2: 343 raise ValueError( 344 f"shap_values must be 2D array (n_obs, n_features), " 345 f"got shape {shap_values.shape}" 346 ) 347 if y.ndim != 1: 348 raise ValueError(f"y must be 1D array, got shape {y.shape}") 349 if shap_values.shape[0] != len(y): 350 raise ValueError( 351 f"shap_values and y must have the same number of observations. " 352 f"Got shap_values.shape[0]={shap_values.shape[0]} and len(y)={len(y)}." 353 ) 354 if len(feature_names) != shap_values.shape[1]: 355 raise ValueError( 356 f"len(feature_names)={len(feature_names)} must equal " 357 f"shap_values.shape[1]={shap_values.shape[1]}." 358 ) 359 if p < 1.0: 360 raise ValueError(f"p must be >= 1. Got p={p}.") 361 if n_folds < 2: 362 raise ValueError(f"n_folds must be >= 2. Got n_folds={n_folds}.") 363 if not (0.0 < ci_level < 1.0): 364 raise ValueError(f"ci_level must be in (0, 1). Got ci_level={ci_level}.") 365 if len(feature_names) != len(set(feature_names)): 366 raise ValueError("feature_names must be unique.") 367 368 self.shap_values = shap_values 369 self.y = y 370 self.feature_names = list(feature_names) 371 self.p = float(p) 372 self.n_folds = n_folds 373 self.nuisance_estimator = nuisance_estimator 374 self.alpha_estimator = alpha_estimator 375 self.beta_n = beta_n 376 self.ci_level = ci_level 377 self.n_jobs = n_jobs 378 self.random_state = random_state 379 380 # Fitted attributes — populated by fit() 381 self._theta_hat: np.ndarray | None = None # shape (n_features,) 382 self._se: np.ndarray | None = None # shape (n_features,) 383 self._rho: np.ndarray | None = None # shape (n_obs, n_features) 384 self._is_fitted: bool = False
394 def fit(self) -> "SHAPInference": 395 """ 396 Estimate nuisance functions via cross-fitting and compute de-biased 397 theta_hat_p for each feature. 398 399 The algorithm: 400 1. Split observations into n_folds folds. 401 2. For each fold, train mu_hat, gamma_hat, alpha_hat on the complement. 402 3. Evaluate nuisances on the held-out fold. 403 4. Assemble full-data nuisance predictions. 404 5. Compute the influence function rho_a for each feature. 405 6. theta_hat = mean(rho_a), SE = sqrt(var(rho_a) / n). 406 407 Returns: 408 self, for method chaining. 409 """ 410 if not _SKLEARN_AVAILABLE: 411 raise ImportError( 412 "scikit-learn >= 1.3 is required for SHAPInference. " 413 "Install with: pip install shap-relativities[ml]" 414 ) 415 416 n, d = self.shap_values.shape 417 418 # Warn if p < 2: smoothing is active 419 effective_beta_n = self.beta_n 420 if self.p < 2.0: 421 if effective_beta_n is None: 422 effective_beta_n = _default_beta_n(n, self.p) 423 warnings.warn( 424 f"p={self.p} < 2: using smoothed estimator phi_{{p,beta}} " 425 f"with beta_n={effective_beta_n:.3f}. " 426 "Coverage is asymptotically valid but may be approximate for " 427 "features with many near-zero SHAP values. " 428 "Consider p=2 for cleaner guarantees.", 429 UserWarning, 430 stacklevel=2, 431 ) 432 else: 433 effective_beta_n = 0.0 # unused but keeps type consistent 434 435 nu_est = _make_nuisance_estimator(self.nuisance_estimator) 436 al_est = _make_nuisance_estimator(self.alpha_estimator) 437 438 kf = KFold(n_splits=self.n_folds, shuffle=True, random_state=self.random_state) 439 440 theta_hats = np.zeros(d) 441 ses = np.zeros(d) 442 rhos = np.zeros((n, d)) 443 444 for j in range(d): 445 theta_j, se_j, rho_j = _fit_single_feature( 446 phi_col=self.shap_values[:, j], 447 y=self.y, 448 shap_matrix=self.shap_values, 449 p=self.p, 450 beta_n=effective_beta_n, 451 kf=kf, 452 nuisance_estimator=nu_est, 453 alpha_estimator=al_est, 454 ) 455 theta_hats[j] = theta_j 456 ses[j] = se_j 457 rhos[:, j] = rho_j 458 459 self._theta_hat = theta_hats 460 self._se = ses 461 self._rho = rhos 462 self._is_fitted = True 463 464 return self
Estimate nuisance functions via cross-fitting and compute de-biased theta_hat_p for each feature.
The algorithm:
- Split observations into n_folds folds.
- For each fold, train mu_hat, gamma_hat, alpha_hat on the complement.
- Evaluate nuisances on the held-out fold.
- Assemble full-data nuisance predictions.
- Compute the influence function rho_a for each feature.
- theta_hat = mean(rho_a), SE = sqrt(var(rho_a) / n).
Returns:
self, for method chaining.
470 def importance_table(self) -> pl.DataFrame: 471 """ 472 Return feature importance estimates with confidence intervals. 473 474 All theta_hat values are theoretically non-negative (they estimate 475 E[|phi|^p]), but may be slightly negative for features with very 476 small true importance — this is expected sampling variability. 477 478 Returns: 479 Polars DataFrame with columns: 480 feature: Feature name 481 theta_hat: Point estimate of E[|phi_a(X)|^p] 482 theta_lower: Lower CI bound 483 theta_upper: Upper CI bound 484 sigma_hat: sqrt(Var[rho_a]) — asymptotic std dev 485 se: Standard error = sigma_hat / sqrt(n) 486 rank: Rank by theta_hat (1 = most important) 487 rank_lower: Conservative rank (using theta_lower) 488 rank_upper: Optimistic rank (using theta_upper) 489 p_value_nonzero: Two-sided p-value for H0: theta_p = 0 490 """ 491 self._check_fitted() 492 assert self._theta_hat is not None 493 assert self._se is not None 494 assert self._rho is not None 495 496 n = self.shap_values.shape[0] 497 z = float(stats.norm.ppf((1 + self.ci_level) / 2)) 498 499 theta = self._theta_hat 500 se = self._se 501 lower = theta - z * se 502 upper = theta + z * se 503 504 # sigma_hat = SE * sqrt(n) = sqrt(Var[rho]) 505 sigma_hat = se * np.sqrt(n) 506 507 # Ranks (1 = most important) 508 rank = _dense_rank_descending(theta) 509 rank_lower = _dense_rank_descending(lower) # conservative: lower bound 510 rank_upper = _dense_rank_descending(upper) # optimistic: upper bound 511 512 # p-value: two-sided test H0: theta_p = 0 513 # Under H0, z_stat = theta_hat / SE ~ N(0,1) asymptotically 514 with np.errstate(divide="ignore", invalid="ignore"): 515 z_stat = np.where(se > 0, theta / se, np.inf) 516 p_values = 2.0 * (1.0 - stats.norm.cdf(np.abs(z_stat))) 517 518 return pl.DataFrame({ 519 "feature": self.feature_names, 520 "theta_hat": theta.tolist(), 521 "theta_lower": lower.tolist(), 522 "theta_upper": upper.tolist(), 523 "sigma_hat": sigma_hat.tolist(), 524 "se": se.tolist(), 525 # _dense_rank_descending already returns list[int]; no .tolist() needed 526 "rank": rank, 527 "rank_lower": rank_lower, 528 "rank_upper": rank_upper, 529 "p_value_nonzero": p_values.tolist(), 530 }).sort("rank")
Return feature importance estimates with confidence intervals.
All theta_hat values are theoretically non-negative (they estimate E[|phi|^p]), but may be slightly negative for features with very small true importance — this is expected sampling variability.
Returns:
Polars DataFrame with columns: feature: Feature name theta_hat: Point estimate of E[|phi_a(X)|^p] theta_lower: Lower CI bound theta_upper: Upper CI bound sigma_hat: sqrt(Var[rho_a]) — asymptotic std dev se: Standard error = sigma_hat / sqrt(n) rank: Rank by theta_hat (1 = most important) rank_lower: Conservative rank (using theta_lower) rank_upper: Optimistic rank (using theta_upper) p_value_nonzero: Two-sided p-value for H0: theta_p = 0
532 def ranking_ci(self, feature_a: str, feature_b: str) -> dict[str, float]: 533 """ 534 Test whether feature_a has strictly higher importance than feature_b. 535 536 Uses the joint asymptotic distribution of (theta_hat_a, theta_hat_b). 537 The covariance is estimated from influence function cross-products, 538 which accounts for the fact that both features' rho vectors are 539 computed from the same observations. 540 541 H0: theta_a = theta_b 542 H1: theta_a > theta_b (one-sided) 543 544 Args: 545 feature_a: Name of the first feature. 546 feature_b: Name of the second feature. 547 548 Returns: 549 dict with: 550 diff: theta_hat_a - theta_hat_b 551 se_diff: Standard error of the difference 552 z_stat: Standardised test statistic 553 p_value: One-sided p-value for H1: theta_a > theta_b 554 ci_lower: Lower bound on (theta_a - theta_b) at self.ci_level 555 ci_upper: Upper bound on (theta_a - theta_b) at self.ci_level 556 """ 557 self._check_fitted() 558 assert self._theta_hat is not None 559 assert self._rho is not None 560 561 if feature_a not in self.feature_names: 562 raise ValueError(f"feature_a='{feature_a}' not in feature_names.") 563 if feature_b not in self.feature_names: 564 raise ValueError(f"feature_b='{feature_b}' not in feature_names.") 565 566 j_a = self.feature_names.index(feature_a) 567 j_b = self.feature_names.index(feature_b) 568 569 n = self.shap_values.shape[0] 570 theta_a = float(self._theta_hat[j_a]) 571 theta_b = float(self._theta_hat[j_b]) 572 573 rho_a = self._rho[:, j_a] 574 rho_b = self._rho[:, j_b] 575 576 # Var(theta_hat_a - theta_hat_b) = Var(mean(rho_a - rho_b)) / n 577 # = Var(rho_a - rho_b) / n 578 diff_rho = rho_a - rho_b 579 var_diff = float(np.var(diff_rho, ddof=1)) / n 580 se_diff = float(np.sqrt(max(var_diff, 0.0))) 581 582 diff = theta_a - theta_b 583 # Guard: when both feature arguments are identical, diff==0 and SE==0. 584 # 0/0 is indeterminate; the correct result is z_stat=0, p_value=1. 585 if diff == 0.0 and se_diff == 0.0: 586 z_stat = 0.0 587 p_value = 1.0 588 else: 589 z_stat = diff / se_diff if se_diff > 0 else float("inf") 590 # One-sided p-value for H1: theta_a > theta_b 591 p_value = float(1.0 - stats.norm.cdf(z_stat)) 592 593 # Two-sided CI on the difference 594 z_ci = float(stats.norm.ppf((1 + self.ci_level) / 2)) 595 ci_lower = diff - z_ci * se_diff 596 ci_upper = diff + z_ci * se_diff 597 598 return { 599 "diff": diff, 600 "se_diff": se_diff, 601 "z_stat": z_stat, 602 "p_value": p_value, 603 "ci_lower": ci_lower, 604 "ci_upper": ci_upper, 605 }
Test whether feature_a has strictly higher importance than feature_b.
Uses the joint asymptotic distribution of (theta_hat_a, theta_hat_b). The covariance is estimated from influence function cross-products, which accounts for the fact that both features' rho vectors are computed from the same observations.
H0: theta_a = theta_b H1: theta_a > theta_b (one-sided)
Arguments:
- feature_a: Name of the first feature.
- feature_b: Name of the second feature.
Returns:
dict with: diff: theta_hat_a - theta_hat_b se_diff: Standard error of the difference z_stat: Standardised test statistic p_value: One-sided p-value for H1: theta_a > theta_b ci_lower: Lower bound on (theta_a - theta_b) at self.ci_level ci_upper: Upper bound on (theta_a - theta_b) at self.ci_level
607 def plot_importance( 608 self, 609 top_n: int | None = None, 610 ax: Any | None = None, 611 sort: bool = True, 612 ) -> Any: 613 """ 614 Bar chart of theta_hat with CI error bars. 615 616 Styled for insurance governance presentations: clean background, 617 coloured bars by significance (dark blue = CI excludes zero, grey = 618 CI includes zero), error bars at self.ci_level. 619 620 Args: 621 top_n: Show only top N features by theta_hat. None shows all. 622 ax: matplotlib Axes. If None, creates a new figure. 623 sort: Sort by theta_hat descending. Default True. 624 625 Returns: 626 The matplotlib Axes object. 627 """ 628 self._check_fitted() 629 630 try: 631 import matplotlib.pyplot as plt 632 except ImportError as e: 633 raise ImportError( 634 "matplotlib is required for plot_importance(). " 635 "Install with: pip install shap-relativities[plot]" 636 ) from e 637 638 tbl = self.importance_table() 639 if sort: 640 tbl = tbl.sort("theta_hat", descending=True) 641 if top_n is not None: 642 tbl = tbl.head(top_n) 643 644 features = tbl["feature"].to_list() 645 theta = tbl["theta_hat"].to_numpy() 646 lower = tbl["theta_lower"].to_numpy() 647 upper = tbl["theta_upper"].to_numpy() 648 649 err_low = theta - lower 650 err_high = upper - theta 651 652 # Colour by significance: CI excludes zero => dark teal, else grey 653 significant = lower > 0 654 colours = ["#1a6b7c" if s else "#9e9e9e" for s in significant] 655 656 if ax is None: 657 fig, ax = plt.subplots(figsize=(max(6, len(features) * 0.6 + 2), 5)) 658 659 y_pos = np.arange(len(features)) 660 ax.barh( 661 y_pos, 662 theta, 663 xerr=[err_low, err_high], 664 color=colours, 665 capsize=4, 666 edgecolor="white", 667 linewidth=0.5, 668 error_kw={"elinewidth": 1.2, "capthick": 1.2, "ecolor": "#555555"}, 669 ) 670 671 ax.set_yticks(y_pos) 672 ax.set_yticklabels(features, fontsize=10) 673 ax.invert_yaxis() 674 ax.axvline(0, color="black", linewidth=0.8, linestyle="--", alpha=0.5) 675 676 p_label = f"E[|φ(X)|^{{{self.p:.4g}}}]" 677 ax.set_xlabel(p_label, fontsize=11) 678 ax.set_title( 679 f"Global SHAP Feature Importance ({int(self.ci_level * 100)}% CI)", 680 fontsize=12, 681 ) 682 683 ax.spines["top"].set_visible(False) 684 ax.spines["right"].set_visible(False) 685 ax.set_facecolor("#f8f8f8") 686 if hasattr(ax, "figure") and ax.figure is not None: 687 ax.figure.set_facecolor("white") 688 689 # Legend 690 from matplotlib.patches import Patch 691 legend_elements = [ 692 Patch(facecolor="#1a6b7c", label="CI excludes 0"), 693 Patch(facecolor="#9e9e9e", label="CI includes 0"), 694 ] 695 ax.legend(handles=legend_elements, fontsize=9, loc="lower right") 696 697 return ax
Bar chart of theta_hat with CI error bars.
Styled for insurance governance presentations: clean background, coloured bars by significance (dark blue = CI excludes zero, grey = CI includes zero), error bars at self.ci_level.
Arguments:
- top_n: Show only top N features by theta_hat. None shows all.
- ax: matplotlib Axes. If None, creates a new figure.
- sort: Sort by theta_hat descending. Default True.
Returns:
The matplotlib Axes object.
699 @property 700 def influence_matrix(self) -> np.ndarray: 701 """ 702 Influence function matrix, shape (n_obs, n_features). 703 704 rho[i, j] is observation i's contribution to theta_hat_j. 705 Observations with large |rho[i, j]| are high-leverage for feature j's 706 importance estimate — useful for identifying influential policies in 707 governance reviews. 708 """ 709 self._check_fitted() 710 assert self._rho is not None 711 return self._rho.copy()
Influence function matrix, shape (n_obs, n_features).
rho[i, j] is observation i's contribution to theta_hat_j. Observations with large |rho[i, j]| are high-leverage for feature j's importance estimate — useful for identifying influential policies in governance reviews.
54def extract_relativities( 55 model: Any, 56 X: Any, 57 exposure: Any = None, 58 categorical_features: list[str] | None = None, 59 base_levels: dict[str, str | float | int] | None = None, 60 ci_method: str = "clt", 61) -> pl.DataFrame: 62 """ 63 One-shot extraction of SHAP relativities from a tree model. 64 65 Wraps SHAPRelativities.fit() and extract_relativities() for cases where 66 you don't need the intermediate object. 67 68 Args: 69 model: Trained CatBoost model with a log-link objective (Poisson, 70 Tweedie, or Gamma). CatBoost is the recommended choice - it handles 71 categorical features natively without encoding. 72 X: Feature matrix. Accepts a Polars or pandas DataFrame. Polars is 73 preferred; pandas is accepted and converted internally. 74 exposure: Earned policy years. If None, all observations are equally 75 weighted. 76 categorical_features: Features to aggregate by level. If None, all 77 non-numeric columns are treated as categorical. 78 base_levels: Base level for each categorical feature (gets 79 relativity = 1.0). 80 ci_method: "clt" (default) or "none". 81 82 Returns: 83 Polars DataFrame with columns: feature, level, relativity, lower_ci, 84 upper_ci, mean_shap, shap_std, n_obs, exposure_weight. 85 """ 86 sr = SHAPRelativities(model, X, exposure, categorical_features) 87 sr.fit() 88 return sr.extract_relativities(base_levels=base_levels, ci_method=ci_method)
One-shot extraction of SHAP relativities from a tree model.
Wraps SHAPRelativities.fit() and extract_relativities() for cases where you don't need the intermediate object.
Arguments:
- model: Trained CatBoost model with a log-link objective (Poisson, Tweedie, or Gamma). CatBoost is the recommended choice - it handles categorical features natively without encoding.
- X: Feature matrix. Accepts a Polars or pandas DataFrame. Polars is preferred; pandas is accepted and converted internally.
- exposure: Earned policy years. If None, all observations are equally weighted.
- categorical_features: Features to aggregate by level. If None, all non-numeric columns are treated as categorical.
- base_levels: Base level for each categorical feature (gets relativity = 1.0).
- ci_method: "clt" (default) or "none".
Returns:
Polars DataFrame with columns: feature, level, relativity, lower_ci, upper_ci, mean_shap, shap_std, n_obs, exposure_weight.