Skip to content

Commit

Permalink
refactor: Lift result object creation to base Acquisition
Browse files Browse the repository at this point in the history
  • Loading branch information
stavros11 committed Feb 19, 2024
1 parent 0aaf828 commit 27893f8
Showing 1 changed file with 25 additions and 12 deletions.
37 changes: 25 additions & 12 deletions src/qibolab/instruments/qm/acquisition.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@ class Acquisition(ABC):

keys: list[str] = field(default_factory=list)

RESULT_CLS = IntegratedResults
"""Result object type that corresponds to this acquisition type."""
AVERAGED_RESULT_CLS = AveragedIntegratedResults
"""Averaged result object type that corresponds to this acquisition
type."""

@property
def npulses(self):
return len(self.keys)
Expand Down Expand Up @@ -73,6 +79,13 @@ def download(self, *dimensions):
def fetch(self):
"""Fetch downloaded streams to host device."""

def result(self, data):
"""Creates Qibolab result object that is returned to the platform."""
res_cls = self.AVERAGED_RESULT_CLS if self.average else self.RESULT_CLS
if self.npulses > 1:
return [res_cls(data[..., i]) for i in range(self.npulses)]
return [res_cls(data)]


@dataclass
class RawAcquisition(Acquisition):
Expand All @@ -83,6 +96,9 @@ class RawAcquisition(Acquisition):
)
"""Stream to collect raw ADC data."""

_result_cls = RawWaveformResults
_averaged_result_cls = AveragedRawWaveformResults

def assign_element(self, element):
pass

Expand All @@ -104,9 +120,7 @@ def fetch(self, handles):
# convert raw ADC signal to volts
u = unit()
signal = u.raw2volts(ires) + 1j * u.raw2volts(qres)
if self.average:
return [AveragedRawWaveformResults(signal)]
return [RawWaveformResults(signal)]
return self.result(signal)


@dataclass
Expand All @@ -120,6 +134,9 @@ class IntegratedAcquisition(Acquisition):
qstream: _ResultSource = field(default_factory=lambda: declare_stream())
"""Streams to collect the results of all shots."""

_result_cls = IntegratedResults
_averaged_result_cls = AveragedIntegratedResults

def assign_element(self, element):
assign_variables_to_element(element, self.i, self.q)

Expand Down Expand Up @@ -152,11 +169,7 @@ def download(self, *dimensions):
def fetch(self, handles):
ires = handles.get(f"{self.name}_I").fetch_all()
qres = handles.get(f"{self.name}_Q").fetch_all()
signal = ires + 1j * qres
res_cls = AveragedIntegratedResults if self.average else IntegratedResults
if self.npulses > 1:
return [res_cls(signal[..., i]) for i in range(self.npulses)]
return [res_cls(signal)]
return self.result(ires + 1j * qres)


@dataclass
Expand All @@ -179,6 +192,9 @@ class ShotsAcquisition(Acquisition):
shots: _ResultSource = field(default_factory=lambda: declare_stream())
"""Stream to collect multiple shots."""

_result_cls = SampleResults
_averaged_result_cls = AveragedSampleResults

def __post_init__(self):
self.cos = np.cos(self.angle)
self.sin = np.sin(self.angle)
Expand Down Expand Up @@ -212,10 +228,7 @@ def download(self, *dimensions):

def fetch(self, handles):
shots = handles.get(f"{self.name}_shots").fetch_all()
res_cls = AveragedSampleResults if self.average else SampleResults
if self.npulses > 1:
return [res_cls(shots[..., i]) for i in range(self.npulses)]
return [res_cls(shots.astype(int))]
return self.result(shots)


ACQUISITION_TYPES = {
Expand Down

0 comments on commit 27893f8

Please sign in to comment.