Skip to content

Commit

Permalink
infra: reuse AWS calls across integ tests (#797)
Browse files Browse the repository at this point in the history
  • Loading branch information
AbeCoull authored Feb 19, 2024
1 parent e11bb31 commit 0d23ac1
Show file tree
Hide file tree
Showing 12 changed files with 223 additions and 115 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,4 @@ __pycache__/
/build
/venv
/dist
/model.tar.gz
88 changes: 88 additions & 0 deletions test/integ_tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,62 @@
# language governing permissions and limitations under the License.

import os
import random
import string

import boto3
import pytest
from botocore.exceptions import ClientError

from braket.aws.aws_device import AwsDevice
from braket.aws.aws_quantum_job import AwsQuantumJob
from braket.aws.aws_session import AwsSession

SV1_ARN = "arn:aws:braket:::device/quantum-simulator/amazon/sv1"
DM1_ARN = "arn:aws:braket:::device/quantum-simulator/amazon/dm1"
TN1_ARN = "arn:aws:braket:::device/quantum-simulator/amazon/tn1"
SIMULATOR_ARNS = [SV1_ARN, DM1_ARN, TN1_ARN]

job_complete_name = "".join(random.choices(string.ascii_lowercase + string.digits, k=12))
job_fail_name = "".join(random.choices(string.ascii_lowercase + string.digits, k=12))


def pytest_configure_node(node):
"""xdist hook"""
node.workerinput["JOB_COMPLETED_NAME"] = job_complete_name
node.workerinput["JOB_FAILED_NAME"] = job_fail_name


def pytest_xdist_node_collection_finished(ids):
"""Uses the pytest xdist hook to check whether tests with jobs are to be ran.
If they are, the first reporting worker sets a flag that it created the tests
to avoid concurrency limits. This is the first time in the pytest setup the
controller has all the tests to be ran from the worker nodes.
"""
run_jobs = any("job" in test for test in ids)
profile_name = os.environ["AWS_PROFILE"]
aws_session = AwsSession(boto3.session.Session(profile_name=profile_name))
if run_jobs and os.getenv("JOBS_STARTED") is None:
AwsQuantumJob.create(
"arn:aws:braket:::device/quantum-simulator/amazon/sv1",
job_name=job_fail_name,
source_module="test/integ_tests/job_test_script.py",
entry_point="job_test_script:start_here",
aws_session=aws_session,
wait_until_complete=False,
hyperparameters={"test_case": "failed"},
)
AwsQuantumJob.create(
"arn:aws:braket:::device/quantum-simulator/amazon/sv1",
job_name=job_complete_name,
source_module="test/integ_tests/job_test_script.py",
entry_point="job_test_script:start_here",
aws_session=aws_session,
wait_until_complete=False,
hyperparameters={"test_case": "completed"},
)
os.environ["JOBS_STARTED"] = "True"


@pytest.fixture(scope="session")
def boto_session():
Expand Down Expand Up @@ -82,3 +131,42 @@ def s3_prefix():
@pytest.fixture(scope="module")
def s3_destination_folder(s3_bucket, s3_prefix):
return AwsSession.S3DestinationFolder(s3_bucket, s3_prefix)


@pytest.fixture(scope="session")
def braket_simulators(aws_session):
return {
simulator_arn: AwsDevice(simulator_arn, aws_session) for simulator_arn in SIMULATOR_ARNS
}


@pytest.fixture(scope="session")
def braket_devices():
return AwsDevice.get_devices(statuses=["RETIRED", "ONLINE", "OFFLINE"])


@pytest.fixture(scope="session", autouse=True)
def created_braket_devices(aws_session, braket_devices):
return {device.arn: device for device in braket_devices}


@pytest.fixture(scope="session")
def job_completed_name(request):
return request.config.workerinput["JOB_COMPLETED_NAME"]


@pytest.fixture(scope="session")
def job_failed_name(request):
return request.config.workerinput["JOB_FAILED_NAME"]


@pytest.fixture(scope="session", autouse=True)
def completed_quantum_job(aws_session, job_completed_name):
job = AwsQuantumJob(arn=f"arn:aws:braket:us-west-2:046073650652:job/{job_completed_name}")
return job


@pytest.fixture(scope="session", autouse=True)
def failed_quantum_job(aws_session, job_failed_name):
job = AwsQuantumJob(arn=f"arn:aws:braket:us-west-2:046073650652:job/{job_failed_name}")
return job
10 changes: 5 additions & 5 deletions test/integ_tests/gate_model_device_testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ def result_types_observable_not_in_instructions(device: Device, run_kwargs: Dict
.variance(observable=Observable.Y(), target=[3])
)
bell_qasm = bell.to_ir(ir_type=IRType.OPENQASM)
for task in (bell, bell_qasm):
result = device.run(task, **run_kwargs).result()
results = device.run_batch([bell, bell_qasm], **run_kwargs).results()
for result in results:
assert np.allclose(result.values[0], 0, **tol)
assert np.allclose(result.values[1], 1, **tol)

Expand All @@ -103,9 +103,9 @@ def result_types_zero_shots_bell_pair_testing(
circuit.amplitude(["01", "10", "00", "11"])
if include_state_vector:
circuit.state_vector()
tasks = (circuit, circuit.to_ir(ir_type=IRType.OPENQASM))
for task in tasks:
result = device.run(task, **run_kwargs).result()
tasks = [circuit, circuit.to_ir(ir_type=IRType.OPENQASM)]
results = device.run_batch(tasks, **run_kwargs).results()
for result in results:
assert len(result.result_types) == 3 if include_state_vector else 2
assert np.allclose(
result.get_value_by_result_type(
Expand Down
4 changes: 2 additions & 2 deletions test/integ_tests/job_test_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ def completed_job_script():
device = AwsDevice(get_job_device_arn())

bell = Circuit().h(0).cnot(0, 1)
for count in range(5):
task = device.run(bell, shots=100)
for count in range(3):
task = device.run(bell, shots=10)
print(task.result().measurement_counts)
save_job_result({"converged": True, "energy": -0.2})
save_job_checkpoint({"some_data": "abc"}, checkpoint_file_suffix="plain_data")
Expand Down
1 change: 1 addition & 0 deletions test/integ_tests/test_create_local_quantum_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def test_completed_local_job(aws_session, capsys):

for data in logs_to_validate:
assert data in log_data

finally:
os.chdir(current_dir)

Expand Down
47 changes: 16 additions & 31 deletions test/integ_tests/test_create_quantum_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import re
import sys
import tempfile
import time
from pathlib import Path

import job_test_script
Expand All @@ -36,27 +37,24 @@ def decorator_python_version():
return int(major_version), int(minor_version)


def test_failed_quantum_job(aws_session, capsys):
def test_failed_quantum_job(aws_session, capsys, failed_quantum_job):
"""Asserts the hybrid job is failed with the output, checkpoints,
quantum tasks not created in bucket and only input is uploaded to s3. Validate the
results/download results have the response raising RuntimeError. Also,
check if the logs displays the Assertion Error.
"""

job = AwsQuantumJob.create(
"arn:aws:braket:::device/quantum-simulator/amazon/sv1",
source_module="test/integ_tests/job_test_script.py",
entry_point="job_test_script:start_here",
aws_session=aws_session,
wait_until_complete=True,
hyperparameters={"test_case": "failed"},
)
job = failed_quantum_job
job_name = job.name

pattern = f"^arn:aws:braket:{aws_session.region}:\\d{{12}}:job/[a-z0-9-]+$"
assert re.match(pattern=pattern, string=job.arn)

# Check job is in failed state.
assert job.state() == "FAILED"
while True:
time.sleep(5)
if job.state() in AwsQuantumJob.TERMINAL_STATES:
break
assert job.state(use_cached_value=True) == "FAILED"

# Check whether the respective folder with files are created for script,
# output, tasks and checkpoints.
Expand Down Expand Up @@ -97,27 +95,22 @@ def test_failed_quantum_job(aws_session, capsys):
)


def test_completed_quantum_job(aws_session, capsys):
def test_completed_quantum_job(aws_session, capsys, completed_quantum_job):
"""Asserts the hybrid job is completed with the output, checkpoints, quantum tasks and
script folder created in S3 for respective hybrid job. Validate the results are
downloaded and results are what we expect. Also, assert that logs contains all the
necessary steps for setup and running the hybrid job and is displayed to the user.
"""

job = AwsQuantumJob.create(
"arn:aws:braket:::device/quantum-simulator/amazon/sv1",
source_module="test/integ_tests/job_test_script.py",
entry_point="job_test_script:start_here",
wait_until_complete=True,
aws_session=aws_session,
hyperparameters={"test_case": "completed"},
)

job = completed_quantum_job
job_name = job.name
pattern = f"^arn:aws:braket:{aws_session.region}:\\d{{12}}:job/[a-z0-9-]+$"
assert re.match(pattern=pattern, string=job.arn)

# check job is in completed state.
assert job.state() == "COMPLETED"
# Check the job has completed
job.result()

assert job.state(use_cached_value=True) == "COMPLETED"

# Check whether the respective folder with files are created for script,
# output, tasks and checkpoints.
Expand Down Expand Up @@ -179,19 +172,11 @@ def test_completed_quantum_job(aws_session, capsys):
== expected_data
)

# Check downloaded results exists in the file system after the call.
downloaded_result = f"{job_name}/{AwsQuantumJob.RESULTS_FILENAME}"
current_dir = Path.cwd()

with tempfile.TemporaryDirectory() as temp_dir:
os.chdir(temp_dir)
try:
job.download_result()
assert (
Path(AwsQuantumJob.RESULTS_TAR_FILENAME).exists()
and Path(downloaded_result).exists()
)

# Check results match the expectations.
assert job.result() == {"converged": True, "energy": -0.2}
finally:
Expand Down
2 changes: 1 addition & 1 deletion test/integ_tests/test_density_matrix_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from braket.aws import AwsDevice
from braket.circuits import Circuit, Noise, Observable

SHOTS = 1000
SHOTS = 500
DM1_ARN = "arn:aws:braket:::device/quantum-simulator/amazon/dm1"
SIMULATOR_ARNS = [DM1_ARN]

Expand Down
33 changes: 16 additions & 17 deletions test/integ_tests/test_device_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@
@pytest.mark.parametrize(
"arn", [(RIGETTI_ARN), (IONQ_ARN), (OQC_ARN), (SIMULATOR_ARN), (PULSE_ARN)]
)
def test_device_creation(arn, aws_session):
device = AwsDevice(arn, aws_session=aws_session)
def test_device_creation(arn, created_braket_devices):
device = created_braket_devices[arn]
assert device.arn == arn
assert device.name
assert device.status
Expand All @@ -39,17 +39,17 @@ def test_device_creation(arn, aws_session):


@pytest.mark.parametrize("arn", [(PULSE_ARN)])
def test_device_pulse_properties(arn, aws_session):
device = AwsDevice(arn, aws_session=aws_session)
def test_device_pulse_properties(arn, aws_session, created_braket_devices):
device = created_braket_devices[arn]
assert device.ports
assert device.frames


def test_device_across_regions(aws_session):
def test_device_across_regions(aws_session, created_braket_devices):
# assert QPUs across different regions can be created using the same aws_session
AwsDevice(RIGETTI_ARN, aws_session)
AwsDevice(IONQ_ARN, aws_session)
AwsDevice(OQC_ARN, aws_session)
created_braket_devices[RIGETTI_ARN]
created_braket_devices[IONQ_ARN]
created_braket_devices[OQC_ARN]


@pytest.mark.parametrize("arn", [(RIGETTI_ARN), (IONQ_ARN), (OQC_ARN), (SIMULATOR_ARN)])
Expand All @@ -59,8 +59,8 @@ def test_get_devices_arn(arn):


@pytest.mark.parametrize("arn", [(PULSE_ARN)])
def test_device_gate_calibrations(arn, aws_session):
device = AwsDevice(arn, aws_session=aws_session)
def test_device_gate_calibrations(arn, aws_session, created_braket_devices):
device = created_braket_devices[arn]
assert device.gate_calibrations


Expand All @@ -76,8 +76,8 @@ def test_get_devices_others():
assert result.status in statuses


def test_get_devices_all():
result_arns = [result.arn for result in AwsDevice.get_devices()]
def test_get_devices_all(braket_devices):
result_arns = [result.arn for result in braket_devices]
for arn in [RIGETTI_ARN, IONQ_ARN, SIMULATOR_ARN, OQC_ARN]:
assert arn in result_arns

Expand Down Expand Up @@ -127,17 +127,16 @@ def _validate_device(device: AwsDevice, active_providers: Set[str]):
assert getattr(getattr(Devices, provider_name), device_name) == device.arn


def test_device_enum():
aws_devices = AwsDevice.get_devices()
active_providers = _get_active_providers(aws_devices)
def test_device_enum(braket_devices, created_braket_devices):
active_providers = _get_active_providers(braket_devices)

# validate all devices in API
for device in aws_devices:
for device in braket_devices:
_validate_device(device, active_providers)

# validate all devices in enum
providers = [getattr(Devices, attr) for attr in dir(Devices) if not attr.startswith("__")]
for provider in providers:
for device_arn in provider:
device = AwsDevice(device_arn)
device = created_braket_devices[device_arn]
_validate_device(device, active_providers)
17 changes: 7 additions & 10 deletions test/integ_tests/test_queue_information.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.

from braket.aws import AwsDevice, AwsQuantumJob

from braket.aws import AwsDevice
from braket.aws.queue_information import (
HybridJobQueueInfo,
QuantumTaskQueueInfo,
Expand Down Expand Up @@ -47,15 +48,11 @@ def test_task_queue_position():
assert queue_information.message is None


def test_job_queue_position(aws_session):
job = AwsQuantumJob.create(
device=Devices.Amazon.SV1,
source_module="test/integ_tests/job_test_script.py",
entry_point="job_test_script:start_here",
aws_session=aws_session,
wait_until_complete=True,
hyperparameters={"test_case": "completed"},
)
def test_job_queue_position(aws_session, completed_quantum_job):
job = completed_quantum_job

# Check the job is complete
job.result()

# call the queue_position method.
queue_information = job.queue_position()
Expand Down
Loading

0 comments on commit 0d23ac1

Please sign in to comment.