"""
EM algorithm for BLIM parameter estimation.
Estimates slip (β) and guess (η) parameters from observed response
patterns using Expectation-Maximization. Supports both global
(homogeneous) and per-item (heterogeneous) parameterization.
The algorithm iterates between:
- E-step: compute posterior P(state | response pattern) for each pattern.
- M-step: re-estimate β, η, and state prior π from sufficient statistics.
References:
Doignon, J.-P., & Falmagne, J.-C. (1999).
Knowledge Spaces, Chapter 12. Springer-Verlag.
Heller, J., & Wickelmaier, F. (2013).
Minimum discrepancy estimation in probabilistic knowledge structures.
Electronic Notes in Discrete Mathematics, 42, 49-56.
"""
from __future__ import annotations
import warnings
from dataclasses import dataclass
import numpy as np
from scipy import stats
from knowledgespaces.structures.knowledge_structure import KnowledgeStructure
[docs]
@dataclass
class ResponseMatrix:
"""Observed response patterns from a group of respondents.
Parameters
----------
items : list[str]
Item labels (columns), must match the structure's domain.
patterns : np.ndarray
Binary matrix of shape (n_respondents, n_items).
patterns[r, q] = 1 if respondent r answered item q correctly.
counts : np.ndarray or None
Optional frequency for each unique pattern. If None, each
row in patterns is one respondent (count=1 each).
"""
items: list[str]
patterns: np.ndarray
counts: np.ndarray | None = None
def __post_init__(self) -> None:
# Item uniqueness
seen: set[str] = set()
for item in self.items:
if item in seen:
raise ValueError(f"Duplicate item label: '{item}'.")
seen.add(item)
if self.patterns.ndim != 2:
raise ValueError(f"patterns must be 2D, got {self.patterns.ndim}D.")
if self.patterns.shape[1] != len(self.items):
raise ValueError(
f"patterns has {self.patterns.shape[1]} columns but {len(self.items)} items."
)
if not np.isin(self.patterns, [0, 1]).all():
raise ValueError("patterns must contain only 0 and 1.")
if self.counts is not None:
if len(self.counts) != self.patterns.shape[0]:
raise ValueError(
f"counts length {len(self.counts)} != patterns rows {self.patterns.shape[0]}."
)
if np.any(self.counts < 0):
raise ValueError("counts must be non-negative.")
if self.counts.sum() == 0:
raise ValueError("Total count must be positive (at least one respondent).")
@property
def n_patterns(self) -> int:
return self.patterns.shape[0]
@property
def n_items(self) -> int:
return len(self.items)
@property
def effective_counts(self) -> np.ndarray:
"""Counts for each pattern (ones if not provided)."""
if self.counts is not None:
return self.counts
return np.ones(self.n_patterns)
@property
def n_respondents(self) -> float:
return float(self.effective_counts.sum())
[docs]
@dataclass(frozen=True)
class GoodnessOfFit:
"""Goodness-of-fit statistics for a BLIM estimate.
Follows the approach of Heller & Wickelmaier (2013) and the R
``pks`` package (Heller & Wickelmaier, 2013, J. Stat. Softw.).
The primary statistic is the likelihood ratio G2 (deviance),
tested against a chi-squared distribution.
Attributes
----------
G2 : float
Likelihood ratio statistic: 2 * sum_r N_r ln(N_r / E_r).
df : int
Degrees of freedom: ``max(min(2^Q - 1, N) - npar, 0)``.
The ``min(2^Q - 1, N)`` cap follows the ``pks`` convention:
when the total sample size N is smaller than the number of
possible response patterns (2^Q), the saturated model cannot
be fully identified, so N replaces 2^Q - 1.
p_value : float
P-value from chi-squared test on G2.
npar : int
Number of free parameters: ``|K| - 1 + 2 * Q``.
AIC : float
Akaike Information Criterion: ``-2*LL + 2*npar``.
BIC : float
Bayesian Information Criterion using the number of unique
response patterns: ``-2*LL + ln(n_patterns)*npar``.
This follows the ``pks`` convention (Heller & Wickelmaier,
2013). See :attr:`BIC_N` for the standard textbook variant.
BIC_N : float
Bayesian Information Criterion using the total sample size:
``-2*LL + ln(N)*npar``. This is the standard BIC definition
(Schwarz, 1978).
"""
G2: float
df: int
p_value: float
npar: int
AIC: float
BIC: float
BIC_N: float
[docs]
@dataclass(frozen=True)
class BLIMEstimate:
"""Result of BLIM parameter estimation via EM.
Attributes
----------
beta : np.ndarray
Slip parameters, shape (n_items,). beta[q] = P(incorrect | q mastered).
eta : np.ndarray
Guess parameters, shape (n_items,). eta[q] = P(correct | q not mastered).
pi : np.ndarray
State prior probabilities, shape (n_states,).
log_likelihood : float
Final log-likelihood of the data.
n_iterations : int
Number of EM iterations until convergence.
converged : bool
True if converged within max_iter.
items : list[str]
Item labels corresponding to beta/eta indices.
states : list[frozenset[str]]
Knowledge states corresponding to pi indices (same order).
gof : GoodnessOfFit
Goodness-of-fit statistics (G2, df, p-value, AIC, BIC).
"""
beta: np.ndarray
eta: np.ndarray
pi: np.ndarray
log_likelihood: float
n_iterations: int
converged: bool
items: list[str]
states: list[frozenset[str]]
gof: GoodnessOfFit
[docs]
def beta_for(self, item: str) -> float:
"""Get beta (slip) for a specific item."""
return float(self.beta[self.items.index(item)])
[docs]
def eta_for(self, item: str) -> float:
"""Get eta (guess) for a specific item."""
return float(self.eta[self.items.index(item)])
[docs]
def beta_dict(self) -> dict[str, float]:
"""Return beta as {item: value} dict."""
return dict(zip(self.items, self.beta.tolist(), strict=True))
[docs]
def eta_dict(self) -> dict[str, float]:
"""Return eta as {item: value} dict."""
return dict(zip(self.items, self.eta.tolist(), strict=True))
[docs]
def pi_dict(self) -> dict[frozenset[str], float]:
"""Return pi as {state: probability} dict."""
return dict(zip(self.states, self.pi.tolist(), strict=True))
def _compute_gof(
log_likelihood: float,
n_states: int,
n_items: int,
n_patterns: int,
n_respondents: float,
beta: np.ndarray,
eta: np.ndarray,
pi: np.ndarray,
S: np.ndarray,
R: np.ndarray,
counts: np.ndarray,
) -> GoodnessOfFit:
"""Compute goodness-of-fit statistics for a BLIM estimate.
Follows the R ``pks`` package conventions (Heller & Wickelmaier, 2013).
G2 is computed over *unique* response patterns with their observed
frequencies, compared against expected frequencies from the model.
"""
N = n_respondents
# Aggregate to unique patterns with observed frequencies
unique_patterns, inverse = np.unique(R, axis=0, return_inverse=True)
observed_freq = np.zeros(len(unique_patterns))
for i, idx in enumerate(inverse):
observed_freq[idx] += counts[i]
n_unique = len(unique_patterns)
# Predicted probability of each unique pattern: P(R) = sum_K P(R|K)*P(K)
p_correct = S * (1 - beta) + (1 - S) * eta # (n_states, n_items)
log_pc = np.log(np.clip(p_correct, 1e-300, None))
log_pinc = np.log(np.clip(1 - p_correct, 1e-300, None))
log_lik_rk = unique_patterns @ log_pc.T + (1 - unique_patterns) @ log_pinc.T
log_prior = np.log(np.clip(pi, 1e-300, None))
log_joint = log_lik_rk + log_prior
max_log = log_joint.max(axis=1, keepdims=True)
log_P_R = (max_log + np.log(np.exp(log_joint - max_log).sum(axis=1, keepdims=True))).ravel()
# Expected counts
E = N * np.exp(log_P_R)
# G2 = 2 * sum_r N_r * ln(N_r / E_r), skipping zero-count patterns
mask = observed_freq > 0
G2 = float(2 * np.sum(observed_freq[mask] * np.log(observed_freq[mask] / E[mask])))
# Free parameters: |K|-1 (pi) + Q (beta) + Q (eta)
npar = (n_states - 1) + 2 * n_items
# Degrees of freedom: saturated model has 2^Q - 1 free parameters
# (one probability per pattern minus the sum-to-1 constraint).
# df = min(2^Q - 1, N) - npar, floored at 0.
n_possible = 2**n_items - 1
if n_items > 20:
warnings.warn(
f"Domain has {n_items} items — the full 2^Q pattern space "
f"({2**n_items}) is very large. GOF df may be unreliable.",
stacklevel=3,
)
df = max(int(min(n_possible, N) - npar), 0)
# P-value
p_value = float(1 - stats.chi2.cdf(G2, df)) if df > 0 else 1.0
# Information criteria
AIC = -2 * log_likelihood + 2 * npar
BIC = -2 * log_likelihood + np.log(n_unique) * npar # pks convention
BIC_N = -2 * log_likelihood + np.log(N) * npar # standard (Schwarz, 1978)
return GoodnessOfFit(
G2=G2,
df=df,
p_value=p_value,
npar=npar,
AIC=float(AIC),
BIC=float(BIC),
BIC_N=float(BIC_N),
)
[docs]
def estimate_blim(
structure: KnowledgeStructure,
data: ResponseMatrix,
*,
max_iter: int = 500,
tol: float = 1e-6,
beta_init: float | np.ndarray = 0.1,
eta_init: float | np.ndarray = 0.1,
) -> BLIMEstimate:
"""Estimate BLIM parameters via Expectation-Maximization.
Parameters
----------
structure : KnowledgeStructure
The knowledge structure defining valid states.
data : ResponseMatrix
Observed response patterns.
max_iter : int
Maximum number of EM iterations. Default 500.
tol : float
Convergence tolerance on log-likelihood change. Default 1e-6.
beta_init : float or np.ndarray
Initial slip values. Scalar for homogeneous, array for per-item.
eta_init : float or np.ndarray
Initial guess values. Scalar for homogeneous, array for per-item.
Returns
-------
BLIMEstimate
Estimated parameters, log-likelihood, and convergence info.
Raises
------
ValueError
If data items don't match structure domain, or if init
parameters are out of range.
"""
# Validate hyperparameters
if max_iter < 1:
raise ValueError(f"max_iter must be >= 1, got {max_iter}.")
if tol <= 0:
raise ValueError(f"tol must be > 0, got {tol}.")
# Validate domain match
if set(data.items) != structure.domain:
raise ValueError(
f"ResponseMatrix items {set(data.items)} don't match "
f"structure domain {set(structure.domain)}."
)
items = data.items
n_items = len(items)
counts = data.effective_counts
# Build state matrix: S[k, q] = 1 if item q is in state k
states = sorted(structure.states, key=lambda s: (len(s), sorted(s)))
n_states = len(states)
item_idx = {item: i for i, item in enumerate(items)}
S = np.zeros((n_states, n_items), dtype=np.float64)
for k, state in enumerate(states):
for item in state:
S[k, item_idx[item]] = 1.0
# Initialize and validate parameters
if isinstance(beta_init, (int, float)):
beta = np.full(n_items, float(beta_init))
else:
beta = np.array(beta_init, dtype=np.float64).copy()
if beta.shape != (n_items,):
raise ValueError(f"beta_init array has shape {beta.shape}, expected ({n_items},).")
if isinstance(eta_init, (int, float)):
eta = np.full(n_items, float(eta_init))
else:
eta = np.array(eta_init, dtype=np.float64).copy()
if eta.shape != (n_items,):
raise ValueError(f"eta_init array has shape {eta.shape}, expected ({n_items},).")
if np.any(beta < 0) or np.any(beta >= 1):
raise ValueError("All beta_init values must be in [0, 1).")
if np.any(eta < 0) or np.any(eta >= 1):
raise ValueError("All eta_init values must be in [0, 1).")
if np.any(beta + eta >= 1):
raise ValueError("beta_init + eta_init must be < 1 for each item.")
pi = np.full(n_states, 1.0 / n_states)
R = data.patterns.astype(np.float64) # (n_patterns, n_items)
prev_ll = -np.inf
for iteration in range(1, max_iter + 1):
# ============================================================
# E-step: compute P(pattern | state) and posterior P(state | pattern)
# ============================================================
# P(R_r | K_k) = prod_q [ P(R_rq | K_k, q) ]
# where P(correct | q in K) = 1-beta_q, P(correct | q not in K) = eta_q
# Likelihood per (pattern, state, item)
# P(R_rq=1 | K_k) = S[k,q]*(1-beta[q]) + (1-S[k,q])*eta[q]
p_correct = S * (1 - beta) + (1 - S) * eta # (n_states, n_items)
p_incorrect = 1 - p_correct # (n_states, n_items)
# P(R_r | K_k) for each (pattern, state)
# log for numerical stability
log_p_correct = np.log(np.clip(p_correct, 1e-300, None))
log_p_incorrect = np.log(np.clip(p_incorrect, 1e-300, None))
# log P(R_r | K_k) = sum_q [ R_rq * log_p_correct[k,q] + (1-R_rq) * log_p_incorrect[k,q] ]
log_lik_rk = R @ log_p_correct.T + (1 - R) @ log_p_incorrect.T # (n_patterns, n_states)
# Add log prior
log_joint = log_lik_rk + np.log(np.clip(pi, 1e-300, None)) # (n_patterns, n_states)
# Log-sum-exp for numerical stability
max_log = log_joint.max(axis=1, keepdims=True)
log_marginal = max_log + np.log(np.exp(log_joint - max_log).sum(axis=1, keepdims=True))
# Posterior: P(K_k | R_r) = exp(log_joint - log_marginal)
posterior = np.exp(log_joint - log_marginal) # (n_patterns, n_states)
# Log-likelihood
ll = float((counts * log_marginal.ravel()).sum())
# Check convergence
if abs(ll - prev_ll) < tol:
gof = _compute_gof(
ll,
n_states,
n_items,
data.n_patterns,
data.n_respondents,
beta,
eta,
pi,
S,
R,
counts,
)
return BLIMEstimate(
beta=beta,
eta=eta,
pi=pi,
log_likelihood=ll,
n_iterations=iteration,
converged=True,
items=items,
states=states,
gof=gof,
)
prev_ll = ll
# ============================================================
# M-step: re-estimate beta, eta, pi from sufficient statistics
# ============================================================
# Weighted posterior: w[r, k] = counts[r] * posterior[r, k]
W = counts[:, np.newaxis] * posterior # (n_patterns, n_states)
N_total = counts.sum()
# Pi: state prior
pi = W.sum(axis=0) / N_total # (n_states,)
# For each item q, compute sufficient statistics
for q in range(n_items):
# Expected number of respondents in states containing q
# N_mastered[q] = sum_r sum_{k: q in K_k} W[r,k]
states_with_q = S[:, q] # (n_states,) binary
N_mastered_q = (W @ states_with_q).sum()
N_not_mastered_q = N_total - N_mastered_q
# Expected number of correct among mastered
# C_mastered[q] = sum_r R[r,q] * sum_{k: q in K_k} W[r,k]
correct_weighted = R[:, q] * (W @ states_with_q)
C_mastered_q = correct_weighted.sum()
# Expected number of correct among not mastered
states_without_q = 1 - states_with_q
correct_not_weighted = R[:, q] * (W @ states_without_q)
C_not_mastered_q = correct_not_weighted.sum()
# Update beta = P(incorrect | mastered) = 1 - C_mastered / N_mastered
if N_mastered_q > 1e-10:
beta[q] = np.clip(1 - C_mastered_q / N_mastered_q, 1e-6, 1 - 1e-6)
# Update eta = P(correct | not mastered) = C_not_mastered / N_not_mastered
if N_not_mastered_q > 1e-10:
eta[q] = np.clip(C_not_mastered_q / N_not_mastered_q, 1e-6, 1 - 1e-6)
# Enforce joint identifiability: beta + eta < 1
_EPS = 1e-6
if beta[q] + eta[q] >= 1:
total = beta[q] + eta[q]
# Proportionally shrink both to satisfy constraint
beta[q] = beta[q] / total * (1 - _EPS)
eta[q] = eta[q] / total * (1 - _EPS)
gof = _compute_gof(
prev_ll,
n_states,
n_items,
data.n_patterns,
data.n_respondents,
beta,
eta,
pi,
S,
R,
counts,
)
return BLIMEstimate(
beta=beta,
eta=eta,
pi=pi,
log_likelihood=prev_ll,
n_iterations=max_iter,
converged=False,
items=items,
states=states,
gof=gof,
)
[docs]
def estimate_blim_restarts(
structure: KnowledgeStructure,
data: ResponseMatrix,
*,
n_restarts: int = 10,
max_iter: int = 500,
tol: float = 1e-6,
seed: int | None = None,
) -> BLIMEstimate:
"""Estimate BLIM parameters with multiple random restarts.
Runs :func:`estimate_blim` ``n_restarts`` times with random
initial values for beta, eta, and selects the result with the
highest log-likelihood. This helps avoid local optima.
The R ``pks`` package does not provide this natively — users
must loop manually with ``randinit=TRUE``.
Parameters
----------
structure : KnowledgeStructure
The knowledge structure defining valid states.
data : ResponseMatrix
Observed response patterns.
n_restarts : int
Number of random restarts. Default 10.
max_iter : int
Maximum EM iterations per restart.
tol : float
Convergence tolerance per restart.
seed : int or None
Random seed for reproducibility.
Returns
-------
BLIMEstimate
The best result (highest log-likelihood) across all restarts.
"""
if n_restarts < 1:
raise ValueError(f"n_restarts must be >= 1, got {n_restarts}.")
rng = np.random.default_rng(seed)
n_items = len(data.items)
best: BLIMEstimate | None = None
for _ in range(n_restarts):
# Random init: beta, eta uniform in [0.01, 0.4] with beta+eta < 1
beta_init = rng.uniform(0.01, 0.4, size=n_items)
eta_init = rng.uniform(0.01, 0.4, size=n_items)
# Ensure joint constraint
for q in range(n_items):
while beta_init[q] + eta_init[q] >= 0.95:
beta_init[q] *= 0.5
eta_init[q] *= 0.5
result = estimate_blim(
structure,
data,
max_iter=max_iter,
tol=tol,
beta_init=beta_init,
eta_init=eta_init,
)
if best is None or result.log_likelihood > best.log_likelihood:
best = result
assert best is not None # guaranteed by n_restarts >= 1
return best