Source code for knowledgespaces.estimation.blim_em

"""
EM algorithm for BLIM parameter estimation.

Estimates slip (β) and guess (η) parameters from observed response
patterns using Expectation-Maximization. Supports both global
(homogeneous) and per-item (heterogeneous) parameterization.

The algorithm iterates between:
- E-step: compute posterior P(state | response pattern) for each pattern.
- M-step: re-estimate β, η, and state prior π from sufficient statistics.

References:
    Doignon, J.-P., & Falmagne, J.-C. (1999).
    Knowledge Spaces, Chapter 12. Springer-Verlag.

    Heller, J., & Wickelmaier, F. (2013).
    Minimum discrepancy estimation in probabilistic knowledge structures.
    Electronic Notes in Discrete Mathematics, 42, 49-56.
"""

from __future__ import annotations

import warnings
from dataclasses import dataclass

import numpy as np
from scipy import stats

from knowledgespaces.structures.knowledge_structure import KnowledgeStructure


[docs] @dataclass class ResponseMatrix: """Observed response patterns from a group of respondents. Parameters ---------- items : list[str] Item labels (columns), must match the structure's domain. patterns : np.ndarray Binary matrix of shape (n_respondents, n_items). patterns[r, q] = 1 if respondent r answered item q correctly. counts : np.ndarray or None Optional frequency for each unique pattern. If None, each row in patterns is one respondent (count=1 each). """ items: list[str] patterns: np.ndarray counts: np.ndarray | None = None def __post_init__(self) -> None: # Item uniqueness seen: set[str] = set() for item in self.items: if item in seen: raise ValueError(f"Duplicate item label: '{item}'.") seen.add(item) if self.patterns.ndim != 2: raise ValueError(f"patterns must be 2D, got {self.patterns.ndim}D.") if self.patterns.shape[1] != len(self.items): raise ValueError( f"patterns has {self.patterns.shape[1]} columns but {len(self.items)} items." ) if not np.isin(self.patterns, [0, 1]).all(): raise ValueError("patterns must contain only 0 and 1.") if self.counts is not None: if len(self.counts) != self.patterns.shape[0]: raise ValueError( f"counts length {len(self.counts)} != patterns rows {self.patterns.shape[0]}." ) if np.any(self.counts < 0): raise ValueError("counts must be non-negative.") if self.counts.sum() == 0: raise ValueError("Total count must be positive (at least one respondent).") @property def n_patterns(self) -> int: return self.patterns.shape[0] @property def n_items(self) -> int: return len(self.items) @property def effective_counts(self) -> np.ndarray: """Counts for each pattern (ones if not provided).""" if self.counts is not None: return self.counts return np.ones(self.n_patterns) @property def n_respondents(self) -> float: return float(self.effective_counts.sum())
[docs] @dataclass(frozen=True) class GoodnessOfFit: """Goodness-of-fit statistics for a BLIM estimate. Follows the approach of Heller & Wickelmaier (2013) and the R ``pks`` package (Heller & Wickelmaier, 2013, J. Stat. Softw.). The primary statistic is the likelihood ratio G2 (deviance), tested against a chi-squared distribution. Attributes ---------- G2 : float Likelihood ratio statistic: 2 * sum_r N_r ln(N_r / E_r). df : int Degrees of freedom: ``max(min(2^Q - 1, N) - npar, 0)``. The ``min(2^Q - 1, N)`` cap follows the ``pks`` convention: when the total sample size N is smaller than the number of possible response patterns (2^Q), the saturated model cannot be fully identified, so N replaces 2^Q - 1. p_value : float P-value from chi-squared test on G2. npar : int Number of free parameters: ``|K| - 1 + 2 * Q``. AIC : float Akaike Information Criterion: ``-2*LL + 2*npar``. BIC : float Bayesian Information Criterion using the number of unique response patterns: ``-2*LL + ln(n_patterns)*npar``. This follows the ``pks`` convention (Heller & Wickelmaier, 2013). See :attr:`BIC_N` for the standard textbook variant. BIC_N : float Bayesian Information Criterion using the total sample size: ``-2*LL + ln(N)*npar``. This is the standard BIC definition (Schwarz, 1978). """ G2: float df: int p_value: float npar: int AIC: float BIC: float BIC_N: float
[docs] @dataclass(frozen=True) class BLIMEstimate: """Result of BLIM parameter estimation via EM. Attributes ---------- beta : np.ndarray Slip parameters, shape (n_items,). beta[q] = P(incorrect | q mastered). eta : np.ndarray Guess parameters, shape (n_items,). eta[q] = P(correct | q not mastered). pi : np.ndarray State prior probabilities, shape (n_states,). log_likelihood : float Final log-likelihood of the data. n_iterations : int Number of EM iterations until convergence. converged : bool True if converged within max_iter. items : list[str] Item labels corresponding to beta/eta indices. states : list[frozenset[str]] Knowledge states corresponding to pi indices (same order). gof : GoodnessOfFit Goodness-of-fit statistics (G2, df, p-value, AIC, BIC). """ beta: np.ndarray eta: np.ndarray pi: np.ndarray log_likelihood: float n_iterations: int converged: bool items: list[str] states: list[frozenset[str]] gof: GoodnessOfFit
[docs] def beta_for(self, item: str) -> float: """Get beta (slip) for a specific item.""" return float(self.beta[self.items.index(item)])
[docs] def eta_for(self, item: str) -> float: """Get eta (guess) for a specific item.""" return float(self.eta[self.items.index(item)])
[docs] def beta_dict(self) -> dict[str, float]: """Return beta as {item: value} dict.""" return dict(zip(self.items, self.beta.tolist(), strict=True))
[docs] def eta_dict(self) -> dict[str, float]: """Return eta as {item: value} dict.""" return dict(zip(self.items, self.eta.tolist(), strict=True))
[docs] def pi_dict(self) -> dict[frozenset[str], float]: """Return pi as {state: probability} dict.""" return dict(zip(self.states, self.pi.tolist(), strict=True))
def _compute_gof( log_likelihood: float, n_states: int, n_items: int, n_patterns: int, n_respondents: float, beta: np.ndarray, eta: np.ndarray, pi: np.ndarray, S: np.ndarray, R: np.ndarray, counts: np.ndarray, ) -> GoodnessOfFit: """Compute goodness-of-fit statistics for a BLIM estimate. Follows the R ``pks`` package conventions (Heller & Wickelmaier, 2013). G2 is computed over *unique* response patterns with their observed frequencies, compared against expected frequencies from the model. """ N = n_respondents # Aggregate to unique patterns with observed frequencies unique_patterns, inverse = np.unique(R, axis=0, return_inverse=True) observed_freq = np.zeros(len(unique_patterns)) for i, idx in enumerate(inverse): observed_freq[idx] += counts[i] n_unique = len(unique_patterns) # Predicted probability of each unique pattern: P(R) = sum_K P(R|K)*P(K) p_correct = S * (1 - beta) + (1 - S) * eta # (n_states, n_items) log_pc = np.log(np.clip(p_correct, 1e-300, None)) log_pinc = np.log(np.clip(1 - p_correct, 1e-300, None)) log_lik_rk = unique_patterns @ log_pc.T + (1 - unique_patterns) @ log_pinc.T log_prior = np.log(np.clip(pi, 1e-300, None)) log_joint = log_lik_rk + log_prior max_log = log_joint.max(axis=1, keepdims=True) log_P_R = (max_log + np.log(np.exp(log_joint - max_log).sum(axis=1, keepdims=True))).ravel() # Expected counts E = N * np.exp(log_P_R) # G2 = 2 * sum_r N_r * ln(N_r / E_r), skipping zero-count patterns mask = observed_freq > 0 G2 = float(2 * np.sum(observed_freq[mask] * np.log(observed_freq[mask] / E[mask]))) # Free parameters: |K|-1 (pi) + Q (beta) + Q (eta) npar = (n_states - 1) + 2 * n_items # Degrees of freedom: saturated model has 2^Q - 1 free parameters # (one probability per pattern minus the sum-to-1 constraint). # df = min(2^Q - 1, N) - npar, floored at 0. n_possible = 2**n_items - 1 if n_items > 20: warnings.warn( f"Domain has {n_items} items — the full 2^Q pattern space " f"({2**n_items}) is very large. GOF df may be unreliable.", stacklevel=3, ) df = max(int(min(n_possible, N) - npar), 0) # P-value p_value = float(1 - stats.chi2.cdf(G2, df)) if df > 0 else 1.0 # Information criteria AIC = -2 * log_likelihood + 2 * npar BIC = -2 * log_likelihood + np.log(n_unique) * npar # pks convention BIC_N = -2 * log_likelihood + np.log(N) * npar # standard (Schwarz, 1978) return GoodnessOfFit( G2=G2, df=df, p_value=p_value, npar=npar, AIC=float(AIC), BIC=float(BIC), BIC_N=float(BIC_N), )
[docs] def estimate_blim( structure: KnowledgeStructure, data: ResponseMatrix, *, max_iter: int = 500, tol: float = 1e-6, beta_init: float | np.ndarray = 0.1, eta_init: float | np.ndarray = 0.1, ) -> BLIMEstimate: """Estimate BLIM parameters via Expectation-Maximization. Parameters ---------- structure : KnowledgeStructure The knowledge structure defining valid states. data : ResponseMatrix Observed response patterns. max_iter : int Maximum number of EM iterations. Default 500. tol : float Convergence tolerance on log-likelihood change. Default 1e-6. beta_init : float or np.ndarray Initial slip values. Scalar for homogeneous, array for per-item. eta_init : float or np.ndarray Initial guess values. Scalar for homogeneous, array for per-item. Returns ------- BLIMEstimate Estimated parameters, log-likelihood, and convergence info. Raises ------ ValueError If data items don't match structure domain, or if init parameters are out of range. """ # Validate hyperparameters if max_iter < 1: raise ValueError(f"max_iter must be >= 1, got {max_iter}.") if tol <= 0: raise ValueError(f"tol must be > 0, got {tol}.") # Validate domain match if set(data.items) != structure.domain: raise ValueError( f"ResponseMatrix items {set(data.items)} don't match " f"structure domain {set(structure.domain)}." ) items = data.items n_items = len(items) counts = data.effective_counts # Build state matrix: S[k, q] = 1 if item q is in state k states = sorted(structure.states, key=lambda s: (len(s), sorted(s))) n_states = len(states) item_idx = {item: i for i, item in enumerate(items)} S = np.zeros((n_states, n_items), dtype=np.float64) for k, state in enumerate(states): for item in state: S[k, item_idx[item]] = 1.0 # Initialize and validate parameters if isinstance(beta_init, (int, float)): beta = np.full(n_items, float(beta_init)) else: beta = np.array(beta_init, dtype=np.float64).copy() if beta.shape != (n_items,): raise ValueError(f"beta_init array has shape {beta.shape}, expected ({n_items},).") if isinstance(eta_init, (int, float)): eta = np.full(n_items, float(eta_init)) else: eta = np.array(eta_init, dtype=np.float64).copy() if eta.shape != (n_items,): raise ValueError(f"eta_init array has shape {eta.shape}, expected ({n_items},).") if np.any(beta < 0) or np.any(beta >= 1): raise ValueError("All beta_init values must be in [0, 1).") if np.any(eta < 0) or np.any(eta >= 1): raise ValueError("All eta_init values must be in [0, 1).") if np.any(beta + eta >= 1): raise ValueError("beta_init + eta_init must be < 1 for each item.") pi = np.full(n_states, 1.0 / n_states) R = data.patterns.astype(np.float64) # (n_patterns, n_items) prev_ll = -np.inf for iteration in range(1, max_iter + 1): # ============================================================ # E-step: compute P(pattern | state) and posterior P(state | pattern) # ============================================================ # P(R_r | K_k) = prod_q [ P(R_rq | K_k, q) ] # where P(correct | q in K) = 1-beta_q, P(correct | q not in K) = eta_q # Likelihood per (pattern, state, item) # P(R_rq=1 | K_k) = S[k,q]*(1-beta[q]) + (1-S[k,q])*eta[q] p_correct = S * (1 - beta) + (1 - S) * eta # (n_states, n_items) p_incorrect = 1 - p_correct # (n_states, n_items) # P(R_r | K_k) for each (pattern, state) # log for numerical stability log_p_correct = np.log(np.clip(p_correct, 1e-300, None)) log_p_incorrect = np.log(np.clip(p_incorrect, 1e-300, None)) # log P(R_r | K_k) = sum_q [ R_rq * log_p_correct[k,q] + (1-R_rq) * log_p_incorrect[k,q] ] log_lik_rk = R @ log_p_correct.T + (1 - R) @ log_p_incorrect.T # (n_patterns, n_states) # Add log prior log_joint = log_lik_rk + np.log(np.clip(pi, 1e-300, None)) # (n_patterns, n_states) # Log-sum-exp for numerical stability max_log = log_joint.max(axis=1, keepdims=True) log_marginal = max_log + np.log(np.exp(log_joint - max_log).sum(axis=1, keepdims=True)) # Posterior: P(K_k | R_r) = exp(log_joint - log_marginal) posterior = np.exp(log_joint - log_marginal) # (n_patterns, n_states) # Log-likelihood ll = float((counts * log_marginal.ravel()).sum()) # Check convergence if abs(ll - prev_ll) < tol: gof = _compute_gof( ll, n_states, n_items, data.n_patterns, data.n_respondents, beta, eta, pi, S, R, counts, ) return BLIMEstimate( beta=beta, eta=eta, pi=pi, log_likelihood=ll, n_iterations=iteration, converged=True, items=items, states=states, gof=gof, ) prev_ll = ll # ============================================================ # M-step: re-estimate beta, eta, pi from sufficient statistics # ============================================================ # Weighted posterior: w[r, k] = counts[r] * posterior[r, k] W = counts[:, np.newaxis] * posterior # (n_patterns, n_states) N_total = counts.sum() # Pi: state prior pi = W.sum(axis=0) / N_total # (n_states,) # For each item q, compute sufficient statistics for q in range(n_items): # Expected number of respondents in states containing q # N_mastered[q] = sum_r sum_{k: q in K_k} W[r,k] states_with_q = S[:, q] # (n_states,) binary N_mastered_q = (W @ states_with_q).sum() N_not_mastered_q = N_total - N_mastered_q # Expected number of correct among mastered # C_mastered[q] = sum_r R[r,q] * sum_{k: q in K_k} W[r,k] correct_weighted = R[:, q] * (W @ states_with_q) C_mastered_q = correct_weighted.sum() # Expected number of correct among not mastered states_without_q = 1 - states_with_q correct_not_weighted = R[:, q] * (W @ states_without_q) C_not_mastered_q = correct_not_weighted.sum() # Update beta = P(incorrect | mastered) = 1 - C_mastered / N_mastered if N_mastered_q > 1e-10: beta[q] = np.clip(1 - C_mastered_q / N_mastered_q, 1e-6, 1 - 1e-6) # Update eta = P(correct | not mastered) = C_not_mastered / N_not_mastered if N_not_mastered_q > 1e-10: eta[q] = np.clip(C_not_mastered_q / N_not_mastered_q, 1e-6, 1 - 1e-6) # Enforce joint identifiability: beta + eta < 1 _EPS = 1e-6 if beta[q] + eta[q] >= 1: total = beta[q] + eta[q] # Proportionally shrink both to satisfy constraint beta[q] = beta[q] / total * (1 - _EPS) eta[q] = eta[q] / total * (1 - _EPS) gof = _compute_gof( prev_ll, n_states, n_items, data.n_patterns, data.n_respondents, beta, eta, pi, S, R, counts, ) return BLIMEstimate( beta=beta, eta=eta, pi=pi, log_likelihood=prev_ll, n_iterations=max_iter, converged=False, items=items, states=states, gof=gof, )
[docs] def estimate_blim_restarts( structure: KnowledgeStructure, data: ResponseMatrix, *, n_restarts: int = 10, max_iter: int = 500, tol: float = 1e-6, seed: int | None = None, ) -> BLIMEstimate: """Estimate BLIM parameters with multiple random restarts. Runs :func:`estimate_blim` ``n_restarts`` times with random initial values for beta, eta, and selects the result with the highest log-likelihood. This helps avoid local optima. The R ``pks`` package does not provide this natively — users must loop manually with ``randinit=TRUE``. Parameters ---------- structure : KnowledgeStructure The knowledge structure defining valid states. data : ResponseMatrix Observed response patterns. n_restarts : int Number of random restarts. Default 10. max_iter : int Maximum EM iterations per restart. tol : float Convergence tolerance per restart. seed : int or None Random seed for reproducibility. Returns ------- BLIMEstimate The best result (highest log-likelihood) across all restarts. """ if n_restarts < 1: raise ValueError(f"n_restarts must be >= 1, got {n_restarts}.") rng = np.random.default_rng(seed) n_items = len(data.items) best: BLIMEstimate | None = None for _ in range(n_restarts): # Random init: beta, eta uniform in [0.01, 0.4] with beta+eta < 1 beta_init = rng.uniform(0.01, 0.4, size=n_items) eta_init = rng.uniform(0.01, 0.4, size=n_items) # Ensure joint constraint for q in range(n_items): while beta_init[q] + eta_init[q] >= 0.95: beta_init[q] *= 0.5 eta_init[q] *= 0.5 result = estimate_blim( structure, data, max_iter=max_iter, tol=tol, beta_init=beta_init, eta_init=eta_init, ) if best is None or result.log_likelihood > best.log_likelihood: best = result assert best is not None # guaranteed by n_restarts >= 1 return best