diff --git a/qadence_libs/qinfo_tools/qng.py b/qadence_libs/qinfo_tools/qng.py index 16baab2..f2aa5a7 100644 --- a/qadence_libs/qinfo_tools/qng.py +++ b/qadence_libs/qinfo_tools/qng.py @@ -215,29 +215,29 @@ def qng_spsa( See :class:`~qadence_libs.qinfo_tools.QuantumNaturalGradient` for details. """ + with torch.no_grad(): + # Get estimation of the QFI matrix + vparams_dict = dict(zip(vparams_keys, vparams_values)) + qfi_estimator, qfi_mat_positive_sd = get_quantum_fisher_spsa( + circuit=circuit, + iteration=state["iter"], + vparams_dict=vparams_dict, + previous_qfi_estimator=state["qfi_estimator"], + epsilon=epsilon, + beta=beta, + ) - # Get estimation of the QFI matrix - vparams_dict = dict(zip(vparams_keys, vparams_values)) - qfi_estimator, qfi_mat_positive_sd = get_quantum_fisher_spsa( - circuit=circuit, - iteration=state["iter"], - vparams_dict=vparams_dict, - previous_qfi_estimator=state["qfi_estimator"], - epsilon=epsilon, - beta=beta, - ) - - # Get transformed gradient vector solving the least squares problem - transf_grad = torch.linalg.lstsq( - 0.25 * qfi_mat_positive_sd, - grad_vec, - driver="gelsd", - ).solution + # Get transformed gradient vector solving the least squares problem + transf_grad = torch.linalg.lstsq( + 0.25 * qfi_mat_positive_sd, + grad_vec, + driver="gelsd", + ).solution - for i, p in enumerate(vparams_values): - if p.grad is None: - continue - p.data.add_(transf_grad[i], alpha=-lr) + for i, p in enumerate(vparams_values): + if p.grad is None: + continue + p.data.add_(transf_grad[i], alpha=-lr) - state["iter"] += 1 - state["qfi_estimator"] = qfi_estimator + state["iter"] += 1 + state["qfi_estimator"] = qfi_estimator