Source code for knowledgespaces.estimation.blim_em

"""
EM algorithm for BLIM parameter estimation.

Estimates slip (β) and guess (η) parameters from observed response
patterns using Expectation-Maximization. Supports both global
(homogeneous) and per-item (heterogeneous) parameterization.

The algorithm iterates between:
- E-step: compute posterior P(state | response pattern) for each pattern.
- M-step: re-estimate β, η, and state prior π from sufficient statistics.

References:
    Doignon, J.-P., & Falmagne, J.-C. (1999).
    Knowledge Spaces, Chapter 12. Springer-Verlag.

    Heller, J., & Wickelmaier, F. (2013).
    Minimum discrepancy estimation in probabilistic knowledge structures.
    Electronic Notes in Discrete Mathematics, 42, 49-56.
"""

from __future__ import annotations

import warnings
from dataclasses import dataclass
from typing import Literal

import numpy as np
from scipy import stats

from knowledgespaces.structures.knowledge_structure import KnowledgeStructure


[docs] class ConvergenceWarning(UserWarning): """Emitted when an iterative estimator fails to meet its convergence criterion."""
[docs] @dataclass class ResponseMatrix: """Observed response patterns from a group of respondents. Parameters ---------- items : list[str] Item labels (columns), must match the structure's domain. patterns : np.ndarray Binary matrix of shape (n_respondents, n_items). patterns[r, q] = 1 if respondent r answered item q correctly. counts : np.ndarray or None Optional frequency for each unique pattern. If None, each row in patterns is one respondent (count=1 each). """ items: list[str] patterns: np.ndarray counts: np.ndarray | None = None def __post_init__(self) -> None: # Item uniqueness seen: set[str] = set() for item in self.items: if item in seen: raise ValueError(f"Duplicate item label: '{item}'.") seen.add(item) if self.patterns.ndim != 2: raise ValueError(f"patterns must be 2D, got {self.patterns.ndim}D.") if self.patterns.shape[1] != len(self.items): raise ValueError( f"patterns has {self.patterns.shape[1]} columns but {len(self.items)} items." ) if not np.isin(self.patterns, [0, 1]).all(): raise ValueError("patterns must contain only 0 and 1.") if self.counts is not None: if len(self.counts) != self.patterns.shape[0]: raise ValueError( f"counts length {len(self.counts)} != patterns rows {self.patterns.shape[0]}." ) if np.any(self.counts < 0): raise ValueError("counts must be non-negative.") if self.counts.sum() == 0: raise ValueError("Total count must be positive (at least one respondent).") @property def n_patterns(self) -> int: return self.patterns.shape[0] @property def n_items(self) -> int: return len(self.items) @property def effective_counts(self) -> np.ndarray: """Counts for each pattern (ones if not provided).""" if self.counts is not None: return self.counts return np.ones(self.n_patterns) @property def n_respondents(self) -> float: return float(self.effective_counts.sum())
[docs] @dataclass(frozen=True) class GoodnessOfFit: """Goodness-of-fit statistics for a BLIM estimate. Follows the approach of Heller & Wickelmaier (2013) and the R ``pks`` package (Heller & Wickelmaier, 2013, J. Stat. Softw.). The primary statistic is the likelihood ratio G2 (deviance), tested against a chi-squared distribution. Attributes ---------- G2 : float Likelihood ratio statistic: 2 * sum_r N_r ln(N_r / E_r). df : int Degrees of freedom: ``max(min(2^Q - 1, N) - npar, 0)``. The ``min(2^Q - 1, N)`` cap follows the ``pks`` convention: when the total sample size N is smaller than the number of possible response patterns (2^Q), the saturated model cannot be fully identified, so N replaces 2^Q - 1. p_value : float P-value from chi-squared test on G2. npar : int Number of free parameters: ``|K| - 1 + 2 * Q``. AIC : float Akaike Information Criterion: ``-2*LL + 2*npar``. BIC : float Bayesian Information Criterion using the total sample size: ``-2*LL + ln(N)*npar``. This is the standard BIC definition of Schwarz (1978, p. 461), where ``N`` is the number of independent observations contributing to the likelihood. For BLIM each of the ``N`` respondents supplies one i.i.d. draw from the pattern distribution, so the Laplace-approximation derivation of the ``log(N)·npar`` penalty applies. This is the recommended primary criterion for model selection. BIC_npatterns : float Variant Bayesian Information Criterion using the number of distinct observed response patterns: ``-2*LL + ln(n_patterns)*npar``. This matches what R ``pks::blim()`` returns: pks does not define an explicit ``BIC`` method, instead overriding ``nobs.blim`` to return the count of distinct patterns and delegating to ``stats::BIC`` (see ``cran/pks/R/blim.R``, ``logLik.blim`` / ``nobs.blim``). Provided for cross-package replication; not recommended as a primary selection criterion because the count of distinct patterns is bounded above by ``2^Q`` and therefore does not satisfy the asymptotic-consistency conditions of Schwarz (1978). """ G2: float df: int p_value: float npar: int AIC: float BIC: float BIC_npatterns: float
[docs] @dataclass(frozen=True) class BLIMEstimate: """Result of BLIM parameter estimation via EM. Attributes ---------- beta : np.ndarray Slip parameters, shape (n_items,). beta[q] = P(incorrect | q mastered). eta : np.ndarray Guess parameters, shape (n_items,). eta[q] = P(correct | q not mastered). pi : np.ndarray State prior probabilities, shape (n_states,). log_likelihood : float Final log-likelihood of the data. n_iterations : int Number of EM iterations until convergence. converged : bool True if converged within max_iter. items : list[str] Item labels corresponding to beta/eta indices. states : list[frozenset[str]] Knowledge states corresponding to pi indices (same order). gof : GoodnessOfFit Goodness-of-fit statistics (G2, df, p-value, AIC, BIC). degenerate_items : tuple[str, ...] Items whose final ``beta[q] + eta[q] >= 1 - 1e-3``. Such items are non-informative under the current knowledge structure: a mastering respondent is no more likely to answer correctly than a non-mastering one. The literature (Spoto, Stefanutti & Vidotto, 2013) treats this as a structural diagnostic — typically the item should be removed or the structure revised. Empty tuple when no item is degenerate. """ beta: np.ndarray eta: np.ndarray pi: np.ndarray log_likelihood: float n_iterations: int converged: bool items: list[str] states: list[frozenset[str]] gof: GoodnessOfFit degenerate_items: tuple[str, ...]
[docs] def beta_for(self, item: str) -> float: """Get beta (slip) for a specific item.""" return float(self.beta[self.items.index(item)])
[docs] def eta_for(self, item: str) -> float: """Get eta (guess) for a specific item.""" return float(self.eta[self.items.index(item)])
[docs] def beta_dict(self) -> dict[str, float]: """Return beta as {item: value} dict.""" return dict(zip(self.items, self.beta.tolist(), strict=True))
[docs] def eta_dict(self) -> dict[str, float]: """Return eta as {item: value} dict.""" return dict(zip(self.items, self.eta.tolist(), strict=True))
[docs] def pi_dict(self) -> dict[frozenset[str], float]: """Return pi as {state: probability} dict.""" return dict(zip(self.states, self.pi.tolist(), strict=True))
def _compute_gof( log_likelihood: float, n_states: int, n_items: int, n_patterns: int, n_respondents: float, beta: np.ndarray, eta: np.ndarray, pi: np.ndarray, S: np.ndarray, R: np.ndarray, counts: np.ndarray, ) -> GoodnessOfFit: """Compute goodness-of-fit statistics for a BLIM estimate. Follows the R ``pks`` package conventions (Heller & Wickelmaier, 2013). G2 is computed over *unique* response patterns with their observed frequencies, compared against expected frequencies from the model. """ N = n_respondents # Aggregate to unique patterns with observed frequencies unique_patterns, inverse = np.unique(R, axis=0, return_inverse=True) observed_freq = np.zeros(len(unique_patterns)) for i, idx in enumerate(inverse): observed_freq[idx] += counts[i] n_unique = len(unique_patterns) # Predicted probability of each unique pattern: P(R) = sum_K P(R|K)*P(K) p_correct = S * (1 - beta) + (1 - S) * eta # (n_states, n_items) log_pc = np.log(np.clip(p_correct, 1e-300, None)) log_pinc = np.log(np.clip(1 - p_correct, 1e-300, None)) log_lik_rk = unique_patterns @ log_pc.T + (1 - unique_patterns) @ log_pinc.T log_prior = np.log(np.clip(pi, 1e-300, None)) log_joint = log_lik_rk + log_prior max_log = log_joint.max(axis=1, keepdims=True) log_P_R = (max_log + np.log(np.exp(log_joint - max_log).sum(axis=1, keepdims=True))).ravel() # Expected counts E = N * np.exp(log_P_R) # G2 = 2 * sum_r N_r * ln(N_r / E_r), skipping zero-count patterns mask = observed_freq > 0 G2 = float(2 * np.sum(observed_freq[mask] * np.log(observed_freq[mask] / E[mask]))) # Free parameters: |K|-1 (pi) + Q (beta) + Q (eta) npar = (n_states - 1) + 2 * n_items # Degrees of freedom: the saturated model has one free parameter per # observable cell minus the sum-to-1 constraint, i.e. 2^Q - 1. When # the total sample size N is smaller than 2^Q - 1, the saturated model # is not fully identifiable and the cap drops to N (pks convention). n_possible = 2**n_items n_saturated = min(n_possible - 1, round(N)) if n_items > 20: warnings.warn( f"Domain has {n_items} items — the full 2^Q pattern space " f"({n_possible}) is very large. GOF df may be unreliable.", stacklevel=3, ) df = max(int(n_saturated - npar), 0) # P-value p_value = float(1 - stats.chi2.cdf(G2, df)) if df > 0 else 1.0 # Information criteria. # Primary BIC follows Schwarz (1978): N = number of independent # observations (here, total respondents). pks::blim() instead returns # log(n_patterns)*npar via nobs.blim override; that variant is exposed # as BIC_npatterns for cross-package replication but is not the # asymptotically consistent BIC because n_patterns is bounded by 2^Q. AIC = -2 * log_likelihood + 2 * npar BIC = -2 * log_likelihood + np.log(N) * npar # Schwarz (1978) BIC_npatterns = -2 * log_likelihood + np.log(n_unique) * npar # pks variant return GoodnessOfFit( G2=G2, df=df, p_value=p_value, npar=npar, AIC=float(AIC), BIC=float(BIC), BIC_npatterns=float(BIC_npatterns), ) def _log_likelihood_at_params( structure: KnowledgeStructure, data: ResponseMatrix, beta: np.ndarray, eta: np.ndarray, pi: np.ndarray, ) -> float: """Compute the BLIM log-likelihood at the given parameters. Mirrors the E-step log-marginal used inside :func:`estimate_blim`. Used to keep ``BLIMEstimate.log_likelihood`` consistent with the returned ``beta/eta/pi`` in the non-convergence branch, and exposed for regression testing. """ items = data.items states = sorted(structure.states, key=lambda s: (len(s), sorted(s))) n_items = len(items) n_states = len(states) item_idx = {item: i for i, item in enumerate(items)} S = np.zeros((n_states, n_items), dtype=np.float64) for k, state in enumerate(states): for item in state: S[k, item_idx[item]] = 1.0 R = data.patterns.astype(np.float64) counts = data.effective_counts p_correct = S * (1 - beta) + (1 - S) * eta log_p_correct = np.log(np.clip(p_correct, 1e-300, None)) log_p_incorrect = np.log(np.clip(1 - p_correct, 1e-300, None)) log_lik_rk = R @ log_p_correct.T + (1 - R) @ log_p_incorrect.T log_joint = log_lik_rk + np.log(np.clip(pi, 1e-300, None)) max_log = log_joint.max(axis=1, keepdims=True) log_marginal = max_log + np.log(np.exp(log_joint - max_log).sum(axis=1, keepdims=True)) return float((counts * log_marginal.ravel()).sum()) _DEGENERATE_TOL = 1e-3 """Slack on the ``beta + eta < 1`` informativeness condition. Items whose final ``beta[q] + eta[q] >= 1 - _DEGENERATE_TOL`` are flagged as degenerate.""" def _degenerate_items(items: list[str], beta: np.ndarray, eta: np.ndarray) -> tuple[str, ...]: """Return items whose ``beta + eta`` is at the degenerate boundary.""" threshold = 1.0 - _DEGENERATE_TOL return tuple(items[q] for q in range(len(items)) if beta[q] + eta[q] >= threshold) def _warn_degenerate(degenerate: tuple[str, ...]) -> None: """Emit a ConvergenceWarning if any item is at the degenerate boundary.""" if not degenerate: return warnings.warn( f"BLIM EM converged with degenerate items {list(degenerate)} " f"(beta + eta >= 1 - {_DEGENERATE_TOL:g}). Such items are " f"non-informative under the current knowledge structure: a " f"mastering respondent is no more likely to answer correctly " f"than a non-mastering one. Consider removing the item or " f"revising the structure (Spoto, Stefanutti & Vidotto 2013, " f"Behav. Res. Methods 45:1197-1211).", category=ConvergenceWarning, stacklevel=3, )
[docs] def estimate_blim( structure: KnowledgeStructure, data: ResponseMatrix, *, max_iter: int = 500, tol: float = 1e-6, beta_init: float | np.ndarray = 0.1, eta_init: float | np.ndarray = 0.1, max_memory_bytes: int = 8_000_000_000, ) -> BLIMEstimate: """Estimate BLIM parameters via Expectation-Maximization. Parameters ---------- structure : KnowledgeStructure The knowledge structure defining valid states. data : ResponseMatrix Observed response patterns. max_iter : int Maximum number of EM iterations. Default 500. tol : float Convergence tolerance on log-likelihood change. Default 1e-6. beta_init : float or np.ndarray Initial slip values. Scalar for homogeneous, array for per-item. eta_init : float or np.ndarray Initial guess values. Scalar for homogeneous, array for per-item. max_memory_bytes : int Hard cap on the estimated allocation for the posterior matrix (``n_patterns × n_states × 8`` bytes). If the estimate exceeds this value a :class:`MemoryError` is raised before any large array is allocated. Default 8 GB. Returns ------- BLIMEstimate Estimated parameters, log-likelihood, and convergence info. Raises ------ ValueError If data items don't match structure domain, or if init parameters are out of range. MemoryError If the estimated posterior allocation would exceed ``max_memory_bytes``. Notes ----- The M-step independently clips ``beta[q]`` and ``eta[q]`` into ``[1e-6, 1 - 1e-6]``, mirroring R ``pks::blim()``. The canonical BLIM parameter space is the open box per item; the joint condition ``beta[q] + eta[q] < 1`` is the *informative item* condition (Falmagne & Doignon 2011 §11), not part of the parameter space, and is therefore *not* enforced inside the loop — doing so would break EM monotonicity (Dempster, Laird & Rubin 1977). Items whose final ``beta[q] + eta[q] >= 1 - 1e-3`` are surfaced via :attr:`BLIMEstimate.degenerate_items` and a :class:`ConvergenceWarning` is emitted. Such items are non-informative under the current knowledge structure and the recommended remedy is structural — drop the item or revise the structure (Spoto, Stefanutti & Vidotto 2013). References ---------- Dempster, A. P., Laird, N. M., & Rubin, D. B. (1977). Maximum likelihood from incomplete data via the EM algorithm. *J. R. Stat. Soc. B*, 39(1), 1-38. Heller, J., & Wickelmaier, F. (2013). Minimum discrepancy estimation in probabilistic knowledge structures. *ENDM*, 42, 49-56. Spoto, A., Stefanutti, L., & Vidotto, G. (2013). Assessing the local identifiability of probabilistic knowledge structures. *Behavior Research Methods*, 45(4), 1197-1211. """ # Validate hyperparameters if max_iter < 1: raise ValueError(f"max_iter must be >= 1, got {max_iter}.") if tol <= 0: raise ValueError(f"tol must be > 0, got {tol}.") # Validate domain match if set(data.items) != structure.domain: raise ValueError( f"ResponseMatrix items {set(data.items)} don't match " f"structure domain {set(structure.domain)}." ) items = data.items n_items = len(items) counts = data.effective_counts # Build state matrix: S[k, q] = 1 if item q is in state k states = sorted(structure.states, key=lambda s: (len(s), sorted(s))) n_states = len(states) # Preflight: the E-step allocates several (n_patterns × n_states) float64 # matrices (log_lik_rk, log_joint, posterior, W). For powerset structures # over large domains this explodes to tens of GB and the OS-kills the # process without a useful message. Estimate the dominant allocation and # fail fast above a hard cap (default 8 GB); warn above 1 GB. estimated_bytes = data.n_patterns * n_states * 8 if estimated_bytes > max_memory_bytes: raise MemoryError( f"BLIM EM would allocate ~{estimated_bytes / 1e9:.2f} GB for the " f"posterior matrix ({data.n_patterns} patterns x {n_states} states " f"x 8 bytes). Exceeds max_memory_bytes=" f"{max_memory_bytes / 1e9:.2f} GB. Reduce domain size, use a " f"sparser knowledge structure, or pass max_memory_bytes=... to " f"override." ) if estimated_bytes > 1_000_000_000: warnings.warn( f"BLIM EM will allocate ~{estimated_bytes / 1e9:.2f} GB for the " f"posterior matrix ({data.n_patterns} patterns x {n_states} " f"states x 8 bytes).", category=ResourceWarning, stacklevel=2, ) item_idx = {item: i for i, item in enumerate(items)} S = np.zeros((n_states, n_items), dtype=np.float64) for k, state in enumerate(states): for item in state: S[k, item_idx[item]] = 1.0 # Initialize and validate parameters if isinstance(beta_init, (int, float)): beta = np.full(n_items, float(beta_init)) else: beta = np.array(beta_init, dtype=np.float64).copy() if beta.shape != (n_items,): raise ValueError(f"beta_init array has shape {beta.shape}, expected ({n_items},).") if isinstance(eta_init, (int, float)): eta = np.full(n_items, float(eta_init)) else: eta = np.array(eta_init, dtype=np.float64).copy() if eta.shape != (n_items,): raise ValueError(f"eta_init array has shape {eta.shape}, expected ({n_items},).") if np.any(beta < 0) or np.any(beta >= 1): raise ValueError("All beta_init values must be in [0, 1).") if np.any(eta < 0) or np.any(eta >= 1): raise ValueError("All eta_init values must be in [0, 1).") if np.any(beta + eta >= 1): raise ValueError("beta_init + eta_init must be < 1 for each item.") pi = np.full(n_states, 1.0 / n_states) R = data.patterns.astype(np.float64) # (n_patterns, n_items) prev_ll = -np.inf for iteration in range(1, max_iter + 1): # ============================================================ # E-step: compute P(pattern | state) and posterior P(state | pattern) # ============================================================ # P(R_r | K_k) = prod_q [ P(R_rq | K_k, q) ] # where P(correct | q in K) = 1-beta_q, P(correct | q not in K) = eta_q # Likelihood per (pattern, state, item) # P(R_rq=1 | K_k) = S[k,q]*(1-beta[q]) + (1-S[k,q])*eta[q] p_correct = S * (1 - beta) + (1 - S) * eta # (n_states, n_items) p_incorrect = 1 - p_correct # (n_states, n_items) # P(R_r | K_k) for each (pattern, state) # log for numerical stability log_p_correct = np.log(np.clip(p_correct, 1e-300, None)) log_p_incorrect = np.log(np.clip(p_incorrect, 1e-300, None)) # log P(R_r | K_k) = sum_q [ R_rq * log_p_correct[k,q] + (1-R_rq) * log_p_incorrect[k,q] ] log_lik_rk = R @ log_p_correct.T + (1 - R) @ log_p_incorrect.T # (n_patterns, n_states) # Add log prior log_joint = log_lik_rk + np.log(np.clip(pi, 1e-300, None)) # (n_patterns, n_states) # Log-sum-exp for numerical stability max_log = log_joint.max(axis=1, keepdims=True) log_marginal = max_log + np.log(np.exp(log_joint - max_log).sum(axis=1, keepdims=True)) # Posterior: P(K_k | R_r) = exp(log_joint - log_marginal) posterior = np.exp(log_joint - log_marginal) # (n_patterns, n_states) # Log-likelihood ll = float((counts * log_marginal.ravel()).sum()) # Check convergence if abs(ll - prev_ll) < tol: gof = _compute_gof( ll, n_states, n_items, data.n_patterns, data.n_respondents, beta, eta, pi, S, R, counts, ) degenerate = _degenerate_items(items, beta, eta) _warn_degenerate(degenerate) return BLIMEstimate( beta=beta, eta=eta, pi=pi, log_likelihood=ll, n_iterations=iteration, converged=True, items=items, states=states, gof=gof, degenerate_items=degenerate, ) prev_ll = ll # ============================================================ # M-step: re-estimate beta, eta, pi from sufficient statistics # ============================================================ # Weighted posterior: w[r, k] = counts[r] * posterior[r, k] W = counts[:, np.newaxis] * posterior # (n_patterns, n_states) N_total = counts.sum() # Pi: state prior pi = W.sum(axis=0) / N_total # (n_states,) # For each item q, compute sufficient statistics for q in range(n_items): # Expected number of respondents in states containing q # N_mastered[q] = sum_r sum_{k: q in K_k} W[r,k] states_with_q = S[:, q] # (n_states,) binary N_mastered_q = (W @ states_with_q).sum() N_not_mastered_q = N_total - N_mastered_q # Expected number of correct among mastered # C_mastered[q] = sum_r R[r,q] * sum_{k: q in K_k} W[r,k] correct_weighted = R[:, q] * (W @ states_with_q) C_mastered_q = correct_weighted.sum() # Expected number of correct among not mastered states_without_q = 1 - states_with_q correct_not_weighted = R[:, q] * (W @ states_without_q) C_not_mastered_q = correct_not_weighted.sum() # Update beta = P(incorrect | mastered) = 1 - C_mastered / N_mastered # Update eta = P(correct | not mastered) = C_not_mastered / N_not_mastered # Per-coordinate clip into [eps, 1-eps] mirrors `pks::blim()` (see # `cran/pks/R/blim.R`, `blimEM`); the canonical BLIM parameter # space is the open box (0, 1) per item independently — the joint # constraint beta+eta < 1 is the "informative item" condition, # not part of the parameter space (Falmagne & Doignon 2011 §11), # so it is *not* enforced inside the M-step. Doing so would # break the EM monotonicity guarantee (Dempster, Laird & Rubin # 1977). Items violating beta+eta < 1 at convergence are # surfaced via :attr:`BLIMEstimate.degenerate_items`. if N_mastered_q > 1e-10: beta[q] = np.clip(1 - C_mastered_q / N_mastered_q, 1e-6, 1 - 1e-6) if N_not_mastered_q > 1e-10: eta[q] = np.clip(C_not_mastered_q / N_not_mastered_q, 1e-6, 1 - 1e-6) warnings.warn( f"BLIM EM did not converge in {max_iter} iterations " f"(|Δ log-likelihood| > tol = {tol:.3e}). " f"Consider increasing max_iter, using estimate_blim_restarts, " f"or inspecting the knowledge structure for identifiability issues.", category=ConvergenceWarning, stacklevel=2, ) # Recompute log-likelihood on the post-M-step parameters so that the # returned `log_likelihood` is consistent with the returned beta/eta/pi. # Inside the loop, `ll` is computed at the E-step (before that # iteration's M-step), so at loop exit `prev_ll` reflects the params # from one M-step ago — off-by-one vs. the returned parameters. final_ll = _log_likelihood_at_params(structure, data, beta, eta, pi) gof = _compute_gof( final_ll, n_states, n_items, data.n_patterns, data.n_respondents, beta, eta, pi, S, R, counts, ) degenerate = _degenerate_items(items, beta, eta) _warn_degenerate(degenerate) return BLIMEstimate( beta=beta, eta=eta, pi=pi, log_likelihood=final_ll, n_iterations=max_iter, converged=False, items=items, states=states, gof=gof, degenerate_items=degenerate, )
[docs] def estimate_blim_restarts( structure: KnowledgeStructure, data: ResponseMatrix, *, n_restarts: int = 10, max_iter: int = 500, tol: float = 1e-6, seed: int | None = None, max_memory_bytes: int = 8_000_000_000, init_range: tuple[float, float] = (0.01, 0.4), init_strategy: Literal["uniform", "pks"] = "uniform", ) -> BLIMEstimate: """Estimate BLIM parameters with multiple random restarts. Runs :func:`estimate_blim` ``n_restarts`` times with random initial values for beta, eta, and selects the result with the highest log-likelihood. This helps avoid local optima. The R ``pks`` package does not provide this natively — users must loop manually with ``randinit=TRUE``. Parameters ---------- structure : KnowledgeStructure The knowledge structure defining valid states. data : ResponseMatrix Observed response patterns. n_restarts : int Number of random restarts. Default 10. max_iter : int Maximum EM iterations per restart. tol : float Convergence tolerance per restart. seed : int or None Random seed for reproducibility. max_memory_bytes : int Forwarded to :func:`estimate_blim`. Default 8 GB. init_range : tuple[float, float] Lower/upper bounds for the random ``U(low, high)`` draw of ``beta_init`` and ``eta_init`` when ``init_strategy="uniform"``. Default ``(0.01, 0.4)``. Ignored when ``init_strategy="pks"``. init_strategy : {"uniform", "pks"} Random initialization strategy. ``"uniform"`` (default) draws both parameters from ``U(*init_range)`` and rescales in-place until ``beta[q] + eta[q] < 0.95`` on each item. ``"pks"`` mirrors ``pks::blim(..., randinit=TRUE)`` (R source: ``cran/pks/R/blim.R``): each parameter is drawn from ``U(0, 1)``, then reflected as ``1 - x`` on items where ``beta[q] + eta[q] >= 1`` to restore the identifiability constraint. Returns ------- BLIMEstimate The best result (highest log-likelihood) across all restarts. Notes ----- The default ``init_range=(0.01, 0.4)`` is a narrowed basin that avoids near-boundary draws at the identifiability frontier ``beta + eta = 1``, where EM can stall in a degenerate-item attractor (cf. Spoto, Stefanutti & Vidotto 2013). The ``"pks"`` strategy is provided for reproducibility with the R ``pks`` package: note that pks uses ``runif(nitems)`` = ``U(0, 1)`` with reflection, **not** ``U(0, 0.5)`` as sometimes reported. References ---------- Heller, J., & Wickelmaier, F. (2013). Minimum discrepancy estimation in probabilistic knowledge structures. *ENDM*, 42, 49-56. """ if n_restarts < 1: raise ValueError(f"n_restarts must be >= 1, got {n_restarts}.") lo, hi = init_range if not (0.0 <= lo < hi < 1.0): raise ValueError(f"init_range must satisfy 0 <= low < high < 1, got {init_range}.") if init_strategy not in ("uniform", "pks"): raise ValueError(f"init_strategy must be 'uniform' or 'pks', got {init_strategy!r}.") rng = np.random.default_rng(seed) n_items = len(data.items) best: BLIMEstimate | None = None n_failed = 0 for restart_i in range(n_restarts): if init_strategy == "uniform": # Draw from U(lo, hi); rescale per-item to restore beta+eta < 0.95. beta_init = rng.uniform(lo, hi, size=n_items) eta_init = rng.uniform(lo, hi, size=n_items) for q in range(n_items): while beta_init[q] + eta_init[q] >= 0.95: beta_init[q] *= 0.5 eta_init[q] *= 0.5 else: # "pks" # Mirror pks::blim(..., randinit=TRUE): U(0,1) + per-item reflect. beta_init = rng.uniform(0.0, 1.0, size=n_items) eta_init = rng.uniform(0.0, 1.0, size=n_items) for q in range(n_items): if beta_init[q] + eta_init[q] >= 1.0: beta_init[q] = 1.0 - beta_init[q] eta_init[q] = 1.0 - eta_init[q] # Numerical guard: avoid exact 0 / 1 which estimate_blim rejects. beta_init = np.clip(beta_init, 1e-6, 1 - 1e-6) eta_init = np.clip(eta_init, 1e-6, 1 - 1e-6) # Suppress per-restart ConvergenceWarning: a single aggregate warning # is emitted below when all restarts fail. Per-restart non-convergence # is expected during the search and only relevant in aggregate. # ResourceWarning (memory preflight) is deduplicated to fire only on # the first restart — the estimate is identical across all restarts. with warnings.catch_warnings(): warnings.simplefilter("ignore", ConvergenceWarning) if restart_i > 0: warnings.simplefilter("ignore", ResourceWarning) result = estimate_blim( structure, data, max_iter=max_iter, tol=tol, beta_init=beta_init, eta_init=eta_init, max_memory_bytes=max_memory_bytes, ) if not result.converged: n_failed += 1 if best is None or result.log_likelihood > best.log_likelihood: best = result if n_failed == n_restarts: warnings.warn( f"BLIM EM did not converge in any of the {n_restarts} restarts " f"(max_iter={max_iter}, tol={tol:.3e}). " f"The returned estimate is the best (highest log-likelihood) " f"non-converged fit. Consider increasing max_iter or n_restarts.", category=ConvergenceWarning, stacklevel=2, ) assert best is not None # guaranteed by n_restarts >= 1 # Re-emit degenerate-item warning on the *selected* best fit. Per-restart # warnings are suppressed above to avoid n_restarts duplicates; surfacing # the diagnostic on the chosen estimate keeps the user informed. _warn_degenerate(best.degenerate_items) return best