"""Evaluate the quantum learnability semidefinite programs."""
import warnings
from itertools import combinations
from typing import Any, Iterable, Sequence
import cvxpy as cp
import cvxpy.settings as cp_settings
import numpy as np
import scipy.sparse as sp
from toqito.matrix_ops import to_density_matrix, vectors_to_gram_matrix
from toqito.matrix_props import is_positive_semidefinite, is_rank_one
[docs]
def learnability(
states: Sequence[np.ndarray],
k: int,
*,
solver: str | None = "SCS",
solver_kwargs: dict[str, Any] | None = None,
verify_reduced: bool = True,
verify_tolerance: float = 1e-4,
tol: float = 1e-8,
) -> dict[str, float | str | None | dict]:
r"""Compute the average error value of the learnability semidefinite program.
This routine minimizes
\[
\frac{1}{n} \sum_{i = 1}^n \left\langle \rho_i,
\sum_{S: i \notin S} M_S \right\rangle.
\]
over POVM elements \((M_S)\) indexed by ``k``-element subsets, subject to
\(\sum_S M_S = \mathbb{I}\) and \(M_S \succeq 0\). When all inputs are pure, the
reduced Gram-matrix SDP
\[
\sum_{i = 1}^n \bra{i} \sum_{S: i \notin S} W_S \ket{i}.
\]
with constraint \(\sum_S W_S = G\) (Gram matrix) and \(W_S \succeq 0\)
is also solved as a consistency check.
Examples:
```python exec="1" source="above"
from toqito.state_props import learnability
from toqito.states import basis
e0, e1 = basis(2, 0), basis(2, 1)
print(learnability(
[e0, e1],
k=1,
solver="SCS",
solver_kwargs={"eps": 1e-6, "max_iters": 5_000},
))
```
Args:
states: Sequence of state vectors or density matrices acting on the same space.
k: Subset size for the POVM outcomes; must satisfy `1 <= k <= len(states)`.
solver: Optional CVXPY solver name. Defaults to `"SCS"`.
solver_kwargs: Extra keyword arguments forwarded to :meth:`cvxpy.Problem.solve`.
verify_reduced: If `True` and the states are pure, also solve the reduced SDP.
verify_tolerance: Absolute tolerance used when comparing the two optimal values.
tol: Numerical tolerance used when validating positivity and rank-one states.
Returns:
Dictionary with keys `value`, `total_value`, `status`, `measurement_operators`, and optionally `reduced_value`,
`reduced_total_value`, `reduced_status`, `reduced_operators`.
Raises:
ValueError: If the data are inconsistent with valid quantum states or if `k` lies outside the permissible range.
cvxpy.error.SolverError: If the selected solver reports a failure.
"""
if not states:
raise ValueError("The list of states must be non-empty.")
density_matrices, candidate_vectors = _convert_states(states, tol=tol)
general_value, general_status, measurement_variables = _solve_learnability_general(
density_matrices,
k,
solver=solver,
solver_kwargs=solver_kwargs,
)
operator_values = {combo: measurement_variables[combo].value for combo in measurement_variables}
result: dict[str, float | str | None | dict] = {
"value": float(np.real(general_value)),
"status": general_status,
"reduced_value": None,
"reduced_status": None,
"measurement_operators": operator_values,
"reduced_operators": None,
"total_value": float(np.real(general_value)) * len(density_matrices),
}
result["reduced_total_value"] = None
if verify_reduced and candidate_vectors is not None:
gram = vectors_to_gram_matrix(candidate_vectors)
reduced_value, reduced_status, reduced_variables = _solve_learnability_reduced(
gram,
k,
solver=solver,
solver_kwargs=solver_kwargs,
)
reduced_operator_values = {combo: var.value for combo, var in reduced_variables.items()}
result["reduced_value"] = float(np.real(reduced_value))
result["reduced_status"] = reduced_status
result["reduced_operators"] = reduced_operator_values
result["reduced_total_value"] = float(np.real(reduced_value)) * len(density_matrices)
if abs(result["value"] - result["reduced_value"]) > verify_tolerance:
warnings.warn(
(
"General and reduced SDP optimal values differ by more than "
f"{verify_tolerance}. General value: {result['value']}, "
f"reduced value: {result['reduced_value']}."
),
RuntimeWarning,
)
return result
def _solve_learnability_general(
density_matrices: Sequence[np.ndarray],
k: int,
*,
solver: str | None,
solver_kwargs: dict[str, Any] | None,
) -> tuple[float, str, dict[tuple[int, ...], cp.Variable]]:
n = len(density_matrices)
if not 1 <= k <= n:
raise ValueError(f"k must satisfy 1 <= k <= n (= {n}).")
dim = density_matrices[0].shape[0]
combos = list(combinations(range(n), k))
variables = {combo: cp.Variable((dim, dim), hermitian=True) for combo in combos}
constraints = [var >> 0 for var in variables.values()]
constraints.append(_sum_expressions(variables.values()) == np.eye(dim, dtype=np.complex128))
objective_terms = []
for idx, rho in enumerate(density_matrices):
without_idx = [var for combo, var in variables.items() if idx not in combo]
if not without_idx:
objective_terms.append(0.0)
continue
summed = _sum_expressions(without_idx)
objective_terms.append(cp.real(cp.trace(rho @ summed)) / n)
problem = cp.Problem(cp.Minimize(cp.sum(objective_terms)), constraints)
value, status = _solve_problem(problem, solver, solver_kwargs)
return value, status, variables
def _solve_learnability_reduced(
gram_matrix: np.ndarray,
k: int,
*,
solver: str | None,
solver_kwargs: dict[str, Any] | None,
) -> tuple[float, str, dict[tuple[int, ...], cp.Variable]]:
n = gram_matrix.shape[0]
if not 1 <= k <= n:
raise ValueError(f"k must satisfy 1 <= k <= n (= {n}).")
combos = list(combinations(range(n), k))
variables = {combo: cp.Variable((n, n), hermitian=True) for combo in combos}
constraints = [var >> 0 for var in variables.values()]
constraints.append(_sum_expressions(variables.values()) == gram_matrix)
objective_terms = []
for idx in range(n):
without_idx = [var for combo, var in variables.items() if idx not in combo]
if not without_idx:
objective_terms.append(0.0)
continue
summed = _sum_expressions(without_idx)
objective_terms.append(cp.real(summed[idx, idx]) / n)
problem = cp.Problem(cp.Minimize(cp.sum(objective_terms)), constraints)
value, status = _solve_problem(problem, solver, solver_kwargs)
return value, status, variables
def _convert_states(
states: Sequence[np.ndarray],
*,
tol: float,
) -> tuple[list[np.ndarray], list[np.ndarray] | None]:
r"""Normalize input states and detect whether they are uniformly pure.
Each entry in `states` may be a state vector or a density matrix. The
routine converts every element to a unit-trace density matrix, checks
positivity, and records the original pure state vectors when all inputs are
rank one.
Args:
states: Collection of quantum states to normalize.
tol: Numerical tolerance used for positivity and rank checks.
Returns:
List of density matrices and, when available, the corresponding state vectors.
"""
density_matrices: list[np.ndarray] = []
pure_vectors: list[np.ndarray] = []
all_pure = True
dim: int | None = None
for raw_state in states:
state_array = np.asarray(raw_state, dtype=np.complex128)
rho = to_density_matrix(state_array)
rho = (rho + rho.conj().T) / 2
trace = np.trace(rho)
if np.isclose(trace, 0.0, atol=tol):
raise ValueError("Each state must have strictly positive trace.")
rho = rho / trace
if dim is None:
dim = rho.shape[0]
elif rho.shape != (dim, dim):
raise ValueError("All states must act on the same Hilbert space.")
if not is_positive_semidefinite(rho, atol=tol):
raise ValueError("Each state must be positive semidefinite.")
density_matrices.append(rho)
if all_pure:
if is_rank_one(rho, tol=tol):
pure_vectors.append(_extract_state_vector(state_array, rho))
else:
all_pure = False
if not all_pure:
pure_vectors = None
return density_matrices, pure_vectors
def _solve_problem(
problem: cp.Problem,
solver: str | None,
solver_kwargs: dict[str, Any] | None,
) -> tuple[float, str]:
"""Solve a CVXPY problem and return both the optimal value and status."""
solve_kwargs = dict(solver_kwargs or {})
if _is_scs_solver(solver):
return _solve_problem_with_scs(problem, solve_kwargs)
if solver is None:
value = problem.solve(**solve_kwargs)
else:
value = problem.solve(solver=solver, **solve_kwargs)
return value, problem.status
def _solve_problem_with_scs(
problem: cp.Problem,
solver_kwargs: dict[str, Any],
) -> tuple[float, str]:
"""Solve with SCS ensuring sparse matrices use CSC format to avoid warnings."""
warm_start = bool(solver_kwargs.pop("warm_start", False))
verbose = bool(solver_kwargs.pop("verbose", False))
data, chain, inverse_data = problem.get_problem_data(cp.SCS)
for key in (cp_settings.A, cp_settings.P):
if key in data and data[key] is not None:
data[key] = sp.csc_matrix(data[key])
solution = chain.solve_via_data(
problem,
data,
warm_start=warm_start,
verbose=verbose,
solver_opts=solver_kwargs,
)
problem.unpack_results(solution, chain, inverse_data)
return problem.value, problem.status
def _is_scs_solver(solver: Any | None) -> bool:
"""Return True when the requested solver corresponds to SCS."""
if solver is None:
return False
if solver is cp.SCS:
return True
if isinstance(solver, str) and solver.strip().upper() == "SCS":
return True
return False
def _extract_state_vector(
original: np.ndarray,
density: np.ndarray,
) -> np.ndarray:
if original.ndim == 1:
vector = original
elif original.ndim == 2 and 1 in original.shape:
vector = original.reshape(-1)
else:
eigenvalues, eigenvectors = np.linalg.eigh(density)
vector = eigenvectors[:, np.argmax(eigenvalues)]
norm = np.linalg.norm(vector)
return (vector / norm).astype(np.complex128)
def _sum_expressions(expressions: Iterable[cp.expressions.expression.Expression]):
iterator = iter(expressions)
try:
total = next(iterator)
except StopIteration:
return 0.0
for expr in iterator:
total = total + expr
return total