Skip to content

Commit

Permalink
[Bug] Use torch.no_grad() in qng-spsa (#25)
Browse files Browse the repository at this point in the history
  • Loading branch information
inafergra authored Oct 14, 2024
1 parent 3bdf515 commit 9f8fcc6
Showing 1 changed file with 23 additions and 23 deletions.
46 changes: 23 additions & 23 deletions qadence_libs/qinfo_tools/qng.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,29 +215,29 @@ def qng_spsa(
See :class:`~qadence_libs.qinfo_tools.QuantumNaturalGradient` for details.
"""
with torch.no_grad():
# Get estimation of the QFI matrix
vparams_dict = dict(zip(vparams_keys, vparams_values))
qfi_estimator, qfi_mat_positive_sd = get_quantum_fisher_spsa(
circuit=circuit,
iteration=state["iter"],
vparams_dict=vparams_dict,
previous_qfi_estimator=state["qfi_estimator"],
epsilon=epsilon,
beta=beta,
)

# Get estimation of the QFI matrix
vparams_dict = dict(zip(vparams_keys, vparams_values))
qfi_estimator, qfi_mat_positive_sd = get_quantum_fisher_spsa(
circuit=circuit,
iteration=state["iter"],
vparams_dict=vparams_dict,
previous_qfi_estimator=state["qfi_estimator"],
epsilon=epsilon,
beta=beta,
)

# Get transformed gradient vector solving the least squares problem
transf_grad = torch.linalg.lstsq(
0.25 * qfi_mat_positive_sd,
grad_vec,
driver="gelsd",
).solution
# Get transformed gradient vector solving the least squares problem
transf_grad = torch.linalg.lstsq(
0.25 * qfi_mat_positive_sd,
grad_vec,
driver="gelsd",
).solution

for i, p in enumerate(vparams_values):
if p.grad is None:
continue
p.data.add_(transf_grad[i], alpha=-lr)
for i, p in enumerate(vparams_values):
if p.grad is None:
continue
p.data.add_(transf_grad[i], alpha=-lr)

state["iter"] += 1
state["qfi_estimator"] = qfi_estimator
state["iter"] += 1
state["qfi_estimator"] = qfi_estimator

0 comments on commit 9f8fcc6

Please sign in to comment.