Skip to content

Part 7: Understanding what the CANN is about to learn

Part 7: Understanding what the CANN is about to learn

Before training, we can look at where the GLM is going wrong. This is the signal the CANN will pick up.

Residual analysis by age band and vehicle group

import matplotlib.pyplot as plt

df_diag = pl.DataFrame({
    "age_band":      X["age_band"].to_list(),
    "vehicle_group": X["vehicle_group"].to_list(),
    "y":             y,
    "mu_glm":        mu_glm,
    "exposure":      exposure_arr,
})

# Actual/expected by age band and vehicle group
ae_table = (
    df_diag
    .group_by(["age_band", "vehicle_group"])
    .agg([
        pl.sum("y").alias("observed"),
        pl.sum("mu_glm").alias("predicted"),
        pl.sum("exposure").alias("exposure"),
    ])
    .with_columns(
        (pl.col("observed") / pl.col("predicted")).alias("ae_ratio")
    )
    .sort(["age_band", "vehicle_group"])
)

# Show the worst cells
print("Cells with highest A/E ratio (GLM underpredicting):")
print(ae_table.sort("ae_ratio", descending=True).head(10))

print("\nCells with lowest A/E ratio (GLM overpredicting):")
print(ae_table.sort("ae_ratio").head(10))

What you expect to see: The age band 17-21 combined with vehicle group 41-50 should appear near the top of the underpredicting list (A/E > 1.0). The main GLM does not know that this combination should be penalised extra hard. You planted this interaction in the data — now you can see where the GLM is bleeding deviance.