diff --git a/.github/workflows/mypytmp.yml b/.github/workflows/mypytmp.yml new file mode 100644 index 000000000..d16f174fc --- /dev/null +++ b/.github/workflows/mypytmp.yml @@ -0,0 +1,58 @@ +name: mypytmp + +on: + push: + branches: [ basic-mypy-infrastructure ] + pull_request: + branches: [ master ] + +jobs: + build: + + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: ["3.7", "3.8", "3.9", "3.10"] + + steps: + - uses: actions/checkout@v3 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + + # pip cache dependencies to save time + - uses: actions/cache@v3 + if: startsWith(runner.os, 'Linux') + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pip-${{ hashFiles('**/setup.py') }} + restore-keys: | + ${{ runner.os }}-pip- + - name: Get pip cache dir + id: pip-cache + run: | + echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT + - name: pip cache + uses: actions/cache@v3 + with: + path: ${{ steps.pip-cache.outputs.dir }} + key: ${{ runner.os }}-pip-${{ hashFiles('**/setup.py') }} + restore-keys: | + ${{ runner.os }}-pip- + + # - name: Install OS prerequisites + # if: startsWith(runner.os, 'Linux') + # run: sudo apt-get install ripgrep + + - name: Install dependencies + if: steps.cache.outputs.cache-hit != 'true' + run: | + python -u -m pip install --upgrade pip setuptools wheel + pip install -r server/requirements3.7.txt + pip install --upgrade -e . + + - name: Run mypy + run: | + mypy Collector/ CovReporter/ EC2Reporter/ FTB/ Reporter/ TaskStatusReporter/ misc/ server/ setup.py diff --git a/.gitignore b/.gitignore index b96ec0eee..d01d95720 100644 --- a/.gitignore +++ b/.gitignore @@ -37,3 +37,7 @@ coverage/ # NPM server/frontend/node_modules server/frontend/dist + +# IDE / Editor +.vscode/ +.idea/ diff --git a/Collector/Collector.py b/Collector/Collector.py index 209acec5e..93fbb64a2 100755 --- a/Collector/Collector.py +++ b/Collector/Collector.py @@ -15,6 +15,9 @@ @contact: choller@mozilla.com """ + +from __future__ import annotations + import argparse import base64 import hashlib @@ -22,29 +25,58 @@ import os import shutil import sys +from collections.abc import Iterator from tempfile import mkstemp +from typing import Dict, cast from zipfile import ZipFile +from typing_extensions import NotRequired, TypedDict + from FTB.ProgramConfiguration import ProgramConfiguration from FTB.Running.AutoRunner import AutoRunner from FTB.Signatures.CrashInfo import CrashInfo from FTB.Signatures.CrashSignature import CrashSignature from Reporter.Reporter import Reporter, remote_checks, signature_checks -__all__ = [] +__all__: list[str] = [] __version__ = 0.1 __date__ = "2014-10-01" __updated__ = "2014-10-01" +class DataType(TypedDict): + """Type information for the data dictionary.""" + + rawStdout: NotRequired[str] + rawStderr: NotRequired[str] + rawCrashData: NotRequired[str] + testcase: NotRequired[bytes | str] + testcase_isbinary: NotRequired[bool] + testcase_quality: NotRequired[int] + testcase_ext: NotRequired[str] + testcase_size: NotRequired[int] + platform: NotRequired[str] + product: NotRequired[str] + product_version: NotRequired[str] + os: NotRequired[str] + client: NotRequired[str] + tool: NotRequired[str] + metadata: NotRequired[str] + env: NotRequired[str] + args: NotRequired[str] + + class Collector(Reporter): @remote_checks @signature_checks - def refresh(self): + def refresh(self) -> None: """ Refresh signatures by contacting the server, downloading new signatures and invalidating old ones. """ + assert self.serverHost is not None + assert self.serverPort is not None + assert self.serverProtocol is not None url = "%s://%s:%d/crashmanager/rest/signatures/download/" % ( self.serverProtocol, self.serverHost, @@ -62,12 +94,13 @@ def refresh(self): os.remove(zipFileName) @signature_checks - def refreshFromZip(self, zipFileName): + def refreshFromZip(self, zipFileName: str) -> None: """ Refresh signatures from a local zip file, adding new signatures and invalidating old ones. (This is a non-standard use case; you probably want to use refresh() instead.) """ + assert self.sigCacheDir is not None with ZipFile(zipFileName, "r") as zipFile: if zipFile.testzip(): raise RuntimeError(f"Bad CRC for downloaded zipfile {zipFileName}") @@ -88,36 +121,30 @@ def refreshFromZip(self, zipFileName): @remote_checks def submit( self, - crashInfo, - testCase=None, - testCaseQuality=0, - testCaseSize=None, - metaData=None, - ): + crashInfo: CrashInfo, + testCase: str | None = None, + testCaseQuality: int = 0, + testCaseSize: int | None = None, + metaData: dict[str, object] | None = None, + ) -> dict[str, object]: """ Submit the given crash information and an optional testcase/metadata to the server for processing and storage. - @type crashInfo: CrashInfo @param crashInfo: CrashInfo instance obtained from L{CrashInfo.fromRawCrashData} - - @type testCase: string @param testCase: A file containing a testcase for reproduction - - @type testCaseQuality: int @param testCaseQuality: A value indicating the quality of the test (less is better) - - @type testCaseSize: int or None @param testCaseSize: The size of the testcase to report. If None, use the file size. - - @type metaData: map @param metaData: A map containing arbitrary (application-specific) data which will be stored on the server in JSON format. This metadata is combined with possible metadata stored in the L{ProgramConfiguration} inside crashInfo. """ + assert self.serverHost is not None + assert self.serverPort is not None + assert self.serverProtocol is not None url = "%s://%s:%d/crashmanager/rest/crashes/" % ( self.serverProtocol, self.serverHost, @@ -126,7 +153,7 @@ def submit( # Serialize our crash information, testcase and metadata into a dictionary to # POST - data = {} + data: DataType = {} data["rawStdout"] = os.linesep.join(crashInfo.rawStdout) data["rawStderr"] = os.linesep.join(crashInfo.rawStderr) @@ -147,6 +174,7 @@ def submit( data["testcase_size"] = testCaseSize data["testcase_ext"] = os.path.splitext(testCase)[1].lstrip(".") + assert crashInfo.configuration is not None data["platform"] = crashInfo.configuration.platform data["product"] = crashInfo.configuration.product data["os"] = crashInfo.configuration.os @@ -154,11 +182,13 @@ def submit( if crashInfo.configuration.version: data["product_version"] = crashInfo.configuration.version + assert self.clientId is not None + assert self.tool is not None data["client"] = self.clientId data["tool"] = self.tool if crashInfo.configuration.metadata or metaData: - aggrMetaData = {} + aggrMetaData: dict[str, object] = {} if crashInfo.configuration.metadata: aggrMetaData.update(crashInfo.configuration.metadata) @@ -174,22 +204,21 @@ def submit( if crashInfo.configuration.args: data["args"] = json.dumps(crashInfo.configuration.args) - return self.post(url, data).json() + return cast(Dict[str, object], self.post(url, data).json()) @signature_checks - def search(self, crashInfo): + def search( + self, crashInfo: CrashInfo + ) -> tuple[str | None, dict[str, object] | None]: """ Searches within the local signature cache directory for a signature matching the given crash. - @type crashInfo: CrashInfo @param crashInfo: CrashInfo instance obtained from L{CrashInfo.fromRawCrashData} - - @rtype: tuple @return: Tuple containing filename of the signature and metadata matching, or None if no match. """ - + assert self.sigCacheDir is not None cachedSigFiles = os.listdir(self.sigCacheDir) for sigFile in cachedSigFiles: @@ -215,27 +244,20 @@ def search(self, crashInfo): @signature_checks def generate( self, - crashInfo, - forceCrashAddress=None, - forceCrashInstruction=None, - numFrames=None, - ): + crashInfo: CrashInfo, + forceCrashAddress: bool = False, + forceCrashInstruction: bool = False, + numFrames: int = 8, + ) -> str | None: """ Generates a signature in the local cache directory. It will be deleted when L{refresh} is called on the same local cache directory. - @type crashInfo: CrashInfo @param crashInfo: CrashInfo instance obtained from L{CrashInfo.fromRawCrashData} - - @type forceCrashAddress: bool @param forceCrashAddress: Force including the crash address into the signature - @type forceCrashInstruction: bool @param forceCrashInstruction: Force including the crash instruction into the signature (GDB only) - @type numFrames: int @param numFrames: How many frames to include in the signature - - @rtype: string @return: File containing crash signature in JSON format """ @@ -250,17 +272,17 @@ def generate( return self.__store_signature_hashed(sig) @remote_checks - def download(self, crashId): + def download(self, crashId: int) -> tuple[str, dict[str, str]] | None: """ Download the testcase for the specified crashId. - @type crashId: int @param crashId: ID of the requested crash entry on the server side - - @rtype: tuple @return: Tuple containing name of the file where the test was stored and the raw JSON response """ + assert self.serverHost is not None + assert self.serverPort is not None + assert self.serverProtocol is not None url = "%s://%s:%d/crashmanager/rest/crashes/%s/" % ( self.serverProtocol, self.serverHost, @@ -295,17 +317,19 @@ def download(self, crashId): return (local_filename, resp_json) @remote_checks - def download_all(self, bucketId): + def download_all(self, bucketId: int) -> Iterator[str]: """ Download all testcases for the specified bucketId. - @type bucketId: int @param bucketId: ID of the requested bucket on the server side - - @rtype: generator @return: generator of filenames where tests were stored. """ - params = {"query": json.dumps({"op": "OR", "bucket": bucketId})} + assert self.serverHost is not None + assert self.serverPort is not None + assert self.serverProtocol is not None + params: dict[str, str] | None = { + "query": json.dumps({"op": "OR", "bucket": bucketId}) + } next_url = "%s://%s:%d/crashmanager/rest/crashes/" % ( self.serverProtocol, self.serverHost, @@ -349,17 +373,14 @@ def download_all(self, bucketId): yield local_filename - def __store_signature_hashed(self, signature): + def __store_signature_hashed(self, signature: CrashSignature) -> str: """ Store a signature, using the sha1 hash hex representation as filename. - @type signature: CrashSignature @param signature: CrashSignature to store - - @rtype: string @return: Name of the file that the signature was written to - """ + assert self.sigCacheDir is not None h = hashlib.new("sha1") if str is bytes: h.update(str(signature)) @@ -372,14 +393,11 @@ def __store_signature_hashed(self, signature): return sigfile @staticmethod - def read_testcase(testCase): + def read_testcase(testCase: str) -> tuple[bytes, bool]: """ Read a testcase file, return the content and indicate if it is binary or not. - @type testCase: string @param testCase: Filename of the file to open - - @rtype: tuple(string, bool) @return: Tuple containing the file contents and a boolean indicating if the content is binary """ @@ -393,7 +411,7 @@ def read_testcase(testCase): return (testCaseData, isBinary) -def main(args=None): +def main(args: list[str] | None = None) -> int: """Command line options.""" # setup argparser @@ -605,10 +623,11 @@ def main(args=None): crashInfo = None args = None env = None - metadata = {} + metadata: dict[str, object] | None = {} if opts.search or opts.generate or opts.submit or opts.autosubmit: if opts.metadata: + assert metadata is not None metadata.update(dict(kv.split("=", 1) for kv in opts.metadata)) if opts.autosubmit: @@ -738,6 +757,7 @@ def main(args=None): if opts.autosubmit: runner = AutoRunner.fromBinaryArgs(opts.rargs[0], opts.rargs[1:]) if runner.run(): + assert configuration is not None crashInfo = runner.getCrashInfo(configuration) collector.submit( crashInfo, testcase, opts.testcasequality, opts.testcasesize, metadata @@ -750,13 +770,18 @@ def main(args=None): return 1 if opts.download: - (retFile, retJSON) = collector.download(opts.download) + collector_download_ret_val = collector.download(opts.download) + if collector_download_ret_val: + (retFile, retJSON) = collector_download_ret_val + else: + raise AssertionError("collector.download function returned a None") if not retFile: print("Specified crash entry does not have a testcase", file=sys.stderr) return 1 if "args" in retJSON and retJSON["args"]: args = json.loads(retJSON["args"]) + assert args is not None print( "Command line arguments:", " ".join(args), @@ -765,6 +790,7 @@ def main(args=None): if "env" in retJSON and retJSON["env"]: env = json.loads(retJSON["env"]) + assert env is not None print( "Environment variables:", " ".join(f"{k} = {v}" for (k, v) in env.items()), @@ -773,6 +799,7 @@ def main(args=None): if "metadata" in retJSON and retJSON["metadata"]: metadata = json.loads(retJSON["metadata"]) + assert metadata is not None print("== Metadata ==") for k, v in metadata.items(): print(f"{k} = {v}") @@ -798,6 +825,8 @@ def main(args=None): print(collector.clientId) return 0 + return 0 + if __name__ == "__main__": sys.exit(main()) diff --git a/Collector/tests/test_Collector.py b/Collector/tests/test_Collector.py index 0ea43cc97..a30e78e5d 100644 --- a/Collector/tests/test_Collector.py +++ b/Collector/tests/test_Collector.py @@ -11,6 +11,9 @@ @contact: choller@mozilla.com """ + +from __future__ import annotations + import json import os import platform @@ -21,6 +24,8 @@ import pytest import requests +from django.contrib.auth.models import User +from pytest_django.live_server_helper import LiveServer from Collector.Collector import Collector, main from crashmanager.models import CrashEntry @@ -38,7 +43,7 @@ pytest_plugins = ("server.tests",) -def test_collector_help(capsys): +def test_collector_help(capsys: pytest.CaptureFixture[str]) -> None: """Test that help prints without throwing""" with pytest.raises(SystemExit): main() @@ -48,7 +53,9 @@ def test_collector_help(capsys): @patch("os.path.expanduser") @patch("time.sleep", new=Mock()) -def test_collector_submit(mock_expanduser, live_server, tmp_path, fm_user): +def test_collector_submit( + mock_expanduser: Mock, live_server: LiveServer, tmp_path: Path, fm_user: User +) -> None: """Test crash submission""" mock_expanduser.side_effect = lambda path: str( tmp_path @@ -100,6 +107,7 @@ def test_collector_submit(mock_expanduser, live_server, tmp_path, fm_user): assert entry.args == "" # create a test config + assert url.port is not None with (tmp_path / ".fuzzmanagerconf").open("w") as fp: fp.write("[Main]\n") fp.write(f"serverhost = {url.hostname}\n") @@ -120,7 +128,7 @@ def test_collector_submit(mock_expanduser, live_server, tmp_path, fm_user): crashdata_path = tmp_path / "crashdata.txt" with crashdata_path.open("w") as fp: fp.write(asan_trace_crash) - result = main( + result_return_code = main( [ "--submit", "--tool", @@ -153,7 +161,7 @@ def test_collector_submit(mock_expanduser, live_server, tmp_path, fm_user): str(crashdata_path), ] ) - assert result == 0 + assert result_return_code == 0 entry = CrashEntry.objects.get( pk__gt=entry.id ) # newer than the last result, will fail if the test db is active @@ -185,7 +193,7 @@ class response_t: collector.submit(crashInfo, str(testcase_path)) -def test_collector_refresh(capsys, tmp_path): +def test_collector_refresh(capsys: pytest.CaptureFixture[str], tmp_path: Path) -> None: """Test signature downloads""" # create a test signature zip test2_path = tmp_path / "test2.signature" @@ -202,16 +210,18 @@ def test_collector_refresh(capsys, tmp_path): (sigs_path / "other.txt").touch() assert {f.name for f in sigs_path.iterdir()} == {"test1.signature", "other.txt"} - with outzip_path.open("rb") as fp: + with outzip_path.open("rb") as fp2: class response_t: status_code = requests.codes["ok"] text = "OK" - raw = fp + raw = fp2 # this asserts the expected arguments and returns the open handle to out.zip as # 'raw' which is read by refresh() - def myget(url, stream=None, headers=None): + def myget( + url: str, stream: bool | None = None, headers: dict[str, str] | None = None + ) -> response_t: assert url == "gopher://aol.com:70/crashmanager/rest/signatures/download/" assert stream is True assert headers == {"Authorization": "Token token"} @@ -251,28 +261,28 @@ class response_t: # noqa collector.refresh() # check that bad zips raise errors - with (sigs_path / "other.txt").open("rb") as fp: + with (sigs_path / "other.txt").open("rb") as fp3: class response_t: # noqa status_code = requests.codes["ok"] text = "OK" - raw = fp + raw = fp3 collector._session.get = lambda *_, **__: response_t() with pytest.raises(zipfile.BadZipfile, match="not a zip file"): collector.refresh() - with outzip_path.open("r+b") as fp: + with outzip_path.open("r+b") as fp4: # corrupt the CRC field for the signature file in the zip - fp.seek(0x42) - fp.write(b"\xFF") - with outzip_path.open("rb") as fp: + fp4.seek(0x42) + fp4.write(b"\xFF") + with outzip_path.open("rb") as fp5: class response_t: # noqa status_code = requests.codes["ok"] text = "OK" - raw = fp + raw = fp5 collector._session.get = lambda *_, **__: response_t() @@ -280,7 +290,7 @@ class response_t: # noqa collector.refresh() -def test_collector_generate_search(tmp_path): +def test_collector_generate_search(tmp_path: Path) -> None: """Test sigcache generation and search""" # create a cache dir cache_dir = tmp_path / "sigcache" @@ -304,6 +314,7 @@ def test_collector_generate_search(tmp_path): assert meta is None # write metadata and make sure that's returned if it exists + assert sig is not None sigBase, _ = os.path.splitext(sig) with open(sigBase + ".metadata", "w") as f: f.write("{}") @@ -322,7 +333,7 @@ def test_collector_generate_search(tmp_path): assert result is None -def test_collector_download(tmp_path, monkeypatch): +def test_collector_download(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: """Test testcase downloads""" # create Collector collector = Collector( @@ -337,7 +348,7 @@ class response1_t: status_code = requests.codes["ok"] text = "OK" - def json(self): + def json(self) -> dict[str, object]: return {"id": 123, "testcase": "path/to/testcase.txt"} class response2_t: @@ -347,7 +358,7 @@ class response2_t: content = b"testcase\xFF" # myget1 mocks requests.get to return the rest response to the crashentry get - def myget1(url, headers=None): + def myget1(url: str, headers: dict[str, str] | None = None) -> response1_t: assert url == "gopher://aol.com:70/crashmanager/rest/crashes/123/" assert headers == {"Authorization": "Token token"} @@ -356,7 +367,7 @@ def myget1(url, headers=None): return response1_t() # myget2 mocks requests.get to return the testcase data specified in myget1 - def myget2(url, headers=None): + def myget2(url: str, headers: dict[str, str] | None = None) -> response2_t: assert url == "gopher://aol.com:70/crashmanager/rest/crashes/123/download/" assert headers == {"Authorization": "Token token"} return response2_t() @@ -385,7 +396,7 @@ class response1_t: # noqa status_code = requests.codes["ok"] text = "OK" - def json(self): + def json(self) -> dict[str, str]: return {"testcase": ""} collector._session.get = myget1 @@ -397,7 +408,7 @@ class response1_t: # noqa status_code = requests.codes["ok"] text = "OK" - def json(self): + def json(self) -> list[str]: return [] collector._session.get = myget1 diff --git a/CovReporter/CovReporter.py b/CovReporter/CovReporter.py index 258ffe787..b874684c0 100755 --- a/CovReporter/CovReporter.py +++ b/CovReporter/CovReporter.py @@ -15,15 +15,19 @@ @contact: choller@mozilla.com """ + +from __future__ import annotations + import argparse import json import os import sys +from typing import Iterable, Mapping from FTB import CoverageHelper from Reporter.Reporter import Reporter, remote_checks -__all__ = [] +__all__: list[str] = [] __version__ = 0.1 __date__ = "2017-07-10" __updated__ = "2017-07-10" @@ -32,30 +36,24 @@ class CovReporter(Reporter): def __init__( self, - serverHost=None, - serverPort=None, - serverProtocol=None, - serverAuthToken=None, - clientId=None, - tool=None, - repository=None, - ): + serverHost: str | None = None, + serverPort: int | None = None, + serverProtocol: str | None = None, + serverAuthToken: str | None = None, + clientId: str | None = None, + tool: str | None = None, + repository: str | None = None, + ) -> None: """ Initialize the Reporter. This constructor will also attempt to read a configuration file to populate any missing properties that have not been passed to this constructor. - @type serverHost: string @param serverHost: Server host to contact for refreshing signatures - @type serverPort: int @param serverPort: Server port to use when contacting server - @type serverAuthToken: string @param serverAuthToken: Token for server authentication - @type clientId: string @param clientId: Client ID stored in the server when submitting - @type tool: string @param tool: Name of the tool that created this coverage - @type repository: string @param repository: Name of the repository that this coverage was measured on """ @@ -75,33 +73,27 @@ def __init__( @remote_checks def submit( - self, coverage, preprocessed=False, version=None, description="", stats=None - ): + self, + coverage: Mapping[str, object], + preprocessed: bool = False, + version: dict[str, str] | None = None, + description: str = "", + stats: dict[str, str] | None = None, + ) -> None: """ Send coverage data to server. - @type coverage: dict @param coverage: Coverage Data - - @type covformat: int - @param covformat: Format of the coverage data (COVERALLS or COVMAN). - - @type version: dict @param version: A dictionary containing keys 'revision' and 'branch', just as returned by version_info_from_coverage_data. If left empty, the implementation will attempt to extract the information from the coverage data itself. - - @type description: string @param description: Optional description for this coverage data - - @type stats: dict @param stats: An optional stats object as returned by create_combined_coverage """ - url = "{}://{}:{}/covmanager/rest/collections/".format( - self.serverProtocol, - self.serverHost, - self.serverPort, + url = ( + f"{self.serverProtocol}://{self.serverHost}:{self.serverPort}" + "/covmanager/rest/collections/" ) if version is None: @@ -139,17 +131,16 @@ def submit( self.post(url, data) @staticmethod - def preprocess_coverage_data(coverage): + def preprocess_coverage_data( + coverage: Mapping[str, object] + ) -> Mapping[str, object]: """ Preprocess the given coverage data. Preprocessing includes structuring the coverage data by directory for better performance as well as computing coverage summaries per directory. - @type coverage: dict @param coverage: Coverage Data - - @rtype dict @return Preprocessed Coverage Data """ @@ -159,6 +150,7 @@ def preprocess_coverage_data(coverage): # Coveralls format source_files = coverage["source_files"] + assert isinstance(source_files, Iterable) # Process every source file and store the coverage data in our tree # structure for source_file in source_files: @@ -197,7 +189,7 @@ def preprocess_coverage_data(coverage): return ret @staticmethod - def version_info_from_coverage_data(coverage): + def version_info_from_coverage_data(coverage) -> dict[str, str]: """ Extract various version fields from the given coverage data. @@ -210,11 +202,8 @@ def version_info_from_coverage_data(coverage): revision branch - @type coverage: string @param coverage: Coverage Data - @return Dictionary with version data - @rtype dict """ ret = {} @@ -227,19 +216,17 @@ def version_info_from_coverage_data(coverage): raise RuntimeError("Unknown coverage format") @staticmethod - def create_combined_coverage(coverage_files, version=None): + def create_combined_coverage( + coverage_files: list[int | str], version: dict[str, str] | None = None + ): """ Read coverage data from multiple files and return a single dictionary containing the merged data (already preprocessed). - @type coverage_files: list @param coverage_files: List of filenames containing coverage data - @type version: dict @param version: Dictionary containing branch and revision - @return Dictionary with combined coverage data, version information and debug statistics - @rtype tuple(dict,dict,dict) """ ret = None stats = None @@ -271,7 +258,7 @@ def create_combined_coverage(coverage_files, version=None): return (ret, version, stats) -def main(argv=None): +def main(argv: list[str] | None = None) -> int: """Command line options.""" # setup argparser diff --git a/CovReporter/tests/test_CovReporter.py b/CovReporter/tests/test_CovReporter.py index 4a4d9b68a..bb2e95c3b 100644 --- a/CovReporter/tests/test_CovReporter.py +++ b/CovReporter/tests/test_CovReporter.py @@ -11,6 +11,9 @@ @contact: choller@mozilla.com """ + +from __future__ import annotations + import json import os import tempfile @@ -21,14 +24,14 @@ FIXTURE_PATH = Path(__file__).parent / "fixtures" -def test_CovReporterCoverallsVersionData(): +def test_CovReporterCoverallsVersionData() -> None: coveralls_data = json.loads((FIXTURE_PATH / "coveralls_data.json").read_text()) ret = CovReporter.version_info_from_coverage_data(coveralls_data) assert ret["revision"] == "1a0d9545b9805f50a70de703a3c04fc0d22e3839" assert ret["branch"] == "master" -def test_CovReporterPreprocessData(): +def test_CovReporterPreprocessData() -> None: coveralls_data = json.loads((FIXTURE_PATH / "coveralls_data.json").read_text()) result = CovReporter.preprocess_coverage_data(coveralls_data) @@ -41,6 +44,7 @@ def test_CovReporterPreprocessData(): coveragePercent = "coveragePercent" # Check that we have all the topdirs + assert isinstance(result, dict) assert "topdir1" in result[children], "topdir1 missing in result" assert "topdir2" in result[children], "topdir2 missing in result" @@ -128,7 +132,7 @@ def test_CovReporterPreprocessData(): ) -def test_CovReporterMergeData(): +def test_CovReporterMergeData() -> None: # result = CovReporter.preprocess_coverage_data(coverallsData) # result2 = CovReporter.preprocess_coverage_data(coverallsAddData) @@ -157,6 +161,7 @@ def test_CovReporterMergeData(): os.remove(cov_file1) os.remove(cov_file2) + assert isinstance(version, dict) assert version["revision"] == "1a0d9545b9805f50a70de703a3c04fc0d22e3839" assert version["branch"] == "master" @@ -169,6 +174,7 @@ def test_CovReporterMergeData(): coveragePercent = "coveragePercent" # Check that we have all the topdirs + assert isinstance(result, dict) assert "topdir1" in result[children], "topdir1 missing in result" assert "topdir2" in result[children], "topdir2 missing in result" assert "topdir3" in result[children], "topdir2 missing in result" diff --git a/EC2Reporter/EC2Reporter.py b/EC2Reporter/EC2Reporter.py index 8c1c93214..a1f206faa 100755 --- a/EC2Reporter/EC2Reporter.py +++ b/EC2Reporter/EC2Reporter.py @@ -15,12 +15,16 @@ @contact: choller@mozilla.com """ + +from __future__ import annotations + import argparse import functools import os import random import sys import time +from typing import Any import requests from fasteners import InterProcessLock @@ -28,7 +32,7 @@ from FTB.ConfigurationFiles import ConfigurationFiles # noqa from Reporter.Reporter import Reporter, remote_checks -__all__ = [] +__all__: list[str] = [] __version__ = 0.1 __date__ = "2014-10-01" __updated__ = "2014-10-01" @@ -36,18 +40,17 @@ class EC2Reporter(Reporter): @functools.wraps(Reporter.__init__) - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: kwargs.setdefault( "tool", "N/A" ) # tool is required by remote_checks, but unused by EC2Reporter super().__init__(*args, **kwargs) @remote_checks - def report(self, text): + def report(self, text: str) -> None: """ Send textual report to server, overwriting any existing reports. - @type text: string @param text: Report text to send """ url = "{}://{}:{}/ec2spotmanager/rest/report/".format( @@ -65,11 +68,10 @@ def report(self, text): self.post(url, data) @remote_checks - def cycle(self, poolid): + def cycle(self, poolid: int) -> None: """ Cycle the pool with the given id. - @type poolid: int @param poolid: ID of the pool to cycle """ url = "{}://{}:{}/ec2spotmanager/rest/pool/{}/cycle/".format( @@ -82,11 +84,10 @@ def cycle(self, poolid): self.post(url, {}, expected=requests.codes["ok"]) @remote_checks - def disable(self, poolid): + def disable(self, poolid: int) -> None: """ Disable the pool with the given id. - @type poolid: int @param poolid: ID of the pool to disable """ url = "{}://{}:{}/ec2spotmanager/rest/pool/{}/disable/".format( @@ -99,11 +100,10 @@ def disable(self, poolid): self.post(url, {}, expected=requests.codes["ok"]) @remote_checks - def enable(self, poolid): + def enable(self, poolid: int) -> None: """ Enable the pool with the given id. - @type poolid: int @param poolid: ID of the pool to enable """ url = "{}://{}:{}/ec2spotmanager/rest/pool/{}/enable/".format( @@ -116,7 +116,7 @@ def enable(self, poolid): self.post(url, {}, expected=requests.codes["ok"]) -def main(argv=None): +def main(argv: list[str] | None = None) -> int: """Command line options.""" # setup argparser diff --git a/EC2Reporter/tests/test_EC2Reporter.py b/EC2Reporter/tests/test_EC2Reporter.py index 8ac7895b3..19ab40111 100644 --- a/EC2Reporter/tests/test_EC2Reporter.py +++ b/EC2Reporter/tests/test_EC2Reporter.py @@ -1,8 +1,13 @@ +from __future__ import annotations + +from pathlib import Path from unittest.mock import Mock, patch from urllib.parse import urlsplit import pytest +from django.contrib.auth.models import User from django.utils import timezone +from pytest_django.live_server_helper import LiveServer from EC2Reporter.EC2Reporter import EC2Reporter, main from ec2spotmanager.models import Instance, InstancePool @@ -12,7 +17,7 @@ pytest_plugins = "server.tests" -def test_ec2reporter_help(capsys): +def test_ec2reporter_help(capsys: pytest.CaptureFixture[str]) -> None: """Test that help prints without throwing""" with pytest.raises(SystemExit): main() @@ -22,7 +27,9 @@ def test_ec2reporter_help(capsys): @patch("os.path.expanduser") @patch("time.sleep", new=Mock()) -def test_ec2reporter_report(mock_expanduser, live_server, tmp_path, fm_user): +def test_ec2reporter_report( + mock_expanduser: Mock, live_server: LiveServer, tmp_path: Path, fm_user: User +) -> None: """Test report submission""" mock_expanduser.side_effect = lambda path: str( tmp_path @@ -68,7 +75,9 @@ def test_ec2reporter_report(mock_expanduser, live_server, tmp_path, fm_user): @patch("os.path.expanduser") @patch("time.sleep", new=Mock()) -def test_ec2reporter_xable(mock_expanduser, live_server, tmp_path, fm_user): +def test_ec2reporter_xable( + mock_expanduser: Mock, live_server: LiveServer, tmp_path: Path, fm_user: User +) -> None: """Test EC2Reporter enable/disable""" mock_expanduser.side_effect = lambda path: str( tmp_path @@ -113,7 +122,9 @@ def test_ec2reporter_xable(mock_expanduser, live_server, tmp_path, fm_user): @patch("os.path.expanduser") @patch("time.sleep", new=Mock()) -def test_ec2reporter_cycle(mock_expanduser, live_server, tmp_path, fm_user): +def test_ec2reporter_cycle( + mock_expanduser: Mock, live_server: LiveServer, tmp_path: Path, fm_user: User +) -> None: """Test EC2Reporter cycle""" mock_expanduser.side_effect = lambda path: str( tmp_path diff --git a/FTB/AssertionHelper.py b/FTB/AssertionHelper.py index 2a9d89051..b4db7dd74 100644 --- a/FTB/AssertionHelper.py +++ b/FTB/AssertionHelper.py @@ -14,6 +14,8 @@ @contact: choller@mozilla.com """ +from __future__ import annotations + import re RE_ASSERTION = re.compile(r"^ASSERTION \d+: \(.+\)") @@ -25,17 +27,16 @@ RE_V8_END = re.compile(r"^") -def getAssertion(output): +def getAssertion(output: list[str]) -> list[str] | str | None: """ This helper method provides a way to extract and process the different types of assertions from a given buffer. The problem here is that pretty much every software has its own type of assertions with different output formats. - @type output: list @param output: List of strings to be searched """ - lastLine = None + lastLine: list[str] | str | None = None endRegex = None # Use this to ignore the ASan head line in case of an assertion @@ -129,17 +130,16 @@ def getAssertion(output): return lastLine -def getAuxiliaryAbortMessage(output): +def getAuxiliaryAbortMessage(output: list[str]) -> list[str] | str | None: """ This helper method provides a way to extract and process additional abort messages or other useful messages produced by helper tools like sanitizers. These messages can be helpful in signatures if there is no abort message from the program itself. - @type output: list @param output: List of strings to be searched """ - lastLine = None + lastLine: list[str] | str | None = None needASanRW = False needTSanRW = False @@ -188,16 +188,14 @@ def getAuxiliaryAbortMessage(output): return lastLine -def getSanitizedAssertionPattern(msgs): +def getSanitizedAssertionPattern(msgs: list[str] | str | None) -> list[str] | str: """ This method provides a way to strip out unwanted dynamic information from assertions and replace it with pattern matching elements, e.g. for use in signature matching. - @type msgs: string or list @param msgs: Assertion message(s) to be sanitized - @rtype: string @return: Sanitized assertion message (regular expression) """ assert msgs is not None @@ -211,7 +209,7 @@ def getSanitizedAssertionPattern(msgs): for msg in msgs: # remember the position of all backslashes in the input - bsPositions = [] + bsPositions: list[int] = [] for chunk in msg.split("\\"): if not bsPositions: bsPositions.append(len(chunk)) @@ -271,7 +269,7 @@ def getSanitizedAssertionPattern(msgs): for replacementPattern in replacementPatterns: - def _handleMatch(match): + def _handleMatch(match: re.Match[str]) -> str: start = match.start(0) end = match.end(0) lengthDiff = len(replacementPattern) - len(match.group(0)) @@ -318,15 +316,12 @@ def _handleMatch(match): return sanitizedMsgs -def escapePattern(msg): +def escapePattern(msg: str) -> str: """ This method escapes regular expression characters in the string. And no, this is not re.escape, which would escape many more characters. - @type msg: string @param msg: String that needs to be quoted - - @rtype: string @return: Escaped string for use in regular expressions """ diff --git a/FTB/ConfigurationFiles.py b/FTB/ConfigurationFiles.py index c652cedf0..55dc29f8f 100755 --- a/FTB/ConfigurationFiles.py +++ b/FTB/ConfigurationFiles.py @@ -12,13 +12,16 @@ @contact: choller@mozilla.com """ + +from __future__ import annotations + import configparser import sys class ConfigurationFiles: - def __init__(self, configFiles): - self.mainConfig = {} + def __init__(self, configFiles: list[str]) -> None: + self.mainConfig: dict[str, str] = {} self.metadataConfig = {} if configFiles: @@ -45,7 +48,7 @@ def __init__(self, configFiles): file=sys.stderr, ) - def getSectionMap(self, section): + def getSectionMap(self, section: str): ret = {} try: options = self.parser.options(section) diff --git a/FTB/CoverageHelper.py b/FTB/CoverageHelper.py index 2bb34fb96..1e7f49cb0 100644 --- a/FTB/CoverageHelper.py +++ b/FTB/CoverageHelper.py @@ -11,10 +11,15 @@ @contact: choller@mozilla.com """ + +from __future__ import annotations + import re +from server.covmanager.models import Collection + -def merge_coverage_data(r, s): +def merge_coverage_data(r: Collection, s) -> dict[str, int]: # These variables are mainly for debugging purposes. We count the number # of warnings we encounter during merging, which are mostly due to # bugs in GCOV. These statistics can be included in the report description @@ -25,7 +30,7 @@ def merge_coverage_data(r, s): "coverable_mismatch_count": 0, } - def merge_recursive(r, s): + def merge_recursive(r: Collection, s): assert r["name"] == s["name"] if "children" in s: @@ -109,7 +114,7 @@ def merge_recursive(r, s): return stats -def calculate_summary_fields(node, name=None): +def calculate_summary_fields(node: Collection, name: str | None = None) -> None: node["name"] = name node["linesTotal"] = 0 node["linesCovered"] = 0 @@ -144,7 +149,7 @@ def calculate_summary_fields(node, name=None): node["coveragePercent"] = 0.0 -def apply_include_exclude_directives(node, directives): +def apply_include_exclude_directives(node: Collection, directives: list[str]) -> None: """ Applies the given include and exclude directives to the given nodeself. Directives either start with a + or a - for include or exclude, followed @@ -153,9 +158,7 @@ def apply_include_exclude_directives(node, directives): are forward slashes, must not have a trailing slash and glob characters are not allowed. ** is additionally supported for recursive directory matching. @param node: The coverage node to modify, in server-side recursive format - @type node: dict @param directives: The directives to apply - @type directives: list(str) This method modifies the node in-place, nothing is returned. IMPORTANT: This method does *not* recalculate any total/summary fields. You *must* call L{calculate_summary_fields} after applying @@ -202,13 +205,13 @@ def apply_include_exclude_directives(node, directives): # convert glob pattern to regex part = part.replace("\\*", ".*").replace("\\?", ".") # compile the resulting regex - parts.append(re.compile(part)) + parts.append(str(re.compile(part))) directives_new.append((what, parts)) - def _is_dir(node): + def _is_dir(node: Collection) -> bool: return "children" in node - def __apply_include_exclude_directives(node, directives): + def __apply_include_exclude_directives(node: Collection, directives) -> None: if not _is_dir(node): return @@ -332,7 +335,7 @@ def __apply_include_exclude_directives(node, directives): __apply_include_exclude_directives(node, directives_new) -def get_flattened_names(node, prefix=""): +def get_flattened_names(node: Collection, prefix: str = "") -> set[str | None]: """ Returns a list of flattened paths (files and directories) of the given node. @@ -340,16 +343,13 @@ def get_flattened_names(node, prefix=""): All slashes in paths will be forward slashes and not use any trailing slashes. @param node: The coverage node to process, in server-side recursive format - @type node: dict - @param prefix: An optional prefix to prepend to each name - @type prefix: str - @return The list of all paths occurring in the given node. - @rtype: list(str) """ - def __get_flattened_names(node, prefix, result): + def __get_flattened_names( + node: Collection, prefix: str, result: set[str | None] + ) -> set[str | None]: current_name = node["name"] if current_name is None: new_prefix = "" diff --git a/FTB/ProgramConfiguration.py b/FTB/ProgramConfiguration.py index aacb0b1ed..c857357c3 100644 --- a/FTB/ProgramConfiguration.py +++ b/FTB/ProgramConfiguration.py @@ -14,6 +14,9 @@ @contact: choller@mozilla.com """ + +from __future__ import annotations + import os import sys @@ -22,14 +25,18 @@ class ProgramConfiguration: def __init__( - self, product, platform, os, version=None, env=None, args=None, metadata=None + self, + product: str, + platform: str, + os: str, + version: str | None = None, + env: dict[str, str] | None = None, + args: list[str] | None = None, + metadata: dict[str, object] | None = None, ): """ - @type product: string @param product: The name of the product/program/branch tested - @type platform: string @param platform: Platform on which is tested (e.g. x86, x86-64 or arm) - @type os: string @param os: Operating system on which is tested (e.g. linux, windows, macosx) """ self.product = product.lower() @@ -55,7 +62,7 @@ def __init__( self.metadata = metadata @staticmethod - def fromBinary(binaryPath): + def fromBinary(binaryPath: str) -> ProgramConfiguration | None: binaryConfig = f"{binaryPath}.fuzzmanagerconf" if not os.path.exists(binaryConfig): print( @@ -86,33 +93,30 @@ def fromBinary(binaryPath): metadata=config.metadataConfig, ) - def addEnvironmentVariables(self, env): + def addEnvironmentVariables(self, env: dict[str, str]) -> None: """ Add (additional) environment variable definitions. Existing definitions will be overwritten if they are redefined in the given environment. - @type env: dict @param env: Dictionary containing the environment variables """ assert isinstance(env, dict) self.env.update(env) - def addProgramArguments(self, args): + def addProgramArguments(self, args: list[str]) -> None: """ Add (additional) program arguments. - @type args: list @param args: List containing the program arguments """ assert isinstance(args, list) self.args.extend(args) - def addMetadata(self, metadata): + def addMetadata(self, metadata: dict[str, object]) -> None: """ Add (additional) metadata definitions. Existing definitions will be overwritten if they are redefined in the given metadata. - @type metadata: dict @param metadata: Dictionary containing the metadata """ assert isinstance(metadata, dict) diff --git a/FTB/Running/AutoRunner.py b/FTB/Running/AutoRunner.py index 752be79b4..0af5093e3 100644 --- a/FTB/Running/AutoRunner.py +++ b/FTB/Running/AutoRunner.py @@ -12,6 +12,9 @@ @contact: choller@mozilla.com """ + +from __future__ import annotations + import os import re import signal @@ -20,6 +23,7 @@ from abc import ABCMeta from distutils import spawn +from FTB.ProgramConfiguration import ProgramConfiguration from FTB.Signatures.CrashInfo import CrashInfo @@ -29,7 +33,14 @@ class AutoRunner(metaclass=ABCMeta): for running the given program and obtaining crash information. """ - def __init__(self, binary, args=None, env=None, cwd=None, stdin=None): + def __init__( + self, + binary: str, + args: list[str] | None = None, + env: dict[str, str] | None = None, + cwd: str | None = None, + stdin: bytes | None = None, + ) -> None: self.binary = binary self.cwd = cwd self.stdin = stdin @@ -56,20 +67,26 @@ def __init__(self, binary, args=None, env=None, cwd=None, stdin=None): assert isinstance(self.args, list) # The command that we will run for obtaining crash information - self.cmdArgs = [] + self.cmdArgs: list[str | bytes] = [] # These will hold our results from running - self.stdout = None - self.stderr = None - self.auxCrashData = None + self.stdout: str | None = None + self.stderr: list[str] | str | None = None + self.auxCrashData: list[str] | str | None = None - def getCrashInfo(self, configuration): + def getCrashInfo(self, configuration: ProgramConfiguration) -> CrashInfo: return CrashInfo.fromRawCrashData( self.stdout, self.stderr, configuration, self.auxCrashData ) @staticmethod - def fromBinaryArgs(binary, args=None, env=None, cwd=None, stdin=None): + def fromBinaryArgs( + binary: str, + args: list[str] | None = None, + env: dict[str, str] | None = None, + cwd: str | None = None, + stdin: bytes | None = None, + ) -> ASanRunner | GDBRunner: process = subprocess.Popen( ["nm", "-g", binary], stdin=subprocess.PIPE, @@ -80,13 +97,13 @@ def fromBinaryArgs(binary, args=None, env=None, cwd=None, stdin=None): ) (stdout, _) = process.communicate() - stdout = stdout.decode("utf-8", errors="ignore") + stdout_decoded = stdout.decode("utf-8", errors="ignore") force_gdb = bool(os.environ.get("FTB_FORCE_GDB", False)) if not force_gdb and ( - stdout.find(" __asan_init") >= 0 - or stdout.find("__ubsan_default_options") >= 0 + stdout_decoded.find(" __asan_init") >= 0 + or stdout_decoded.find("__ubsan_default_options") >= 0 ): return ASanRunner(binary, args, env, cwd, stdin) @@ -94,7 +111,15 @@ def fromBinaryArgs(binary, args=None, env=None, cwd=None, stdin=None): class GDBRunner(AutoRunner): - def __init__(self, binary, args=None, env=None, cwd=None, core=None, stdin=None): + def __init__( + self, + binary: str, + args: list[str] | None = None, + env: dict[str, str] | None = None, + cwd: str | None = None, + core: bytes | None = None, + stdin: bytes | None = None, + ) -> None: AutoRunner.__init__(self, binary, args, env, cwd, stdin) # This can be used to force GDBRunner to first generate a core and then @@ -141,11 +166,13 @@ def __init__(self, binary, args=None, env=None, cwd=None, core=None, stdin=None) if core is not None: self.cmdArgs.append(core) else: + assert self.args is not None self.cmdArgs.extend(self.args) - def run(self): + def run(self) -> bool: if self.force_core: plainCmdArgs = [self.binary] + assert self.args is not None plainCmdArgs.extend(self.args) process = subprocess.Popen( @@ -214,10 +241,18 @@ def run(self): class ASanRunner(AutoRunner): - def __init__(self, binary, args=None, env=None, cwd=None, stdin=None): + def __init__( + self, + binary: str, + args: list[str] | None = None, + env: dict[str, str] | None = None, + cwd: str | None = None, + stdin: bytes | None = None, + ) -> None: AutoRunner.__init__(self, binary, args, env, cwd, stdin) self.cmdArgs.append(self.binary) + assert self.args is not None self.cmdArgs.extend(self.args) if "ASAN_SYMBOLIZER_PATH" not in self.env: @@ -228,9 +263,11 @@ def __init__(self, binary, args=None, env=None, cwd=None, stdin=None): os.path.dirname(binary), "llvm-symbolizer" ) if not os.path.isfile(self.env["ASAN_SYMBOLIZER_PATH"]): - self.env["ASAN_SYMBOLIZER_PATH"] = spawn.find_executable( + spawn_find_llvm_symbolizer = spawn.find_executable( "llvm-symbolizer" ) + assert spawn_find_llvm_symbolizer is not None + self.env["ASAN_SYMBOLIZER_PATH"] = spawn_find_llvm_symbolizer if not self.env["ASAN_SYMBOLIZER_PATH"]: raise RuntimeError("Unable to locate llvm-symbolizer") @@ -262,7 +299,7 @@ def __init__(self, binary, args=None, env=None, cwd=None, stdin=None): # for bucketing. This is helpful when assertions are hit in debug builds self.env["ASAN_OPTIONS"] = "allocator_may_return_null=1:handle_abort=1" - def run(self): + def run(self) -> bool: process = subprocess.Popen( self.cmdArgs, stdin=subprocess.PIPE, @@ -275,14 +312,14 @@ def run(self): (stdout, stderr) = process.communicate(input=self.stdin) self.stdout = stdout.decode("utf-8", errors="ignore") - stderr = stderr.decode("utf-8", errors="ignore") + stderr_decoded = stderr.decode("utf-8", errors="ignore") inASanTrace = False inUBSanTrace = False inTSanTrace = False self.auxCrashData = [] self.stderr = [] - for line in stderr.splitlines(): + for line in stderr_decoded.splitlines(): if inASanTrace or inUBSanTrace or inTSanTrace: self.auxCrashData.append(line) if (inASanTrace or inUBSanTrace) and line.find("==ABORTING") >= 0: diff --git a/FTB/Running/GDB.py b/FTB/Running/GDB.py index 8b82c78b3..e7993bdcc 100644 --- a/FTB/Running/GDB.py +++ b/FTB/Running/GDB.py @@ -12,37 +12,39 @@ @contact: choller@mozilla.com """ +from __future__ import annotations -def is64bit(): - return not str(gdb.parse_and_eval("$rax")) == "void" # noqa @UndefinedVariable +def is64bit() -> bool: + return not str(gdb.parse_and_eval("$rax")) == "void" # type: ignore[name-defined] # noqa @UndefinedVariable -def isARM(): - return not str(gdb.parse_and_eval("$r0")) == "void" # noqa @UndefinedVariable +def isARM() -> bool: + return not str(gdb.parse_and_eval("$r0")) == "void" # type: ignore[name-defined] # noqa @UndefinedVariable -def isARM64(): - return not str(gdb.parse_and_eval("$x0")) == "void" # noqa @UndefinedVariable +def isARM64() -> bool: + return not str(gdb.parse_and_eval("$x0")) == "void" # type: ignore[name-defined] # noqa @UndefinedVariable -def regAsHexStr(reg): + +def regAsHexStr(reg: str) -> str: if is64bit(): mask = 0xFFFFFFFFFFFFFFFF else: mask = 0xFFFFFFFF - val = int(str(gdb.parse_and_eval("$" + reg)), 0) & mask # noqa @UndefinedVariable + val = int(str(gdb.parse_and_eval("$" + reg)), 0) & mask # type: ignore[name-defined] # noqa @UndefinedVariable return f"0x{val:x}" -def regAsIntStr(reg): - return str(int(str(gdb.parse_and_eval("$" + reg)), 0)) # noqa @UndefinedVariable +def regAsIntStr(reg: str) -> str: + return str(int(str(gdb.parse_and_eval("$" + reg)), 0)) # type: ignore[name-defined] # noqa @UndefinedVariable -def regAsRaw(reg): - return str(gdb.parse_and_eval("$" + reg)) # noqa @UndefinedVariable +def regAsRaw(reg: str) -> str: + return str(gdb.parse_and_eval("$" + reg)) # type: ignore[name-defined] # noqa @UndefinedVariable -def printImportantRegisters(): +def printImportantRegisters() -> None: if is64bit(): regs = ( "rax rbx rcx rdx rsi rdi rbp rsp r8 r9 r10 r11 r12 r13 r14 r15 rip".split( diff --git a/FTB/Running/PersistentApplication.py b/FTB/Running/PersistentApplication.py index 954558c05..06c6868e5 100644 --- a/FTB/Running/PersistentApplication.py +++ b/FTB/Running/PersistentApplication.py @@ -17,6 +17,9 @@ @contact: choller@mozilla.com """ + +from __future__ import annotations + import os import queue import signal @@ -75,7 +78,7 @@ def __init__( persistentMode=PersistentMode.NONE, processingTimeout=10, inputFile=None, - ): + ) -> None: self.binary = binary self.cwd = cwd @@ -114,19 +117,20 @@ def __init__( self.spfpPrefix = "" self.spfpSuffix = "" # To support - def start(self, test=None): + def start(self, test: str | None = None) -> int: pass - def stop(self): + def stop(self) -> None: pass - def runTest(self, test): + def runTest(self, test: str) -> int: pass - def status(self): + def status(self) -> None: pass - def _crashed(self): + def _crashed(self) -> bool: + assert self.process is not None if self.process.returncode < 0: crashSignals = [ # POSIX.1-1990 signals @@ -157,7 +161,7 @@ def __init__( persistentMode=PersistentMode.NONE, processingTimeout=10, inputFile=None, - ): + ) -> None: PersistentApplication.__init__( self, binary, args, env, cwd, persistentMode, processingTimeout, inputFile ) @@ -169,7 +173,8 @@ def __init__( self.outCollector = None self.errCollector = None - def _write_log_test(self, test): + def _write_log_test(self, test: str) -> None: + assert self.testLog is not None self.testLog.append(test) if self.inputFile: @@ -193,7 +198,8 @@ def _write_log_test(self, test): self.process.stdin.write(test) self.process.stdin.close() - def _wait_child_stopped(self): + def _wait_child_stopped(self) -> bool: + assert self.process is not None monitor = WaitpidMonitor(self.process.pid, os.WUNTRACED) monitor.start() monitor.join(self.processingTimeout) @@ -208,10 +214,11 @@ def _wait_child_stopped(self): return True - def start(self, test=None): + def start(self, test: str | None = None) -> int: assert self.process is None or self.process.poll() is not None # Reset the test log + assert self.testLog is not None self.testLog = [] if self.persistentMode == PersistentMode.NONE: @@ -315,7 +322,7 @@ def start(self, test=None): return ret - def stop(self): + def stop(self) -> None: self._terminateProcess() # Ensure we leave no dangling threads when stopping @@ -328,7 +335,7 @@ def stop(self): self.stdout = self.outCollector.output self.stderr = self.errCollector.output - def runTest(self, test): + def runTest(self, test: str) -> int: if self.process is None or self.process.poll() is not None: self.start() @@ -341,6 +348,7 @@ def runTest(self, test): block=True, timeout=self.processingTimeout ) except queue.Empty: + assert self.process is not None if self.process.poll() is None: # The process is still running, force it to stop and return timeout # code @@ -374,6 +382,8 @@ def runTest(self, test): ) # Update stdout/err available for the last run + assert self.errCollector is not None + assert self.outCollector is not None self.stdout = self.outCollector.output self.stderr = self.errCollector.output @@ -387,6 +397,7 @@ def runTest(self, test): ) elif self.persistentMode == PersistentMode.SIGSTOP: # Resume the process + assert self.process is not None os.kill(self.process.pid, signal.SIGCONT) # Wait for process to stop itself again @@ -415,7 +426,7 @@ def runTest(self, test): return ApplicationStatus.OK - def _terminateProcess(self): + def _terminateProcess(self) -> None: if self.process: if self.process.poll() is None: # Try to terminate the process gracefully first diff --git a/FTB/Running/StreamCollector.py b/FTB/Running/StreamCollector.py index e050c4295..5a139b2e5 100644 --- a/FTB/Running/StreamCollector.py +++ b/FTB/Running/StreamCollector.py @@ -11,12 +11,22 @@ @contact: choller@mozilla.com """ + +from __future__ import annotations + import queue import threading +from typing import TextIO class StreamCollector(threading.Thread): - def __init__(self, fd, responseQueue, logResponses=False, maxBacklog=None): + def __init__( + self, + fd: TextIO, + responseQueue: queue.Queue, + logResponses: bool = False, + maxBacklog: int | None = None, + ) -> None: assert callable(fd.readline) assert isinstance(responseQueue, queue.Queue) @@ -24,12 +34,12 @@ def __init__(self, fd, responseQueue, logResponses=False, maxBacklog=None): self.fd = fd self.queue = responseQueue - self.output = [] - self.responsePrefixes = [] + self.output: list[str] = [] + self.responsePrefixes: list[str] = [] self.logResponses = logResponses self.maxBacklog = maxBacklog - def run(self): + def run(self) -> None: while True: line = self.fd.readline(4096) @@ -53,5 +63,5 @@ def run(self): self.fd.close() - def addResponsePrefix(self, prefix): + def addResponsePrefix(self, prefix: str) -> None: self.responsePrefixes.append(prefix) diff --git a/FTB/Running/WaitpidMonitor.py b/FTB/Running/WaitpidMonitor.py index 846bcae74..2a267cab1 100644 --- a/FTB/Running/WaitpidMonitor.py +++ b/FTB/Running/WaitpidMonitor.py @@ -12,20 +12,23 @@ @contact: choller@mozilla.com """ + +from __future__ import annotations + import os import threading class WaitpidMonitor(threading.Thread): - def __init__(self, pid, options): + def __init__(self, pid: int, options: int) -> None: threading.Thread.__init__(self) - self.pid = pid - self.options = options + self.pid: int = pid + self.options: int = options - self.childPid = None - self.childExit = None + self.childPid: int | None = None + self.childExit: int | None = None - def run(self): + def run(self) -> None: while not self.childPid: (self.childPid, self.childExit) = os.waitpid(self.pid, self.options) diff --git a/FTB/Running/tests/test_persistent.py b/FTB/Running/tests/test_persistent.py index a08ba1c1a..c105f0628 100644 --- a/FTB/Running/tests/test_persistent.py +++ b/FTB/Running/tests/test_persistent.py @@ -12,9 +12,12 @@ @contact: choller@mozilla.com """ +from __future__ import annotations + import os import sys import time +from pathlib import Path import pytest @@ -27,13 +30,14 @@ TEST_PATH = os.path.dirname(__file__) -def test_PersistentApplicationTestModeNone(tmp_path): - def _check(spa): +def test_PersistentApplicationTestModeNone(tmp_path: Path) -> None: + def _check(spa: SimplePersistentApplication) -> None: try: ret = spa.start("aa") assert ret == ApplicationStatus.OK + assert spa.stdout is not None assert spa.stdout[0] == "Stdout test1" assert spa.stdout[1] == "Stdout test2" @@ -75,8 +79,8 @@ def _check(spa): @pytest.mark.xfail -def test_PersistentApplicationTestOtherModes(tmp_path): - def _check(spa): +def test_PersistentApplicationTestOtherModes(tmp_path: Path) -> None: + def _check(spa: SimplePersistentApplication) -> None: try: ret = spa.start() @@ -142,11 +146,12 @@ def _check(spa): @pytest.mark.xfail -def test_PersistentApplicationTestPerf(tmp_path): - def _check(spa): +def test_PersistentApplicationTestPerf(tmp_path: Path) -> None: + def _check(spa: SimplePersistentApplication) -> None: try: spa.start() + assert spa.process is not None oldPid = spa.process.pid startTime = time.time() @@ -187,7 +192,7 @@ def _check(spa): ) -def test_PersistentApplicationTestFaultySigstop(tmp_path): +def test_PersistentApplicationTestFaultySigstop(tmp_path: Path) -> None: inputFile = tmp_path / "input.tmp" inputFile.touch() spa = SimplePersistentApplication( @@ -201,7 +206,7 @@ def test_PersistentApplicationTestFaultySigstop(tmp_path): spa.start() -def test_PersistentApplicationTestStopWithoutStart(tmp_path): +def test_PersistentApplicationTestStopWithoutStart(tmp_path: Path) -> None: inputFile = tmp_path / "input.tmp" inputFile.touch() spa = SimplePersistentApplication( diff --git a/FTB/Running/tests/test_shell.py b/FTB/Running/tests/test_shell.py index 36e5ace89..96b32b29d 100644 --- a/FTB/Running/tests/test_shell.py +++ b/FTB/Running/tests/test_shell.py @@ -11,29 +11,33 @@ @contact: choller@mozilla.com """ + +from __future__ import annotations + import os import signal import sys import time +from typing import TextIO -def crash(): +def crash() -> None: import ctypes # Causes a NULL deref ctypes.string_at(0) -def hang(): +def hang() -> None: while True: time.sleep(1) -def stop(): +def stop() -> None: os.kill(os.getpid(), signal.SIGSTOP) -def processInput(mode, inputFd): +def processInput(mode: str, inputFd: TextIO) -> None: received_aa = False if mode == "none": @@ -90,7 +94,7 @@ def processInput(mode, inputFd): sys.exit(0) -def main(): +def main() -> int: if len(sys.argv) < 2: print("Need at least one argument (mode)", file=sys.stderr) sys.exit(1) diff --git a/FTB/Signatures/CrashInfo.py b/FTB/Signatures/CrashInfo.py index 974cae490..d2a60e214 100644 --- a/FTB/Signatures/CrashInfo.py +++ b/FTB/Signatures/CrashInfo.py @@ -14,11 +14,17 @@ @contact: choller@mozilla.com """ + +from __future__ import annotations + import json import os import re import sys from abc import ABCMeta +from collections.abc import Mapping + +from typing_extensions import TypedDict from FTB import AssertionHelper from FTB.ProgramConfiguration import ProgramConfiguration @@ -26,12 +32,12 @@ from FTB.Signatures.CrashSignature import CrashSignature -def _is_unfinished(symbol, operators): +def _is_unfinished(symbol: list[str], operators: str) -> bool: start, end = operators return bool(symbol.count(start) > symbol.count(end)) -def uint32(val): +def uint32(val: int) -> int: """Force `val` into unsigned 32-bit range. Note that the input is returned as an int, therefore @@ -51,7 +57,7 @@ def uint32(val): return val & 0xFFFFFFFF -def int32(val): +def int32(val: int) -> int: """Force `val` into signed 32-bit range. Note that the input is returned as an int, therefore @@ -74,7 +80,7 @@ def int32(val): return val -def uint64(val): +def uint64(val: int) -> int: """Force `val` into unsigned 64-bit range. Note that the input is returned as an int, therefore @@ -94,7 +100,7 @@ def uint64(val): return val & 0xFFFFFFFFFFFFFFFF -def int64(val): +def int64(val: int) -> int: """Force `val` into signed 64-bit range. Note that the input is returned as an int, therefore @@ -117,36 +123,46 @@ def int64(val): return val +class CacheObject(TypedDict): + """CacheObject type specification.""" + + backtrace: list[str] + registers: dict[str, int] + crashAddress: int | None + crashInstruction: str | None + failureReason: str | None + + class CrashInfo(metaclass=ABCMeta): """ Abstract base class that provides a method to instantiate the right sub class. It also supports generating a CrashSignature based on the stored information. """ - def __init__(self): + def __init__(self) -> None: # Store the raw data - self.rawStdout = [] - self.rawStderr = [] - self.rawCrashData = [] + self.rawStdout: list[str] = [] + self.rawStderr: list[str] = [] + self.rawCrashData: list[str] = [] # Store processed data - self.backtrace = [] - self.registers = {} - self.crashAddress = None - self.crashInstruction = None + self.backtrace: list[str] = [] + self.registers: dict[str, int] = {} + self.crashAddress: int | None = None + self.crashInstruction: str | None = None # Store configuration data (platform, product, os, etc.) - self.configuration = None + self.configuration: ProgramConfiguration | None = None # This is an optional testcase that is not stored with the crashInfo but # can be "attached" before matching signatures that might require the # testcase. - self.testcase = None + self.testcase: bytes | str | None = None # This can be used to record failures during signature creation - self.failureReason = None + self.failureReason: str | None = None - def __str__(self): + def __str__(self) -> str: buf = [] buf.append("Crash trace:") buf.append("") @@ -168,52 +184,45 @@ def __str__(self): return "\n".join(buf) - def toCacheObject(self): + def toCacheObject(self) -> CacheObject: """ Create a cache object for restoring the class instance later on without parsing the crash data again. This object includes all class fields except for the storage heavy raw objects like stdout, stderr and raw crashdata. - @rtype: dict @return: Dictionary containing expensive class fields """ - cacheObject = {} - cacheObject["backtrace"] = self.backtrace - cacheObject["registers"] = self.registers - - if self.crashAddress is not None: - cacheObject["crashAddress"] = int(self.crashAddress) - else: - cacheObject["crashAddress"] = None - - cacheObject["crashInstruction"] = self.crashInstruction - cacheObject["failureReason"] = self.failureReason + cacheObject: CacheObject = { + "backtrace": self.backtrace, + "registers": self.registers, + "crashAddress": self.crashAddress, + "crashInstruction": self.crashInstruction, + "failureReason": self.failureReason, + } return cacheObject @staticmethod def fromRawCrashData( - stdout, stderr, configuration, auxCrashData=None, cacheObject=None - ): + stdout: list[str] | str | None, + stderr: list[str] | str | None, + configuration: ProgramConfiguration, + auxCrashData: list[str] | str | None = None, + cacheObject: CacheObject | None = None, + ) -> CrashInfo: """ Create appropriate CrashInfo instance from raw crash data - @type stdout: List of strings @param stdout: List of lines as they appeared on stdout - @type stderr: List of strings @param stderr: List of lines as they appeared on stderr - @type configuration: ProgramConfiguration @param configuration: Exact program configuration that is associated with the crash - @type auxCrashData: List of strings @param auxCrashData: Optional additional crash output (e.g. GDB). If not specified, stderr is used. - @type cacheObject: Dictionary @param cacheObject: The cache object that should be used to restore the class fields instead of parsing the crash data. The appropriate object can be created by calling the toCacheObject method. - @rtype: CrashInfo @return: Crash information object """ @@ -287,7 +296,7 @@ def fromRawCrashData( if stderr is not None: lines.extend(stderr) - result = None + result: CrashInfo | None = None for line in lines: if ubsanString in line and re.match(ubsanRegex, line) is not None: result = UBSanCrashInfo(stdout, stderr, configuration, auxCrashData) @@ -351,9 +360,8 @@ def fromRawCrashData( return result - def createShortSignature(self): + def createShortSignature(self) -> str: """ - @rtype: String @return: A string representing this crash (short signature) """ # See if we have an abort message and if so, use that as short signature @@ -376,27 +384,22 @@ def createShortSignature(self): def createCrashSignature( self, - forceCrashAddress=False, - forceCrashInstruction=False, - maxFrames=8, - minimumSupportedVersion=13, - ): + forceCrashAddress: bool = False, + forceCrashInstruction: bool = False, + maxFrames: int = 8, + minimumSupportedVersion: int = 13, + ) -> CrashSignature | None: """ @param forceCrashAddress: If True, the crash address will be included in any case - @type forceCrashAddress: bool @param forceCrashInstruction: If True, the crash instruction will be included in any case - @type forceCrashInstruction: bool @param maxFrames: How many frames (at most) should be included in the signature - @type maxFrames: int @param minimumSupportedVersion: The minimum crash signature standard version that the generated signature should be valid for (10 => 1.0, 13 => 1.3) - @type minimumSupportedVersion: int - @rtype: CrashSignature @return: A crash signature object """ # Determine the actual number of frames based on how many we got @@ -405,7 +408,7 @@ def createCrashSignature( else: numFrames = len(self.backtrace) - symptomArr = [] + symptomArr: list[Mapping[str, object]] = [] # Memorize where we find our abort messages abortMsgInCrashdata = False @@ -437,8 +440,8 @@ def createCrashSignature( if not isinstance(abortMsgs, list): abortMsgs = [abortMsgs] - for abortMsg in abortMsgs: - abortMsg = AssertionHelper.getSanitizedAssertionPattern(abortMsg) + for msg in abortMsgs: + abortMsg = AssertionHelper.getSanitizedAssertionPattern(msg) abortMsgSrc = "stderr" if abortMsgInCrashdata: abortMsgSrc = "crashdata" @@ -447,6 +450,7 @@ def createCrashSignature( # Versions below 1.2 only support the full object PCRE style, # for anything newer, use the short form with forward slashes # to increase the readability of the signatures. + symptomObj: dict[str, object] if minimumSupportedVersion < 12: stringObj = {"value": abortMsg, "matchType": "pcre"} symptomObj = { @@ -571,15 +575,13 @@ def createCrashSignature( return CrashSignature(json.dumps(sigObj, indent=2, sort_keys=True)) @staticmethod - def sanitizeStackFrame(frame): + def sanitizeStackFrame(frame: str) -> str: """ This function removes function arguments and other non-generic parts of the function frame, returning a (hopefully) generic function name. @param frame: The stack frame to sanitize - @type forceCrashAddress: str - @rtype: str @return: Sanitized stack frame """ @@ -615,7 +617,13 @@ def sanitizeStackFrame(frame): class NoCrashInfo(CrashInfo): - def __init__(self, stdout, stderr, configuration, crashData=None): + def __init__( + self, + stdout: list[str] | None, + stderr: list[str] | None, + configuration: ProgramConfiguration, + crashData: list[str] | None = None, + ): """ Private constructor, called by L{CrashInfo.fromRawCrashData}. Do not use directly. @@ -635,7 +643,13 @@ def __init__(self, stdout, stderr, configuration, crashData=None): class ASanCrashInfo(CrashInfo): - def __init__(self, stdout, stderr, configuration, crashData=None): + def __init__( + self, + stdout: list[str] | None, + stderr: list[str] | None, + configuration: ProgramConfiguration, + crashData: list[str] | None = None, + ): """ Private constructor, called by L{CrashInfo.fromRawCrashData}. Do not use directly. @@ -681,75 +695,77 @@ def __init__(self, stdout, stderr, configuration, crashData=None): expectedIndex = 0 reportFound = False - for traceLine in asanOutput: - if not reportFound or self.crashAddress is None: - match = re.search(asanCrashAddressPattern, traceLine) - if match is not None: - reportFound = True - try: - self.crashAddress = int(match.group(1), 16) - except TypeError: - pass # No crash address available - - # Crash Address and Registers are in the same line for ASan - match = re.search(asanRegisterPattern, traceLine) - # Collect register values if they are available in the ASan trace + if asanOutput is not None: + for traceLine in asanOutput: + if not reportFound or self.crashAddress is None: + match = re.search(asanCrashAddressPattern, traceLine) if match is not None: - self.registers["pc"] = int(match.group(1), 16) - self.registers[match.group(2)] = int(match.group(3), 16) - self.registers[match.group(4)] = int(match.group(5), 16) - elif not reportFound: - # Not in the ASan output yet. - # Some lines in eg. debug+asan builds might error if we continue. + reportFound = True + try: + self.crashAddress = int(match.group(1), 16) + except TypeError: + pass # No crash address available + + # Crash Address and Registers are in the same line for ASan + match = re.search(asanRegisterPattern, traceLine) + # Collect register values if they are available in the ASan + # trace + if match is not None: + self.registers["pc"] = int(match.group(1), 16) + self.registers[match.group(2)] = int(match.group(3), 16) + self.registers[match.group(4)] = int(match.group(5), 16) + elif not reportFound: + # Not in the ASan output yet. + # Some lines in eg. debug+asan builds might error if we continue + continue + + index, parts = self.split_frame(traceLine) + if index is None: continue - index, parts = self.split_frame(traceLine) - if index is None: - continue - - # We may see multiple traces in ASAN - if index == 0: - expectedIndex = 0 + # We may see multiple traces in ASAN + if index == 0: + expectedIndex = 0 - if not expectedIndex == index: - raise RuntimeError( - f"Fatal error parsing ASan trace (Index mismatch, got index {index}" - f" but expected {expectedIndex})" - ) + if not expectedIndex == index: + raise RuntimeError( + "Fatal error parsing ASan trace (Index mismatch, got index " + f"{index} but expected {expectedIndex})" + ) - component = None - # TSan doesn't include address, symbol will be immediately following the - # frame number - if len(parts) > 1 and not parts[1].startswith("0x"): - if parts[1] == "": - # the last part is either `(lib.so+0xoffset)` or `(0xaddress)` - if "+" in parts[-1]: - # Remove parentheses around component - component = parts[-1][1:-1] + component = None + # TSan doesn't include address, symbol will be immediately following the + # frame number + if len(parts) > 1 and not parts[1].startswith("0x"): + if parts[1] == "": + # the last part is either `(lib.so+0xoffset)` or `(0xaddress)` + if "+" in parts[-1]: + # Remove parentheses around component + component = parts[-1][1:-1] + else: + component = "" else: + component = " ".join(parts[1:-2]) + elif len(parts) > 2: + if parts[2] == "in": + component = " ".join(parts[3:-1]) + elif parts[2:] == ["()"]: component = "" + else: + # Remove parentheses around component + component = parts[2][1:-1] else: - component = " ".join(parts[1:-2]) - elif len(parts) > 2: - if parts[2] == "in": - component = " ".join(parts[3:-1]) - elif parts[2] == "()": + print( + f"Warning: Missing component in this line: {traceLine}", + file=sys.stderr, + ) component = "" - else: - # Remove parentheses around component - component = parts[2][1:-1] - else: - print( - f"Warning: Missing component in this line: {traceLine}", - file=sys.stderr, - ) - component = "" - self.backtrace.append(CrashInfo.sanitizeStackFrame(component)) - expectedIndex += 1 + self.backtrace.append(CrashInfo.sanitizeStackFrame(component)) + expectedIndex += 1 @staticmethod - def split_frame(line): + def split_frame(line: str) -> tuple[int, list[str]]: parts = line.strip().split() # We only want stack frames @@ -786,9 +802,8 @@ def split_frame(line): return frame_no, parts - def createShortSignature(self): + def createShortSignature(self) -> str: """ - @rtype: String @return: A string representing this crash (short signature) """ # Always prefer using regular program aborts @@ -861,7 +876,13 @@ def createShortSignature(self): class LSanCrashInfo(CrashInfo): - def __init__(self, stdout, stderr, configuration, crashData=None): + def __init__( + self, + stdout: list[str] | None, + stderr: list[str] | None, + configuration: ProgramConfiguration, + crashData: list[str] | None = None, + ): """ Private constructor, called by L{CrashInfo.fromRawCrashData}. Do not use directly. @@ -885,38 +906,39 @@ def __init__(self, stdout, stderr, configuration, crashData=None): lsanPatternSeen = False expectedIndex = 0 - for traceLine in lsanOutput: - if not lsanErrorPattern: - if lsanErrorPattern in traceLine: - lsanPatternSeen = True - continue + if lsanOutput is not None: + for traceLine in lsanOutput: + if not lsanErrorPattern: + if lsanErrorPattern in traceLine: + lsanPatternSeen = True + continue - index, parts = ASanCrashInfo.split_frame(traceLine) - if index is None: - continue + index, parts = ASanCrashInfo.split_frame(traceLine) + if index is None: + continue - if expectedIndex != index: - raise RuntimeError( - f"Fatal error parsing LSan trace (Index mismatch, got index {index}" - f" but expected {expectedIndex})" - ) + if expectedIndex != index: + raise RuntimeError( + f"Fatal error parsing LSan trace (Index mismatch, got index {index}" + f" but expected {expectedIndex})" + ) - component = None - if len(parts) > 2: - if parts[2] == "in": - component = " ".join(parts[3:-1]) + component = None + if len(parts) > 2: + if parts[2] == "in": + component = " ".join(parts[3:-1]) + else: + # Remove parentheses around component + component = parts[2][1:-1] else: - # Remove parentheses around component - component = parts[2][1:-1] - else: - print( - f"Warning: Missing component in this line: {traceLine}", - file=sys.stderr, - ) - component = "" + print( + f"Warning: Missing component in this line: {traceLine}", + file=sys.stderr, + ) + component = "" - self.backtrace.append(CrashInfo.sanitizeStackFrame(component)) - expectedIndex += 1 + self.backtrace.append(CrashInfo.sanitizeStackFrame(component)) + expectedIndex += 1 if not self.backtrace and lsanPatternSeen: # We've seen the crash address but no frames, so this is likely @@ -924,9 +946,8 @@ def __init__(self, stdout, stderr, configuration, crashData=None): # frame so it doesn't show up as "No crash detected" self.backtrace.append("??") - def createShortSignature(self): + def createShortSignature(self) -> str: """ - @rtype: String @return: A string representing this crash (short signature) """ # Try to find the LSan message on stderr and use that as short signature @@ -948,7 +969,13 @@ def createShortSignature(self): class UBSanCrashInfo(CrashInfo): - def __init__(self, stdout, stderr, configuration, crashData=None): + def __init__( + self, + stdout: list[str] | None, + stderr: list[str] | None, + configuration: ProgramConfiguration, + crashData: list[str] | None = None, + ): """ Private constructor, called by L{CrashInfo.fromRawCrashData}. Do not use directly. @@ -972,38 +999,39 @@ def __init__(self, stdout, stderr, configuration, crashData=None): ubsanPatternSeen = False expectedIndex = 0 - for traceLine in ubsanOutput: - if not ubsanPatternSeen: - if re.search(ubsanErrorPattern, traceLine) is not None: - ubsanPatternSeen = True - continue + if ubsanOutput is not None: + for traceLine in ubsanOutput: + if not ubsanPatternSeen: + if re.search(ubsanErrorPattern, traceLine) is not None: + ubsanPatternSeen = True + continue - index, parts = ASanCrashInfo.split_frame(traceLine) - if index is None: - continue + index, parts = ASanCrashInfo.split_frame(traceLine) + if index is None: + continue - if expectedIndex != index: - raise RuntimeError( - "Fatal error parsing UBSan trace (Index mismatch, got index " - f"{index} but expected {expectedIndex})" - ) + if expectedIndex != index: + raise RuntimeError( + "Fatal error parsing UBSan trace (Index mismatch, got index " + f"{index} but expected {expectedIndex})" + ) - component = None - if len(parts) > 2: - if parts[2] == "in": - component = " ".join(parts[3:-1]) + component = None + if len(parts) > 2: + if parts[2] == "in": + component = " ".join(parts[3:-1]) + else: + # Remove parentheses around component + component = parts[2][1:-1] else: - # Remove parentheses around component - component = parts[2][1:-1] - else: - print( - f"Warning: Missing component in this line: {traceLine}", - file=sys.stderr, - ) - component = "" + print( + f"Warning: Missing component in this line: {traceLine}", + file=sys.stderr, + ) + component = "" - self.backtrace.append(CrashInfo.sanitizeStackFrame(component)) - expectedIndex += 1 + self.backtrace.append(CrashInfo.sanitizeStackFrame(component)) + expectedIndex += 1 if not self.backtrace and ubsanPatternSeen: # We've seen the crash address but no frames, so this is likely @@ -1011,9 +1039,8 @@ def __init__(self, stdout, stderr, configuration, crashData=None): # frame so it doesn't show up as "No crash detected" self.backtrace.append("??") - def createShortSignature(self): + def createShortSignature(self) -> str: """ - @rtype: String @return: A string representing this crash (short signature) """ # Try to find the UBSan message on stderr and use that as short signature @@ -1035,7 +1062,13 @@ def createShortSignature(self): class GDBCrashInfo(CrashInfo): - def __init__(self, stdout, stderr, configuration, crashData=None): + def __init__( + self, + stdout: list[str] | None, + stderr: list[str] | None, + configuration: ProgramConfiguration, + crashData: list[str] | None = None, + ): """ Private constructor, called by L{CrashInfo.fromRawCrashData}. Do not use directly. @@ -1057,6 +1090,7 @@ def __init__(self, stdout, stderr, configuration, crashData=None): if crashData: gdbOutput = crashData else: + assert stderr is not None gdbOutput = stderr gdbFramePatterns = [ @@ -1159,6 +1193,7 @@ def __init__(self, stdout, stderr, configuration, crashData=None): # This is a workaround for GDB throwing an error while resolving # function arguments in the trace and aborting. We try to remove the # error message to at least recover the function name properly. + assert functionName is not None gdbErrorIdx = functionName.find(" (/build/buildd/gdb") if gdbErrorIdx > 0: functionName = functionName[:gdbErrorIdx] @@ -1194,16 +1229,16 @@ def __init__(self, stdout, stderr, configuration, crashData=None): self.crashAddress = uint64(self.crashAddress) @staticmethod - def calculateCrashAddress(crashInstruction, registerMap): + def calculateCrashAddress( + crashInstruction: str, + registerMap: dict[str, int], + ) -> str | int | None: """ Calculate the crash address given the crash instruction and register contents - @type crashInstruction: string @param crashInstruction: Crash instruction string as provided by GDB - @type registerMap: Map from string to int @param registerMap: Map of register names to values - @rtype: int @return The calculated crash address On error, a string containing the failure message is returned instead. @@ -1265,12 +1300,12 @@ def calculateCrashAddress(crashInstruction, registerMap): # e.g. shrb -0x69(%rdx,%rbx,8) # When we fail, try storing a reason here - failureReason = "Unknown failure." + failureReason: str | None = "Unknown failure." - def isDerefOp(op): + def isDerefOp(op: str) -> bool: return "(" in op and ")" in op - def calculateDerefOpAddress(derefOp): + def calculateDerefOpAddress(derefOp: str) -> tuple[int | None, str | None]: match = re.match("\\*?((?:\\-?0x[0-9a-f]+)?)\\(%([a-z0-9]+)\\)", derefOp) if match is not None: offset = 0 @@ -1391,14 +1426,17 @@ def calculateDerefOpAddress(derefOp): if "(" in parts[0] and ")" in parts[2]: complexDerefOp = parts[0] + "," + parts[1] + "," + parts[2] - (result, reason) = GDBCrashInfo.calculateComplexDerefOpAddress( + ( + result_parts_3, + reason, + ) = GDBCrashInfo.calculateComplexDerefOpAddress( complexDerefOp, registerMap ) - if result is None: + if result_parts_3 is None: failureReason = reason else: - return result + return result_parts_3 else: raise RuntimeError( f"Unexpected instruction pattern: {crashInstruction}" @@ -1409,14 +1447,14 @@ def calculateDerefOpAddress(derefOp): elif "(" not in parts[0] and ")" not in parts[0]: complexDerefOp = parts[1] + "," + parts[2] + "," + parts[3] - (result, reason) = GDBCrashInfo.calculateComplexDerefOpAddress( + (result_parts_4, reason) = GDBCrashInfo.calculateComplexDerefOpAddress( complexDerefOp, registerMap ) - if result is None: + if result_parts_4 is None: failureReason = reason else: - return result + return result_parts_4 else: raise RuntimeError( "Unexpected length after splitting operands of this instruction: %s" @@ -1427,13 +1465,15 @@ def calculateDerefOpAddress(derefOp): # Anything that is not explicitly handled now is considered unsupported failureReason = "Unsupported instruction in incomplete ARM/ARM64 support." - def getARMImmConst(val): + def getARMImmConst(val: str) -> int: val = val.replace("#", "").strip() if val.startswith("0x"): return int(val, 16) return int(val) - def calculateARMDerefOpAddress(derefOp): + def calculateARMDerefOpAddress( + derefOp: str, + ) -> tuple[int | None, str | None]: derefOps = derefOp.split(",") if len(derefOps) > 2: @@ -1477,11 +1517,13 @@ def calculateARMDerefOpAddress(derefOp): # Load/Store instruction match = re.match("^\\s*\\[(.*)\\]$", parts[1]) if match is not None: - (result, reason) = calculateARMDerefOpAddress(match.group(1)) - if result is None: + (result_parts_2, reason) = calculateARMDerefOpAddress( + match.group(1) + ) + if result_parts_2 is None: failureReason += f" ({reason})" else: - return result + return result_parts_2 else: failureReason = "Architecture is not supported." @@ -1493,7 +1535,10 @@ def calculateARMDerefOpAddress(derefOp): return failureReason @staticmethod - def calculateComplexDerefOpAddress(complexDerefOp, registerMap): + def calculateComplexDerefOpAddress( + complexDerefOp: str, + registerMap: dict[str, int], + ) -> tuple[int | None, str | None]: match = re.match( "((?:\\-?0x[0-9a-f]+)?)\\(%([a-z0-9]+),%([a-z0-9]+),([0-9]+)\\)", @@ -1527,7 +1572,13 @@ def calculateComplexDerefOpAddress(complexDerefOp, registerMap): class MinidumpCrashInfo(CrashInfo): - def __init__(self, stdout, stderr, configuration, crashData=None): + def __init__( + self, + stdout: list[str] | None, + stderr: list[str] | None, + configuration: ProgramConfiguration, + crashData: list[str] | None = None, + ): """ Private constructor, called by L{CrashInfo.fromRawCrashData}. Do not use directly. @@ -1550,6 +1601,7 @@ def __init__(self, stdout, stderr, configuration, crashData=None): if crashData: minidumpOuput = crashData else: + assert stderr is not None minidumpOuput = stderr crashThread = None @@ -1580,7 +1632,13 @@ def __init__(self, stdout, stderr, configuration, crashData=None): class AppleCrashInfo(CrashInfo): - def __init__(self, stdout, stderr, configuration, crashData=None): + def __init__( + self, + stdout: list[str] | None, + stderr: list[str] | None, + configuration: ProgramConfiguration, + crashData: list[str] | None = None, + ): """ Private constructor, called by L{CrashInfo.fromRawCrashData}. Do not use directly. @@ -1601,47 +1659,48 @@ def __init__(self, stdout, stderr, configuration, crashData=None): apple_crash_data = crashData or stderr inCrashingThread = False - for line in apple_crash_data: - # Crash address - if line.startswith("Exception Codes:"): - # Example: - # Exception Type: EXC_BAD_ACCESS (SIGABRT) - # Exception Codes: KERN_INVALID_ADDRESS at 0x00000001374b893e - address = line.split(" ")[-1] - if address.startswith("0x"): - self.crashAddress = int(address, 16) - - # Start of stack for crashing thread - if re.match(r"Thread \d+ Crashed:", line): - inCrashingThread = True - continue - - if line.strip() == "": - inCrashingThread = False - continue - - if inCrashingThread: - # Example: - # noqa "0 js-dbg-64-dm-darwin-a523d4c7efe2 0x00000001004b04c4 js::jit::MacroAssembler::Pop(js::jit::Register) + 180 (MacroAssembler-inl.h:50)" - components = line.split(None, 3) - stackEntry = components[3] - if stackEntry.startswith("0"): - self.backtrace.append("??") - else: - stackEntry = AppleCrashInfo.removeFilename(stackEntry) - stackEntry = AppleCrashInfo.removeOffset(stackEntry) - stackEntry = CrashInfo.sanitizeStackFrame(stackEntry) - self.backtrace.append(stackEntry) + if apple_crash_data is not None: + for line in apple_crash_data: + # Crash address + if line.startswith("Exception Codes:"): + # Example: + # Exception Type: EXC_BAD_ACCESS (SIGABRT) + # Exception Codes: KERN_INVALID_ADDRESS at 0x00000001374b893e + address = line.split(" ")[-1] + if address.startswith("0x"): + self.crashAddress = int(address, 16) + + # Start of stack for crashing thread + if re.match(r"Thread \d+ Crashed:", line): + inCrashingThread = True + continue + + if line.strip() == "": + inCrashingThread = False + continue + + if inCrashingThread: + # Example: + # noqa "0 js-dbg-64-dm-darwin-a523d4c7efe2 0x00000001004b04c4 js::jit::MacroAssembler::Pop(js::jit::Register) + 180 (MacroAssembler-inl.h:50)" + components = line.split(None, 3) + stackEntry = components[3] + if stackEntry.startswith("0"): + self.backtrace.append("??") + else: + stackEntry = AppleCrashInfo.removeFilename(stackEntry) + stackEntry = AppleCrashInfo.removeOffset(stackEntry) + stackEntry = CrashInfo.sanitizeStackFrame(stackEntry) + self.backtrace.append(stackEntry) @staticmethod - def removeFilename(stackEntry): + def removeFilename(stackEntry: str) -> str: match = re.match(r"(.*) \(\S+\)", stackEntry) if match: return match.group(1) return stackEntry @staticmethod - def removeOffset(stackEntry): + def removeOffset(stackEntry: str) -> str: match = re.match(r"(.*) \+ \d+", stackEntry) if match: return match.group(1) @@ -1649,7 +1708,13 @@ def removeOffset(stackEntry): class CDBCrashInfo(CrashInfo): - def __init__(self, stdout, stderr, configuration, crashData=None): + def __init__( + self, + stdout: list[str] | None, + stderr: list[str] | None, + configuration: ProgramConfiguration, + crashData: list[str] | None = None, + ): """ Private constructor, called by L{CrashInfo.fromRawCrashData}. Do not use directly. @@ -1678,114 +1743,115 @@ def __init__(self, stdout, stderr, configuration, crashData=None): cdb_crash_data = crashData or stderr - for line in cdb_crash_data: - # Start of .ecxr data - if re.match(r"0:000> \.ecxr", line): - inEcxrData = True - continue - - if inEcxrData: - # 32-bit example: - # 0:000> .ecxr - # noqa eax=02efff01 ebx=016fddb8 ecx=2b2b2b2b edx=016fe490 esi=02e00310 edi=02e00310 - # noqa eip=00404c59 esp=016fdc2c ebp=016fddb8 iopl=0 nv up ei pl nz na po nc - # noqa cs=0023 ss=002b ds=002b es=002b fs=0053 gs=002b efl=00010202 - # 64-bit example: - # 0:000> .ecxr - # rax=00007ff74d8fee30 rbx=00000285ef400420 rcx=2b2b2b2b2b2b2b2b - # rdx=00000285ef21b940 rsi=000000e87fbfc340 rdi=00000285ef400420 - # rip=00007ff74d469ff3 rsp=000000e87fbfc040 rbp=fffe000000000000 - # r8=000000e87fbfc140 r9=000000000001fffc r10=0000000000000649 - # r11=00000285ef25a000 r12=00007ff74d9239a0 r13=fffa7fffffffffff - # r14=000000e87fbfd0e0 r15=0000000000000003 - # iopl=0 nv up ei pl nz na pe nc - # noqa cs=0033 ss=002b ds=002b es=002b fs=0053 gs=002b efl=00010200 - if line.startswith("cs="): - inEcxrData = False - continue - - # Crash address - # Extract the eip/rip specifically for use later - if "eip=" in line: - address = line.split("eip=")[1].split()[0] - self.crashAddress = int(address, 16) - elif "rip=" in line: - address = line.split("rip=")[1].split()[0] - self.crashAddress = int(address, 16) - - # First extract the line - # 32-bit example: - # eax=02efff01 ebx=016fddb8 ecx=2b2b2b2b esi=02e00310 edi=02e00310 - # 64-bit example: - # rax=00007ff74d8fee30 rbx=00000285ef400420 rcx=2b2b2b2b2b2b2b2b - matchLine = re.search(RegisterHelper.getRegisterPattern(), line) - if matchLine is not None: - ecxrData.extend(line.split()) - - # Next, put the eax/rax, ebx/rbx, etc. entries into a list of their own, - # then iterate - match = re.search(cdbRegisterPattern, line) - for instr in ecxrData: - match = re.search(cdbRegisterPattern, instr) - if match is not None: - register = match.group(1) - value = int(match.group(2), 16) - self.registers[register] = value - - # Crash instruction - # Start of crash instruction details - if line.startswith("FAULTING_IP"): - inCrashInstruction = True - continue - - if inCrashInstruction and not cInstruction: - if "PROCESS_NAME" in line: - inCrashInstruction = False + if cdb_crash_data is not None: + for line in cdb_crash_data: + # Start of .ecxr data + if re.match(r"0:000> \.ecxr", line): + inEcxrData = True continue - # 64-bit binaries have a backtick in their addresses, - # e.g. 00007ff7`1e424e62 - lineWithoutBacktick = line.replace("`", "", 1) - if address and lineWithoutBacktick.startswith(address): - # 32-bit examples: - # 25d80b01 cc int 3 - # 00404c59 8b39 mov edi,dword ptr [ecx] + if inEcxrData: + # 32-bit example: + # 0:000> .ecxr + # noqa eax=02efff01 ebx=016fddb8 ecx=2b2b2b2b edx=016fe490 esi=02e00310 edi=02e00310 + # noqa eip=00404c59 esp=016fdc2c ebp=016fddb8 iopl=0 nv up ei pl nz na po nc + # noqa cs=0023 ss=002b ds=002b es=002b fs=0053 gs=002b efl=00010202 # 64-bit example: - # 00007ff7`4d469ff3 4c8b01 mov r8,qword ptr [rcx] - cInstruction = line.split(None, 2)[-1] - # There may be multiple spaces inside the faulting instruction - cInstruction = " ".join(cInstruction.split()) - self.crashInstruction = cInstruction - - # Start of stack for crashing thread - if line.startswith("STACK_TEXT:"): - inCrashingThread = True - continue - - if inCrashingThread: - # 32-bit example: - # noqa "016fdc38 004b2387 01e104e8 016fe490 016fe490 js_32_dm_windows_62f79d676e0e!JSObject::allocKindForTenure+0x9" - # 64-bit example: - # noqa "000000e8`7fbfc040 00007ff7`4d53a984 : 000000e8`7fbfc2c0 00000285`ef7bb400 00000285`ef21b000 00007ff7`4d4254b9 : js_64_dm_windows_62f79d676e0e!JSObject::allocKindForTenure+0x13" - if ( - "STACK_COMMAND" in line - or "SYMBOL_NAME" in line - or "THREAD_SHA1_HASH_MOD_FUNC" in line - or "FAULTING_SOURCE_CODE" in line - ): - inCrashingThread = False + # 0:000> .ecxr + # noqa rax=00007ff74d8fee30 rbx=00000285ef400420 rcx=2b2b2b2b2b2b2b2b + # noqa rdx=00000285ef21b940 rsi=000000e87fbfc340 rdi=00000285ef400420 + # noqa rip=00007ff74d469ff3 rsp=000000e87fbfc040 rbp=fffe000000000000 + # noqa r8=000000e87fbfc140 r9=000000000001fffc r10=0000000000000649 + # noqa r11=00000285ef25a000 r12=00007ff74d9239a0 r13=fffa7fffffffffff + # r14=000000e87fbfd0e0 r15=0000000000000003 + # iopl=0 nv up ei pl nz na pe nc + # noqa cs=0033 ss=002b ds=002b es=002b fs=0053 gs=002b efl=00010200 + if line.startswith("cs="): + inEcxrData = False + continue + + # Crash address + # Extract the eip/rip specifically for use later + if "eip=" in line: + address = line.split("eip=")[1].split()[0] + self.crashAddress = int(address, 16) + elif "rip=" in line: + address = line.split("rip=")[1].split()[0] + self.crashAddress = int(address, 16) + + # First extract the line + # 32-bit example: + # noqa eax=02efff01 ebx=016fddb8 ecx=2b2b2b2b edx=016fe490 esi=02e00310 edi=02e00310 + # 64-bit example: + # noqa rax=00007ff74d8fee30 rbx=00000285ef400420 rcx=2b2b2b2b2b2b2b2b + matchLine = re.search(RegisterHelper.getRegisterPattern(), line) + if matchLine is not None: + ecxrData.extend(line.split()) + + # Next, put the eax/rax, ebx/rbx, etc. entries into a list of their + # own, then iterate + match = re.search(cdbRegisterPattern, line) + for instr in ecxrData: + match = re.search(cdbRegisterPattern, instr) + if match is not None: + register = match.group(1) + value = int(match.group(2), 16) + self.registers[register] = value + + # Crash instruction + # Start of crash instruction details + if line.startswith("FAULTING_IP"): + inCrashInstruction = True continue - # Ignore cdb error and empty lines - if "Following frames may be wrong." in line or line.strip() == "": + if inCrashInstruction and not cInstruction: + if "PROCESS_NAME" in line: + inCrashInstruction = False + continue + + # 64-bit binaries have a backtick in their addresses, + # e.g. 00007ff7`1e424e62 + lineWithoutBacktick = line.replace("`", "", 1) + if address and lineWithoutBacktick.startswith(address): + # 32-bit examples: + # 25d80b01 cc int 3 + # 00404c59 8b39 mov edi,dword ptr [ecx] + # 64-bit example: + # noqa 00007ff7`4d469ff3 4c8b01 mov r8,qword ptr [rcx] + cInstruction = line.split(None, 2)[-1] + # There may be multiple spaces inside the faulting instruction + cInstruction = " ".join(cInstruction.split()) + self.crashInstruction = cInstruction + + # Start of stack for crashing thread + if line.startswith("STACK_TEXT:"): + inCrashingThread = True continue - stackEntry = CDBCrashInfo.removeFilenameAndOffset(line) - stackEntry = CrashInfo.sanitizeStackFrame(stackEntry) - self.backtrace.append(stackEntry) + if inCrashingThread: + # 32-bit example: + # noqa "016fdc38 004b2387 01e104e8 016fe490 016fe490 js_32_dm_windows_62f79d676e0e!JSObject::allocKindForTenure+0x9" + # 64-bit example: + # noqa "000000e8`7fbfc040 00007ff7`4d53a984 : 000000e8`7fbfc2c0 00000285`ef7bb400 00000285`ef21b000 00007ff7`4d4254b9 : js_64_dm_windows_62f79d676e0e!JSObject::allocKindForTenure+0x13" + if ( + "STACK_COMMAND" in line + or "SYMBOL_NAME" in line + or "THREAD_SHA1_HASH_MOD_FUNC" in line + or "FAULTING_SOURCE_CODE" in line + ): + inCrashingThread = False + continue + + # Ignore cdb error and empty lines + if "Following frames may be wrong." in line or line.strip() == "": + continue + + stackEntry = CDBCrashInfo.removeFilenameAndOffset(line) + stackEntry = CrashInfo.sanitizeStackFrame(stackEntry) + self.backtrace.append(stackEntry) @staticmethod - def removeFilenameAndOffset(stackEntry): + def removeFilenameAndOffset(stackEntry: str) -> str: # Extract only the function name if "0x" in stackEntry: result = ( @@ -1809,7 +1875,13 @@ class RustCrashInfo(CrashInfo): r"(::h[0-9a-f]{16})?|\s+at ([A-Za-z]:)?(/[A-Za-z0-9_ .]+)+:\d+)$" ) - def __init__(self, stdout, stderr, configuration, crashData=None): + def __init__( + self, + stdout: list[str] | None, + stderr: list[str] | None, + configuration: ProgramConfiguration, + crashData: list[str] | None = None, + ): """ Private constructor, called by L{CrashInfo.fromRawCrashData}. Do not use directly. @@ -1837,24 +1909,31 @@ def __init__(self, stdout, stderr, configuration, crashData=None): inAssertion = False inBacktrace = False - for line in rustOutput: - # Start of stack - if not inAssertion: - if AssertionHelper.RE_RUST_ASSERT.match(line) is not None: - inAssertion = True - continue - - frame = self.RE_FRAME.match(line) - if frame is None and inBacktrace: - break - elif frame is not None: - inBacktrace = True - if frame.group("symbol"): - self.backtrace.append(frame.group("symbol")) + if rustOutput is not None: + for line in rustOutput: + # Start of stack + if not inAssertion: + if AssertionHelper.RE_RUST_ASSERT.match(line) is not None: + inAssertion = True + continue + + frame = self.RE_FRAME.match(line) + if frame is None and inBacktrace: + break + elif frame is not None: + inBacktrace = True + if frame.group("symbol"): + self.backtrace.append(frame.group("symbol")) class TSanCrashInfo(CrashInfo): - def __init__(self, stdout, stderr, configuration, crashData=None): + def __init__( + self, + stdout: list[str] | None, + stderr: list[str] | None, + configuration: ProgramConfiguration, + crashData: list[str] | None = None, + ): """ Private constructor, called by L{CrashInfo.fromRawCrashData}. Do not use directly. @@ -1887,60 +1966,61 @@ def __init__(self, stdout, stderr, configuration, crashData=None): isDataRace = False backtraces = [] currentBacktrace = None - for traceLine in tsanOutput: - if not reportFound: - match = re.search(tsanWarningPattern, traceLine) - if match is not None: - self.tsanWarnLine = traceLine.strip() - reportFound = True - isDataRace = "data race" in self.tsanWarnLine - continue - - if "[failed to restore the stack]" in traceLine: - # TSan failed to symbolize at least one stack - brokenStack = True - - index, parts = ASanCrashInfo.split_frame(traceLine) - if index is None: - continue - - # We may see multiple traces in TSAN - if index == 0: - if currentBacktrace is not None: - backtraces.append(currentBacktrace) - currentBacktrace = [] - expectedIndex = 0 - - if not expectedIndex == index: - raise RuntimeError( - f"Fatal error parsing TSan trace (Index mismatch, got index {index}" - f" but expected {expectedIndex})" - ) + if tsanOutput is not None: + for traceLine in tsanOutput: + if not reportFound: + match = re.search(tsanWarningPattern, traceLine) + if match is not None: + self.tsanWarnLine = traceLine.strip() + reportFound = True + isDataRace = "data race" in self.tsanWarnLine + continue - component = None - if len(parts) > 2: - # TSan has a different trace style than other sanitizers: - # TSan uses: - # #0 function name filename:line:col (bin+0xaddr) - # ASan uses: - # #0 0xaddr in function name filename:line - component = " ".join(parts[1:-2]) - - if component == "" and len(parts) > 3: - # TSan uses to indicate missing symbols, e.g. - # #1 (libXext.so.6+0xcc89) - # Remove parentheses around component - component = parts[3][1:-1] - else: - print( - f"Warning: Missing component in this line: {traceLine}", - file=sys.stderr, - ) - component = "" + if "[failed to restore the stack]" in traceLine: + # TSan failed to symbolize at least one stack + brokenStack = True - currentBacktrace.append(CrashInfo.sanitizeStackFrame(component)) + index, parts = ASanCrashInfo.split_frame(traceLine) + if index is None: + continue + + # We may see multiple traces in TSAN + if index == 0: + if currentBacktrace is not None: + backtraces.append(currentBacktrace) + currentBacktrace = [] + expectedIndex = 0 + + if not expectedIndex == index: + raise RuntimeError( + "Fatal error parsing TSan trace (Index mismatch, got index " + f"{index} but expected {expectedIndex})" + ) - expectedIndex += 1 + component = None + if len(parts) > 2: + # TSan has a different trace style than other sanitizers: + # TSan uses: + # #0 function name filename:line:col (bin+0xaddr) + # ASan uses: + # #0 0xaddr in function name filename:line + component = " ".join(parts[1:-2]) + + if component == "" and len(parts) > 3: + # TSan uses to indicate missing symbols, e.g. + # #1 (libXext.so.6+0xcc89) + # Remove parentheses around component + component = parts[3][1:-1] + else: + print( + f"Warning: Missing component in this line: {traceLine}", + file=sys.stderr, + ) + component = "" + + currentBacktrace.append(CrashInfo.sanitizeStackFrame(component)) + + expectedIndex += 1 if currentBacktrace is not None: backtraces.append(currentBacktrace) @@ -1957,9 +2037,8 @@ def __init__(self, stdout, stderr, configuration, crashData=None): self.tsanIndexZero.append(backtrace[0]) self.backtrace.extend(backtrace) - def createShortSignature(self): + def createShortSignature(self) -> str: """ - @rtype: String @return: A string representing this crash (short signature) """ if self.tsanWarnLine: @@ -2011,7 +2090,13 @@ class ValgrindCrashInfo(CrashInfo): re.VERBOSE, ) - def __init__(self, stdout, stderr, configuration, crashData=None): + def __init__( + self, + stdout: list[str] | None, + stderr: list[str] | None, + configuration: ProgramConfiguration, + crashData: list[str] | None = None, + ): """ Private constructor, called by L{CrashInfo.fromRawCrashData}. Do not use directly. @@ -2043,40 +2128,40 @@ def __init__(self, stdout, stderr, configuration, crashData=None): ) foundStart = False - for traceLine in vgdOutput: - if not traceLine.startswith("=="): - # skip unrelated noise - continue - elif not foundStart: - if re.match(self.MSG_REGEX, traceLine) is not None: - # skip other lines that are not part of a recognized trace - foundStart = True - # continue search for the beginning of the stack trace - continue - - lineInfo = re.match(stackPattern, traceLine) - if lineInfo is not None: - lineFunc = lineInfo.group("func") - # if function name is not available use the file name instead - if lineFunc == "???": - lineFunc = lineInfo.group("file") - self.backtrace.append(CrashInfo.sanitizeStackFrame(lineFunc)) - - elif self.backtrace: - # check if address info is available - addr = re.match( - r"^==\d+==\s+Address\s(?P0x[0-9A-Fa-f]+)\s", traceLine - ) - if addr: - self.crashAddress = int(addr.group("addr"), 16) - # look for '==PID== \n' to indicate the end of a trace - if re.match(r"^==\d+==\s+$", traceLine) is not None: - # done parsing - break + if vgdOutput is not None: + for traceLine in vgdOutput: + if not traceLine.startswith("=="): + # skip unrelated noise + continue + elif not foundStart: + if re.match(self.MSG_REGEX, traceLine) is not None: + # skip other lines that are not part of a recognized trace + foundStart = True + # continue search for the beginning of the stack trace + continue + + lineInfo = re.match(stackPattern, traceLine) + if lineInfo is not None: + lineFunc = lineInfo.group("func") + # if function name is not available use the file name instead + if lineFunc == "???": + lineFunc = lineInfo.group("file") + self.backtrace.append(CrashInfo.sanitizeStackFrame(lineFunc)) + + elif self.backtrace: + # check if address info is available + addr = re.match( + r"^==\d+==\s+Address\s(?P0x[0-9A-Fa-f]+)\s", traceLine + ) + if addr: + self.crashAddress = int(addr.group("addr"), 16) + # look for '==PID== \n' to indicate the end of a trace + if re.match(r"^==\d+==\s+$", traceLine) is not None: + # done parsing + break - def createShortSignature(self): + def createShortSignature(self) -> str: """ - @rtype: String @return: A string representing this crash (short signature) """ diff --git a/FTB/Signatures/CrashSignature.py b/FTB/Signatures/CrashSignature.py index 5328d6316..ede833b60 100644 --- a/FTB/Signatures/CrashSignature.py +++ b/FTB/Signatures/CrashSignature.py @@ -15,10 +15,17 @@ @contact: choller@mozilla.com """ +from __future__ import annotations + import difflib import json +from pathlib import Path +from typing import Sequence + +from typing_extensions import NotRequired, TypedDict from FTB.Signatures import JSONHelper +from FTB.Signatures.CrashInfo import CrashInfo from FTB.Signatures.Symptom import ( OutputSymptom, StackFramesSymptom, @@ -27,12 +34,19 @@ ) +class SymptomsDiffType(TypedDict): + """Type information for SymptomsDiff""" + + offending: bool + proposed: NotRequired[Symptom] + symptom: Symptom + + class CrashSignature: - def __init__(self, rawSignature): + def __init__(self, rawSignature: str) -> None: """ Constructor - @type rawSignature: string @param rawSignature: A JSON-formatted string representing the crash signature """ @@ -66,23 +80,22 @@ def __init__(self, rawSignature): self.products = JSONHelper.getArrayChecked(obj, "products") @staticmethod - def fromFile(signatureFile): + def fromFile(signatureFile: Path | str) -> CrashSignature: with open(signatureFile) as sigFd: return CrashSignature(sigFd.read()) - def __str__(self): + def __str__(self) -> str: return str(self.rawSignature) - def matches(self, crashInfo): + def matches(self, crashInfo: CrashInfo) -> bool: """ Match this signature against the given crash information - @type crashInfo: CrashInfo @param crashInfo: The crash info to match the signature against - @rtype: bool @return: True if the signature matches, False otherwise """ + assert crashInfo.configuration is not None if ( self.platforms is not None and crashInfo.configuration.platform not in self.platforms @@ -121,14 +134,13 @@ def matches(self, crashInfo): return True - def matchRequiresTest(self): + def matchRequiresTest(self) -> bool: """ Check if the signature requires a testcase to match. This method can be used to avoid attaching a testcase to the crashInfo before matching, avoiding unnecessary I/O on testcase files. - @rtype: bool @return: True if the signature requires a testcase to match """ for symptom in self.symptoms: @@ -137,14 +149,13 @@ def matchRequiresTest(self): return False - def getRequiredOutputSources(self): + def getRequiredOutputSources(self) -> list[str]: """ Return a list of output sources required by this signature for matching. This method can be used to avoid loading raw output fields from the database if they are not required for the particular signature. - @rtype: list(str) @return: A list of output identifiers (e.g. stdout, stderr or crashdata) required by this signature. """ @@ -160,7 +171,7 @@ def getRequiredOutputSources(self): return ret - def getDistance(self, crashInfo): + def getDistance(self, crashInfo: CrashInfo) -> int: distance = 0 for symptom in self.symptoms: @@ -175,6 +186,7 @@ def getDistance(self, crashInfo): if not symptom.matches(crashInfo): distance += 1 + assert crashInfo.configuration is not None if ( self.platforms is not None and crashInfo.configuration.platform not in self.platforms @@ -195,9 +207,9 @@ def getDistance(self, crashInfo): return distance - def fit(self, crashInfo): + def fit(self, crashInfo: CrashInfo) -> CrashSignature | None: sigObj = {} - sigSymptoms = [] + sigSymptoms: list[dict[str, object]] = [] sigObj["symptoms"] = sigSymptoms @@ -224,8 +236,8 @@ def fit(self, crashInfo): return CrashSignature(json.dumps(sigObj, indent=2, sort_keys=True)) - def getSymptomsDiff(self, crashInfo): - symptomsDiff = [] + def getSymptomsDiff(self, crashInfo: CrashInfo) -> list[SymptomsDiffType]: + symptomsDiff: list[SymptomsDiffType] = [] for symptom in self.symptoms: if symptom.matches(crashInfo): symptomsDiff.append({"offending": False, "symptom": symptom}) @@ -249,7 +261,9 @@ def getSymptomsDiff(self, crashInfo): symptomsDiff.append({"offending": True, "symptom": symptom}) return symptomsDiff - def getSignatureUnifiedDiffTuples(self, crashInfo): + def getSignatureUnifiedDiffTuples( + self, crashInfo: CrashInfo + ) -> Sequence[tuple[str, list[str] | str]]: diffTuples = [] # go through dumps(loads()) to standardize the format. diff --git a/FTB/Signatures/JSONHelper.py b/FTB/Signatures/JSONHelper.py index f2224a856..79a08bdc7 100644 --- a/FTB/Signatures/JSONHelper.py +++ b/FTB/Signatures/JSONHelper.py @@ -13,20 +13,20 @@ @contact: choller@mozilla.com """ + +from __future__ import annotations + import numbers -def getArrayChecked(obj, key, mandatory=False): +def getArrayChecked(obj, key: str, mandatory: bool = False): """ Retrieve a list from the given object using the given key @type obj: map @param obj: Source object - @type key: string @param key: Key to retrieve from obj - - @type mandatory: bool @param mandatory: If True, throws an exception if the key is not found @rtype: list @@ -35,17 +35,14 @@ def getArrayChecked(obj, key, mandatory=False): return __getTypeChecked(obj, key, [list], mandatory) -def getStringChecked(obj, key, mandatory=False): +def getStringChecked(obj, key: str, mandatory: bool = False): """ Retrieve a string from the given object using the given key @type obj: map @param obj: Source object - @type key: string @param key: Key to retrieve from obj - - @type mandatory: bool @param mandatory: If True, throws an exception if the key is not found @rtype: string @@ -54,17 +51,14 @@ def getStringChecked(obj, key, mandatory=False): return __getTypeChecked(obj, key, [str, bytes], mandatory) -def getNumberChecked(obj, key, mandatory=False): +def getNumberChecked(obj, key: str, mandatory: bool = False): """ Retrieve an integer from the given object using the given key @type obj: map @param obj: Source object - @type key: string @param key: Key to retrieve from obj - - @type mandatory: bool @param mandatory: If True, throws an exception if the key is not found @rtype: int @@ -73,17 +67,14 @@ def getNumberChecked(obj, key, mandatory=False): return __getTypeChecked(obj, key, [numbers.Integral], mandatory) -def getObjectOrStringChecked(obj, key, mandatory=False): +def getObjectOrStringChecked(obj, key: str, mandatory: bool = False): """ Retrieve an object or string from the given object using the given key @type obj: map @param obj: Source object - @type key: string @param key: Key to retrieve from obj - - @type mandatory: bool @param mandatory: If True, throws an exception if the key is not found @rtype: string or dict @@ -92,17 +83,14 @@ def getObjectOrStringChecked(obj, key, mandatory=False): return __getTypeChecked(obj, key, [str, bytes, dict], mandatory) -def getNumberOrStringChecked(obj, key, mandatory=False): +def getNumberOrStringChecked(obj, key: str, mandatory: bool = False): """ Retrieve a number or string from the given object using the given key @type obj: map @param obj: Source object - @type key: string @param key: Key to retrieve from obj - - @type mandatory: bool @param mandatory: If True, throws an exception if the key is not found @rtype: string or number @@ -111,7 +99,7 @@ def getNumberOrStringChecked(obj, key, mandatory=False): return __getTypeChecked(obj, key, [str, bytes, numbers.Integral], mandatory) -def __getTypeChecked(obj, key, valTypes, mandatory=False): +def __getTypeChecked(obj, key: str, valTypes, mandatory: bool = False): if key not in obj: if mandatory: raise RuntimeError(f'Expected key "{key}" in object') diff --git a/FTB/Signatures/Matchers.py b/FTB/Signatures/Matchers.py index 450023a81..2987b60c8 100644 --- a/FTB/Signatures/Matchers.py +++ b/FTB/Signatures/Matchers.py @@ -13,6 +13,9 @@ @contact: choller@mozilla.com """ + +from __future__ import annotations + import numbers import re from abc import ABCMeta, abstractmethod @@ -22,14 +25,14 @@ class Match(metaclass=ABCMeta): @abstractmethod - def matches(self, value): + def matches(self, value: bytes | int | str | None) -> re.Match[str] | bool | None: pass class StringMatch(Match): - def __init__(self, obj): + def __init__(self, obj: bytes | int | numbers.Integral | str) -> None: self.isPCRE = False - self.compiledValue = None + self.compiledValue: re.Pattern[str] | None = None self.patternContainsSlash = False if isinstance(obj, bytes): @@ -63,28 +66,34 @@ def __init__(self, obj): else: raise RuntimeError(f"Unknown match operator specified: {matchType}") - def matches(self, value, windowsSlashWorkaround=False): + def matches( + self, value: bytes | int | str | None, windowsSlashWorkaround: bool = False + ) -> re.Match[str] | bool | None: if isinstance(value, bytes): # If the input is not already unicode, try to interpret it as UTF-8 # If there are errors, replace them with U+FFFD so we neither raise nor # false positive. - value = value.decode("utf-8", errors="replace") + value_decoded = value.decode("utf-8", errors="replace") if self.isPCRE: - if self.compiledValue.search(value) is not None: + assert self.compiledValue is not None + if self.compiledValue.search(value_decoded) is not None: return True elif windowsSlashWorkaround and self.patternContainsSlash: # NB this will fail if the pattern is supposed to match a backslash and # a windows-style path in the same line - return self.compiledValue.search(value.replace("\\", "/")) is not None + return ( + self.compiledValue.search(value_decoded.replace("\\", "/")) + is not None + ) return False else: - return self.value in value + return self.value in value_decoded - def __str__(self): + def __str__(self) -> str: return self.value - def __repr__(self): + def __repr__(self) -> str: if self.isPCRE: return f"/{self.value}/" @@ -96,8 +105,9 @@ class NumberMatchType: class NumberMatch(Match): - def __init__(self, obj): - self.matchType = None + def __init__(self, obj: bytes | int | numbers.Integral | str) -> None: + self.matchType: int | None = None + self.value: int | numbers.Integral | None if isinstance(obj, bytes): obj = obj.decode("utf-8") @@ -144,10 +154,12 @@ def __init__(self, obj): else: raise RuntimeError(f"Invalid type {type(obj)} in NumberMatch.") - def matches(self, value): + def matches(self, value: bytes | int | str | None) -> bool: if value is None: return self.value is None + assert isinstance(value, int) + assert isinstance(self.value, int) if self.matchType == NumberMatchType.GE: return value >= self.value elif self.matchType == NumberMatchType.GT: diff --git a/FTB/Signatures/RegisterHelper.py b/FTB/Signatures/RegisterHelper.py index 23e745002..06a264bc8 100644 --- a/FTB/Signatures/RegisterHelper.py +++ b/FTB/Signatures/RegisterHelper.py @@ -12,6 +12,8 @@ @contact: choller@mozilla.com """ +from __future__ import annotations + x86Registers = ["eax", "ebx", "ecx", "edx", "esi", "edi", "ebp", "esp", "eip"] x64Registers = [ @@ -77,7 +79,7 @@ } -def getRegisterPattern(): +def getRegisterPattern() -> str: """ Return a pattern including all register names that are considered valid """ @@ -89,14 +91,12 @@ def getRegisterPattern(): ) -def getStackPointer(registerMap): +def getStackPointer(registerMap: dict[str, int]) -> int: """ Return the stack pointer value from the given register map - @type registerMap: map @param registerMap: Map of register names to value - @rtype: int @return: The value of the stack pointer """ @@ -107,14 +107,12 @@ def getStackPointer(registerMap): raise RuntimeError("Register map does not contain a usable stack pointer!") -def getInstructionPointer(registerMap): +def getInstructionPointer(registerMap: dict[str, int]) -> int: """ Return the instruction pointer value from the given register map - @type registerMap: map @param registerMap: Map of register names to value - @rtype: int @return: The value of the instruction pointer """ @@ -125,18 +123,15 @@ def getInstructionPointer(registerMap): raise RuntimeError("Register map does not contain a usable instruction pointer!") -def getRegisterValue(register, registerMap): +def getRegisterValue(register: str, registerMap: dict[str, int]) -> int | None: """ Return the value of the specified register using the provided register map. This method also works for getting lower register parts out of higher ones. - @type register: string @param register: The register to get the value for - @type registerMap: map @param registerMap: Map of register names to values - @rtype: int @return: The register value """ @@ -196,14 +191,12 @@ def getRegisterValue(register, registerMap): return None -def getBitWidth(registerMap): +def getBitWidth(registerMap: dict[str, int]) -> int: """ Return the bit width (32 or 64 bit) given the registers - @type registerMap: map @param registerMap: Map of register names to value - @rtype: int @return: The bit width """ if "rax" in registerMap or "x0" in registerMap: @@ -212,15 +205,13 @@ def getBitWidth(registerMap): return 32 -def isX86Compatible(registerMap): +def isX86Compatible(registerMap: dict[str, int]) -> bool: """ Return true, if the the given registers are X86 compatible, such as x86 or x86-64. ARM, PPC and your PDP-15 will fail this check and we don't support it right now. - @type registerMap: map @param registerMap: Map of register names to value - @rtype: bool @return: True if the architecture is X86 compatible, False otherwise """ for register in x86OnlyRegisters: @@ -229,14 +220,12 @@ def isX86Compatible(registerMap): return False -def isARMCompatible(registerMap): +def isARMCompatible(registerMap: dict[str, int]) -> bool: """ Return true, if the the given registers are either ARM or ARM64. - @type registerMap: map @param registerMap: Map of register names to value - @rtype: bool @return: True if the architecture is ARM/ARM64 compatible, False otherwise """ for register in armOnlyRegisters: diff --git a/FTB/Signatures/Symptom.py b/FTB/Signatures/Symptom.py index c886e93fa..f41832795 100644 --- a/FTB/Signatures/Symptom.py +++ b/FTB/Signatures/Symptom.py @@ -13,10 +13,15 @@ @contact: choller@mozilla.com """ + +from __future__ import annotations + import json from abc import ABCMeta, abstractmethod +from typing import Sequence from FTB.Signatures import JSONHelper +from FTB.Signatures.CrashInfo import CrashInfo from FTB.Signatures.Matchers import NumberMatch, StringMatch @@ -26,23 +31,22 @@ class Symptom(metaclass=ABCMeta): It also supports generating a CrashSignature based on the stored information. """ - def __init__(self, jsonObj): + def __init__(self, jsonObj: dict[str, object]) -> None: # Store the original source so we can return it if someone wants to stringify us self.jsonsrc = json.dumps(jsonObj, indent=2) self.jsonobj = jsonObj - def __str__(self): + def __str__(self) -> str: return self.jsonsrc @staticmethod - def fromJSONObject(obj): + def fromJSONObject(obj: dict[str, object]) -> Symptom: """ Create the appropriate Symptom based on the given object (decoded from JSON) @type obj: map @param obj: Object as decoded from JSON - @rtype: Symptom @return: Symptom subclass instance matching the given object """ if "type" not in obj: @@ -68,21 +72,19 @@ def fromJSONObject(obj): raise RuntimeError(f"Unknown symptom type: {stype}") @abstractmethod - def matches(self, crashInfo): + def matches(self, crashInfo: CrashInfo) -> bool: """ Check if the symptom matches the given crash information - @type crashInfo: CrashInfo @param crashInfo: The crash information to check against - @rtype: bool @return: True if the symptom matches, False otherwise """ return class OutputSymptom(Symptom): - def __init__(self, obj): + def __init__(self, obj: dict[str, object]) -> None: """ Private constructor, called by L{Symptom.fromJSONObject}. Do not use directly. """ @@ -101,14 +103,12 @@ def __init__(self, obj): ): raise RuntimeError(f"Invalid source specified: {self.src}") - def matches(self, crashInfo): + def matches(self, crashInfo: CrashInfo) -> bool: """ Check if the symptom matches the given crash information - @type crashInfo: CrashInfo @param crashInfo: The crash information to check against - @rtype: bool @return: True if the symptom matches, False otherwise """ checkedOutput = [] @@ -124,6 +124,7 @@ def matches(self, crashInfo): else: checkedOutput = crashInfo.rawCrashData + assert crashInfo.configuration is not None windowsSlashWorkaround = crashInfo.configuration.os == "windows" for line in reversed(checkedOutput): if self.output.matches(line, windowsSlashWorkaround=windowsSlashWorkaround): @@ -133,7 +134,7 @@ def matches(self, crashInfo): class StackFrameSymptom(Symptom): - def __init__(self, obj): + def __init__(self, obj: dict[str, object]) -> None: """ Private constructor, called by L{Symptom.fromJSONObject}. Do not use directly. """ @@ -149,14 +150,12 @@ def __init__(self, obj): # Default to 0 self.frameNumber = NumberMatch(0) - def matches(self, crashInfo): + def matches(self, crashInfo: CrashInfo) -> bool: """ Check if the symptom matches the given crash information - @type crashInfo: CrashInfo @param crashInfo: The crash information to check against - @rtype: bool @return: True if the symptom matches, False otherwise """ @@ -170,7 +169,7 @@ def matches(self, crashInfo): class StackSizeSymptom(Symptom): - def __init__(self, obj): + def __init__(self, obj: dict[str, object]) -> None: """ Private constructor, called by L{Symptom.fromJSONObject}. Do not use directly. """ @@ -179,21 +178,19 @@ def __init__(self, obj): JSONHelper.getNumberOrStringChecked(obj, "size", True) ) - def matches(self, crashInfo): + def matches(self, crashInfo: CrashInfo) -> bool: """ Check if the symptom matches the given crash information - @type crashInfo: CrashInfo @param crashInfo: The crash information to check against - @rtype: bool @return: True if the symptom matches, False otherwise """ return self.stackSize.matches(len(crashInfo.backtrace)) class CrashAddressSymptom(Symptom): - def __init__(self, obj): + def __init__(self, obj: dict[str, object]) -> None: """ Private constructor, called by L{Symptom.fromJSONObject}. Do not use directly. """ @@ -202,14 +199,12 @@ def __init__(self, obj): JSONHelper.getNumberOrStringChecked(obj, "address", True) ) - def matches(self, crashInfo): + def matches(self, crashInfo: CrashInfo) -> bool: """ Check if the symptom matches the given crash information - @type crashInfo: CrashInfo @param crashInfo: The crash information to check against - @rtype: bool @return: True if the symptom matches, False otherwise """ # In case the crash address is not available, @@ -218,7 +213,7 @@ def matches(self, crashInfo): class InstructionSymptom(Symptom): - def __init__(self, obj): + def __init__(self, obj: dict[str, object]) -> None: """ Private constructor, called by L{Symptom.fromJSONObject}. Do not use directly. """ @@ -235,14 +230,12 @@ def __init__(self, obj): "Must provide at least instruction name or register names" ) - def matches(self, crashInfo): + def matches(self, crashInfo: CrashInfo) -> bool: """ Check if the symptom matches the given crash information - @type crashInfo: CrashInfo @param crashInfo: The crash information to check against - @rtype: bool @return: True if the symptom matches, False otherwise """ if crashInfo.crashInstruction is None: @@ -262,7 +255,7 @@ def matches(self, crashInfo): class TestcaseSymptom(Symptom): - def __init__(self, obj): + def __init__(self, obj: dict[str, object]) -> None: """ Private constructor, called by L{Symptom.fromJSONObject}. Do not use directly. """ @@ -271,14 +264,12 @@ def __init__(self, obj): JSONHelper.getObjectOrStringChecked(obj, "value", True) ) - def matches(self, crashInfo): + def matches(self, crashInfo: CrashInfo) -> bool: """ Check if the symptom matches the given crash information - @type crashInfo: CrashInfo @param crashInfo: The crash information to check against - @rtype: bool @return: True if the symptom matches, False otherwise """ @@ -296,7 +287,7 @@ def matches(self, crashInfo): class StackFramesSymptom(Symptom): - def __init__(self, obj): + def __init__(self, obj: dict[str, object]) -> None: """ Private constructor, called by L{Symptom.fromJSONObject}. Do not use directly. """ @@ -308,20 +299,20 @@ def __init__(self, obj): for fn in rawFunctionNames: self.functionNames.append(StringMatch(fn)) - def matches(self, crashInfo): + def matches(self, crashInfo: CrashInfo) -> bool: """ Check if the symptom matches the given crash information - @type crashInfo: CrashInfo @param crashInfo: The crash information to check against - @rtype: bool @return: True if the symptom matches, False otherwise """ return StackFramesSymptom._match(crashInfo.backtrace, self.functionNames) - def diff(self, crashInfo): + def diff( + self, crashInfo: CrashInfo + ) -> tuple[int | None, StackFramesSymptom | None]: if self.matches(crashInfo): return (0, None) @@ -330,6 +321,7 @@ def diff(self, crashInfo): crashInfo.backtrace, self.functionNames, 0, 1, depth ) if bestDepth is not None: + assert bestGuess is not None guessedFunctionNames = [repr(x) for x in bestGuess] # Remove trailing wildcards as they are of no use @@ -353,7 +345,13 @@ def diff(self, crashInfo): return (None, None) @staticmethod - def _diff(stack, signatureGuess, startIdx, depth, maxDepth): + def _diff( + stack: list[str], + signatureGuess: list[StringMatch], + startIdx: int, + depth: int, + maxDepth: int, + ) -> tuple[int | None, list[StringMatch] | None]: singleWildcardMatch = StringMatch("?") newSignatureGuess = [] @@ -448,7 +446,10 @@ def _diff(stack, signatureGuess, startIdx, depth, maxDepth): return (bestDepth, bestGuess) @staticmethod - def _match(partialStack, partialFunctionNames): + def _match( + partialStack: Sequence[StringMatch | str], + partialFunctionNames: Sequence[StringMatch | str], + ) -> bool: while True: @@ -459,6 +460,7 @@ def _match(partialStack, partialFunctionNames): and partialStack and str(partialFunctionNames[0]) not in {"?", "???"} ): + assert isinstance(partialFunctionNames[0], StringMatch) if not partialFunctionNames[0].matches(partialStack[0]): return False diff --git a/FTB/Signatures/tests/test_CrashInfo.py b/FTB/Signatures/tests/test_CrashInfo.py index 89fef24e5..382f781f9 100644 --- a/FTB/Signatures/tests/test_CrashInfo.py +++ b/FTB/Signatures/tests/test_CrashInfo.py @@ -11,6 +11,9 @@ @contact: choller@mozilla.com """ + +from __future__ import annotations + import json from pathlib import Path @@ -34,7 +37,7 @@ FIXTURE_PATH = Path(__file__).parent / "fixtures" -def test_ASanParserTestAccessViolation(): +def test_ASanParserTestAccessViolation() -> None: config = ProgramConfiguration("test", "x86-64", "windows") crashInfo = ASanCrashInfo( @@ -53,7 +56,7 @@ def test_ASanParserTestAccessViolation(): assert crashInfo.registers["bp"] == 0x00F9915F0A20 -def test_ASanParserTestCrash(): +def test_ASanParserTestCrash() -> None: config = ProgramConfiguration("test", "x86", "linux") crashInfo = ASanCrashInfo( @@ -75,7 +78,7 @@ def test_ASanParserTestCrash(): assert crashInfo.registers["bp"] == 0xFFC57F18 -def test_ASanParserTestCrashWithWarning(): +def test_ASanParserTestCrashWithWarning() -> None: config = ProgramConfiguration("test", "x86", "linux") crashInfo = ASanCrashInfo( @@ -94,7 +97,7 @@ def test_ASanParserTestCrashWithWarning(): assert crashInfo.registers["bp"] == 0xFFC57F18 -def test_ASanParserTestFailedAlloc(): +def test_ASanParserTestFailedAlloc() -> None: config = ProgramConfiguration("test", "x86-64", "linux") crashInfo = ASanCrashInfo( @@ -115,7 +118,7 @@ def test_ASanParserTestFailedAlloc(): ) == crashInfo.createShortSignature() -def test_ASanParserTestAllocSize(): +def test_ASanParserTestAllocSize() -> None: config = ProgramConfiguration("test", "x86-64", "linux") crashInfo = ASanCrashInfo( @@ -135,7 +138,7 @@ def test_ASanParserTestAllocSize(): ) == crashInfo.createShortSignature() -def test_ASanParserTestHeapCrash(): +def test_ASanParserTestHeapCrash() -> None: config = ProgramConfiguration("test", "x86", "linux") crashInfo = ASanCrashInfo( @@ -153,7 +156,7 @@ def test_ASanParserTestHeapCrash(): assert crashInfo.createShortSignature() == "[@ ??]" -def test_ASanParserTestUAF(): +def test_ASanParserTestUAF() -> None: config = ProgramConfiguration("test", "x86-64", "linux") crashInfo = ASanCrashInfo( @@ -171,7 +174,7 @@ def test_ASanParserTestUAF(): ) == crashInfo.createShortSignature() -def test_ASanParserTestInvalidFree(): +def test_ASanParserTestInvalidFree() -> None: config = ProgramConfiguration("test", "x86-64", "linux") crashInfo = ASanCrashInfo( @@ -190,7 +193,7 @@ def test_ASanParserTestInvalidFree(): ) == crashInfo.createShortSignature() -def test_ASanParserTestOOM(): +def test_ASanParserTestOOM() -> None: config = ProgramConfiguration("test", "x86-64", "linux") crashInfo = ASanCrashInfo( @@ -214,7 +217,7 @@ def test_ASanParserTestOOM(): ) == crashInfo.createShortSignature() -def test_ASanParserTestDebugAssertion(): +def test_ASanParserTestDebugAssertion() -> None: config = ProgramConfiguration("test", "x86-64", "linux") crashInfo = ASanCrashInfo( @@ -247,7 +250,9 @@ def test_ASanParserTestDebugAssertion(): (None, "trace_ubsan_generic_crash.txt"), ], ) -def test_ASanDetectionTest(stderr_path, crash_data_path): +def test_ASanDetectionTest( + stderr_path: str | None, crash_data_path: str | None +) -> None: config = ProgramConfiguration("test", "x86", "linux") stderr = "" if stderr_path is None else (FIXTURE_PATH / stderr_path).read_text() crash_data = ( @@ -262,7 +267,7 @@ def test_ASanDetectionTest(stderr_path, crash_data_path): assert isinstance(crashInfo, ASanCrashInfo) -def test_ASanParserTestParamOverlap(): +def test_ASanParserTestParamOverlap() -> None: config = ProgramConfiguration("test", "x86-64", "linux") crashInfo = ASanCrashInfo( @@ -293,7 +298,7 @@ def test_ASanParserTestParamOverlap(): ) -def test_ASanParserTestMultiTrace(): +def test_ASanParserTestMultiTrace() -> None: config = ProgramConfiguration("test", "x86-64", "linux") crashInfo = ASanCrashInfo( @@ -306,7 +311,7 @@ def test_ASanParserTestMultiTrace(): assert "[@ mozilla::ipc::Shmem::OpenExisting]" == crashInfo.createShortSignature() -def test_ASanParserTestTruncatedTrace(): +def test_ASanParserTestTruncatedTrace() -> None: config = ProgramConfiguration("test", "x86-64", "linux") crashInfo = ASanCrashInfo( @@ -320,10 +325,11 @@ def test_ASanParserTestTruncatedTrace(): # Confirm that generating a crash signature will fail crashSig = crashInfo.createCrashSignature() assert crashSig is None + assert crashInfo.failureReason is not None assert "Insufficient data" in crashInfo.failureReason -def test_ASanParserTestClang14(): +def test_ASanParserTestClang14() -> None: config = ProgramConfiguration("test", "x86-64", "linux") crashInfo = ASanCrashInfo( @@ -339,7 +345,7 @@ def test_ASanParserTestClang14(): assert "[@ raise]" == crashInfo.createShortSignature() -def test_GDBParserTestCrash(): +def test_GDBParserTestCrash() -> None: config = ProgramConfiguration("test", "x86", "linux") crashInfo = GDBCrashInfo( @@ -355,7 +361,7 @@ def test_GDBParserTestCrash(): assert crashInfo.registers["eip"] == 0x818BC33 -def test_GDBParserTestCrashAddress(): +def test_GDBParserTestCrashAddress() -> None: config = ProgramConfiguration("test", "x86-64", "linux") crashInfo1 = GDBCrashInfo( @@ -391,7 +397,7 @@ def test_GDBParserTestCrashAddress(): assert crashInfo5.crashAddress == 0x87AFA014 -def test_GDBParserTestCrashAddressSimple(): +def test_GDBParserTestCrashAddressSimple() -> None: registerMap64 = {} registerMap64["rax"] = 0x0 registerMap64["rbx"] = -1 @@ -452,7 +458,7 @@ def test_GDBParserTestCrashAddressSimple(): ) -def test_GDBParserTestRegression1(): +def test_GDBParserTestRegression1() -> None: config = ProgramConfiguration("test", "x86", "linux") crashInfo1 = GDBCrashInfo( @@ -465,7 +471,7 @@ def test_GDBParserTestRegression1(): assert crashInfo1.backtrace[1] == "js::SetPropertyIgnoringNamedGetter" -def test_GDBParserTestCrashAddressRegression2(): +def test_GDBParserTestCrashAddressRegression2() -> None: config = ProgramConfiguration("test", "x86", "linux") crashInfo2 = GDBCrashInfo( @@ -477,7 +483,7 @@ def test_GDBParserTestCrashAddressRegression2(): assert crashInfo2.crashAddress == 0xFFFD579C -def test_GDBParserTestCrashAddressRegression3(): +def test_GDBParserTestCrashAddressRegression3() -> None: config = ProgramConfiguration("test", "x86-64", "linux") crashInfo3 = GDBCrashInfo( @@ -489,7 +495,7 @@ def test_GDBParserTestCrashAddressRegression3(): assert crashInfo3.crashAddress == 0x7FFFFFFFFFFF -def test_GDBParserTestCrashAddressRegression4(): +def test_GDBParserTestCrashAddressRegression4() -> None: config = ProgramConfiguration("test", "x86-64", "linux") crashInfo4 = GDBCrashInfo( @@ -501,7 +507,7 @@ def test_GDBParserTestCrashAddressRegression4(): assert crashInfo4.crashAddress == 0x0 -def test_GDBParserTestCrashAddressRegression5(): +def test_GDBParserTestCrashAddressRegression5() -> None: config = ProgramConfiguration("test", "x86", "linux") crashInfo5 = GDBCrashInfo( @@ -513,7 +519,7 @@ def test_GDBParserTestCrashAddressRegression5(): assert crashInfo5.crashAddress == 0xFFFD573C -def test_GDBParserTestCrashAddressRegression6(): +def test_GDBParserTestCrashAddressRegression6() -> None: config = ProgramConfiguration("test", "x86", "linux") crashInfo6 = GDBCrashInfo( @@ -525,7 +531,7 @@ def test_GDBParserTestCrashAddressRegression6(): assert crashInfo6.crashAddress == 0xF7673132 -def test_GDBParserTestCrashAddressRegression7(): +def test_GDBParserTestCrashAddressRegression7() -> None: config = ProgramConfiguration("test", "x86", "linux") # This used to fail because CrashInfo.fromRawCrashData fails to detect a GDB trace @@ -539,7 +545,7 @@ def test_GDBParserTestCrashAddressRegression7(): assert crashInfo7.backtrace[1] == "js::ScopeIter::settle" -def test_GDBParserTestCrashAddressRegression8(): +def test_GDBParserTestCrashAddressRegression8() -> None: config = ProgramConfiguration("test", "x86", "linux") # This used to fail because CrashInfo.fromRawCrashData fails to detect a GDB trace @@ -559,7 +565,7 @@ def test_GDBParserTestCrashAddressRegression8(): assert crashInfo8.backtrace[5] == "js::jit::CheckICacheLocked" -def test_GDBParserTestCrashAddressRegression9(): +def test_GDBParserTestCrashAddressRegression9() -> None: config = ProgramConfiguration("test", "x86", "linux") crashInfo9 = CrashInfo.fromRawCrashData( @@ -571,7 +577,7 @@ def test_GDBParserTestCrashAddressRegression9(): assert crashInfo9.crashInstruction == "call 0x8120ca0" -def test_GDBParserTestCrashAddressRegression10(): +def test_GDBParserTestCrashAddressRegression10() -> None: config = ProgramConfiguration("test", "x86-64", "linux") crashInfo10 = CrashInfo.fromRawCrashData( @@ -584,7 +590,7 @@ def test_GDBParserTestCrashAddressRegression10(): assert crashInfo10.crashAddress == 0x7FF7F20C1F81 -def test_GDBParserTestCrashAddressRegression11(): +def test_GDBParserTestCrashAddressRegression11() -> None: config = ProgramConfiguration("test", "x86-64", "linux") crashInfo11 = CrashInfo.fromRawCrashData( @@ -597,7 +603,7 @@ def test_GDBParserTestCrashAddressRegression11(): assert crashInfo11.crashAddress == 0x7FF7F2091032 -def test_GDBParserTestCrashAddressRegression12(): +def test_GDBParserTestCrashAddressRegression12() -> None: config = ProgramConfiguration("test", "x86-64", "linux") crashInfo12 = CrashInfo.fromRawCrashData( @@ -612,7 +618,7 @@ def test_GDBParserTestCrashAddressRegression12(): assert crashInfo12.backtrace[3] == "CaptureStack" -def test_GDBParserTestCrashAddressRegression13(): +def test_GDBParserTestCrashAddressRegression13() -> None: config = ProgramConfiguration("test", "x86", "linux") crashInfo13 = CrashInfo.fromRawCrashData( @@ -630,7 +636,7 @@ def test_GDBParserTestCrashAddressRegression13(): assert crashInfo13.crashAddress == 0xE5E5E5F5 -def test_CrashSignatureOutputTest(): +def test_CrashSignatureOutputTest() -> None: config = ProgramConfiguration("test", "x86-64", "linux") crashSignature1 = '{ "symptoms" : [ { "type" : "output", "value" : "test" } ] }' @@ -646,7 +652,7 @@ def test_CrashSignatureOutputTest(): outputSignature1Neg = CrashSignature(crashSignature1Neg) outputSignature2 = CrashSignature(crashSignature2) - gdbOutput = [] + gdbOutput: list[str] = [] stdout = [] stderr = [] @@ -678,7 +684,7 @@ def test_CrashSignatureOutputTest(): assert outputSignature2.matches(crashInfo) -def test_CrashSignatureAddressTest(): +def test_CrashSignatureAddressTest() -> None: config = ProgramConfiguration("test", "x86-64", "linux") crashSignature1 = ( @@ -713,7 +719,7 @@ def test_CrashSignatureAddressTest(): assert not addressSig1Neg.matches(crashInfo3) -def test_CrashSignatureRegisterTest(): +def test_CrashSignatureRegisterTest() -> None: config = ProgramConfiguration("test", "x86-64", "linux") crashSignature1 = {"symptoms": [{"type": "instruction", "registerNames": ["r14"]}]} @@ -783,7 +789,7 @@ def test_CrashSignatureRegisterTest(): assert not instructionSig3.matches(crashInfo3) -def test_CrashSignatureStackFrameTest(): +def test_CrashSignatureStackFrameTest() -> None: config = ProgramConfiguration("test", "x86-64", "linux") crashSignature1 = { @@ -834,7 +840,7 @@ def test_CrashSignatureStackFrameTest(): assert not stackFrameSig2Neg.matches(crashInfo1) -def test_CrashSignatureStackSizeTest(): +def test_CrashSignatureStackSizeTest() -> None: config = ProgramConfiguration("test", "x86-64", "linux") crashSignature1 = '{ "symptoms" : [ { "type" : "stackSize", "size" : 8 } ] }' @@ -867,7 +873,7 @@ def test_CrashSignatureStackSizeTest(): assert not stackSizeSig2Neg.matches(crashInfo1) -def test_RegisterHelperValueTest(): +def test_RegisterHelperValueTest() -> None: registerMap = {"rax": 0xFFFFFFFFFFFFFE00, "rbx": 0x7FFFF79A7640} assert RegisterHelper.getRegisterValue("rax", registerMap) == 0xFFFFFFFFFFFFFE00 @@ -883,7 +889,7 @@ def test_RegisterHelperValueTest(): assert RegisterHelper.getRegisterValue("bl", registerMap) == 0x40 -def test_MinidumpParserTestCrash(): +def test_MinidumpParserTestCrash() -> None: config = ProgramConfiguration("test", "x86", "linux") crashInfo = MinidumpCrashInfo( @@ -899,7 +905,7 @@ def test_MinidumpParserTestCrash(): assert crashInfo.crashAddress == 0x3E800006ACB -def test_MinidumpSelectorTest(): +def test_MinidumpSelectorTest() -> None: config = ProgramConfiguration("test", "x86", "linux") crashData = (FIXTURE_PATH / "minidump-example.txt").read_text().splitlines() @@ -908,7 +914,7 @@ def test_MinidumpSelectorTest(): assert crashInfo.crashAddress == 0x3E800006ACB -def test_MinidumpFromMacOSTest(): +def test_MinidumpFromMacOSTest() -> None: config = ProgramConfiguration("test", "x86-64", "macosx") crashInfo = CrashInfo.fromRawCrashData( @@ -925,7 +931,7 @@ def test_MinidumpFromMacOSTest(): assert crashInfo.crashAddress == 0 -def test_AppleParserTestCrash(): +def test_AppleParserTestCrash() -> None: config = ProgramConfiguration("test", "x86-64", "macosx") crashInfo = AppleCrashInfo( @@ -955,7 +961,7 @@ def test_AppleParserTestCrash(): assert crashInfo.crashAddress == 0x00007FFF5F3FFF98 -def test_AppleSelectorTest(): +def test_AppleSelectorTest() -> None: config = ProgramConfiguration("test", "x86-64", "macosx") crashData = ( @@ -966,7 +972,7 @@ def test_AppleSelectorTest(): assert crashInfo.crashAddress == 0x00007FFF5F3FFF98 -def test_AppleLionParserTestCrash(): +def test_AppleLionParserTestCrash() -> None: config = ProgramConfiguration("test", "x86-64", "macosx64") crashInfo = AppleCrashInfo( @@ -1000,7 +1006,7 @@ def test_AppleLionParserTestCrash(): assert crashInfo.crashAddress == 0x0000000000000000 -def test_AppleLionSelectorTest(): +def test_AppleLionSelectorTest() -> None: config = ProgramConfiguration("test", "x86-64", "macosx64") crashData = ( @@ -1015,7 +1021,7 @@ def test_AppleLionSelectorTest(): # failure: # js_dbg_32_dm_windows_62f79d676e0e!js::GetBytecodeLength # 01814577 cc int 3 -def test_CDBParserTestCrash1a(): +def test_CDBParserTestCrash1a() -> None: config = ProgramConfiguration("test", "x86", "windows") crashInfo = CDBCrashInfo( @@ -1054,7 +1060,7 @@ def test_CDBParserTestCrash1a(): assert crashInfo.crashAddress == 0x01814577 -def test_CDBSelectorTest1a(): +def test_CDBSelectorTest1a() -> None: config = ProgramConfiguration("test", "x86", "windows") crashData = (FIXTURE_PATH / "cdb-1a-crashlog.txt").read_text().splitlines() @@ -1067,7 +1073,7 @@ def test_CDBSelectorTest1a(): # failure: # js_dbg_32_dm_windows_62f79d676e0e!js::GetBytecodeLength+47 # 01344577 cc int 3 -def test_CDBParserTestCrash1b(): +def test_CDBParserTestCrash1b() -> None: config = ProgramConfiguration("test", "x86", "windows") crashInfo = CDBCrashInfo( @@ -1106,7 +1112,7 @@ def test_CDBParserTestCrash1b(): assert crashInfo.crashAddress == 0x01344577 -def test_CDBSelectorTest1b(): +def test_CDBSelectorTest1b() -> None: config = ProgramConfiguration("test", "x86", "windows") crashData = (FIXTURE_PATH / "cdb-1b-crashlog.txt").read_text().splitlines() @@ -1119,7 +1125,7 @@ def test_CDBSelectorTest1b(): # failure: # js_dbg_64_dm_windows_62f79d676e0e!js::GetBytecodeLength # 00000001`40144e62 cc int 3 -def test_CDBParserTestCrash2a(): +def test_CDBParserTestCrash2a() -> None: config = ProgramConfiguration("test", "x86-64", "windows") crashInfo = CDBCrashInfo( @@ -1178,7 +1184,7 @@ def test_CDBParserTestCrash2a(): assert crashInfo.crashAddress == 0x0000000140144E62 -def test_CDBSelectorTest2a(): +def test_CDBSelectorTest2a() -> None: config = ProgramConfiguration("test", "x86-64", "windows") crashData = (FIXTURE_PATH / "cdb-2a-crashlog.txt").read_text().splitlines() @@ -1191,7 +1197,7 @@ def test_CDBSelectorTest2a(): # failure: # js_dbg_64_dm_windows_62f79d676e0e!js::GetBytecodeLength+52 # 00007ff7`1e424e62 cc int 3 -def test_CDBParserTestCrash2b(): +def test_CDBParserTestCrash2b() -> None: config = ProgramConfiguration("test", "x86-64", "windows") crashInfo = CDBCrashInfo( @@ -1250,7 +1256,7 @@ def test_CDBParserTestCrash2b(): assert crashInfo.crashAddress == 0x00007FF71E424E62 -def test_CDBSelectorTest2b(): +def test_CDBSelectorTest2b() -> None: config = ProgramConfiguration("test", "x86-64", "windows") crashData = (FIXTURE_PATH / "cdb-2b-crashlog.txt").read_text().splitlines() @@ -1262,7 +1268,7 @@ def test_CDBSelectorTest2b(): # Test 3a is for Win7 with 32-bit js debug deterministic shell crashing: # js_dbg_32_dm_windows_62f79d676e0e!js::gc::TenuredCell::arena # 00f36a63 8b00 mov eax,dword ptr [eax] -def test_CDBParserTestCrash3a(): +def test_CDBParserTestCrash3a() -> None: config = ProgramConfiguration("test", "x86", "windows") crashInfo = CDBCrashInfo( @@ -1321,7 +1327,7 @@ def test_CDBParserTestCrash3a(): assert crashInfo.crashAddress == 0x00F36A63 -def test_CDBSelectorTest3a(): +def test_CDBSelectorTest3a() -> None: config = ProgramConfiguration("test", "x86", "windows") crashData = (FIXTURE_PATH / "cdb-3a-crashlog.txt").read_text().splitlines() @@ -1333,7 +1339,7 @@ def test_CDBSelectorTest3a(): # Test 3b is for Win10 with 32-bit js debug deterministic shell crashing: # js_dbg_32_dm_windows_62f79d676e0e!js::gc::TenuredCell::arena+13 # 00ed6a63 8b00 mov eax,dword ptr [eax] -def test_CDBParserTestCrash3b(): +def test_CDBParserTestCrash3b() -> None: config = ProgramConfiguration("test", "x86", "windows") crashInfo = CDBCrashInfo( @@ -1392,7 +1398,7 @@ def test_CDBParserTestCrash3b(): assert crashInfo.crashAddress == 0x00ED6A63 -def test_CDBSelectorTest3b(): +def test_CDBSelectorTest3b() -> None: config = ProgramConfiguration("test", "x86", "windows") crashData = (FIXTURE_PATH / "cdb-3b-crashlog.txt").read_text().splitlines() @@ -1404,7 +1410,7 @@ def test_CDBSelectorTest3b(): # Test 4a is for Win7 with 32-bit js opt deterministic shell crashing: # js_32_dm_windows_62f79d676e0e!JSObject::allocKindForTenure # 00d44c59 8b39 mov edi,dword ptr [ecx] -def test_CDBParserTestCrash4a(): +def test_CDBParserTestCrash4a() -> None: config = ProgramConfiguration("test", "x86", "windows") crashInfo = CDBCrashInfo( @@ -1481,7 +1487,7 @@ def test_CDBParserTestCrash4a(): assert crashInfo.crashAddress == 0x00D44C59 -def test_CDBSelectorTest4a(): +def test_CDBSelectorTest4a() -> None: config = ProgramConfiguration("test", "x86", "windows") crashData = (FIXTURE_PATH / "cdb-4a-crashlog.txt").read_text().splitlines() @@ -1493,7 +1499,7 @@ def test_CDBSelectorTest4a(): # Test 4b is for Win10 with 32-bit js opt deterministic shell crashing: # js_32_dm_windows_62f79d676e0e!JSObject::allocKindForTenure+9 # 00404c59 8b39 mov edi,dword ptr [ecx] -def test_CDBParserTestCrash4b(): +def test_CDBParserTestCrash4b() -> None: config = ProgramConfiguration("test", "x86", "windows") crashInfo = CDBCrashInfo( @@ -1554,7 +1560,7 @@ def test_CDBParserTestCrash4b(): assert crashInfo.crashAddress == 0x00404C59 -def test_CDBSelectorTest4b(): +def test_CDBSelectorTest4b() -> None: config = ProgramConfiguration("test", "x86", "windows") crashData = (FIXTURE_PATH / "cdb-4b-crashlog.txt").read_text().splitlines() @@ -1566,7 +1572,7 @@ def test_CDBSelectorTest4b(): # Test 5a is for Win7 with 64-bit js debug deterministic shell crashing: # js_dbg_64_dm_windows_62f79d676e0e!js::gc::IsInsideNursery # 00000001`3f4975db 8b11 mov edx,dword ptr [rcx] -def test_CDBParserTestCrash5a(): +def test_CDBParserTestCrash5a() -> None: config = ProgramConfiguration("test", "x86-64", "windows") crashInfo = CDBCrashInfo( @@ -1631,7 +1637,7 @@ def test_CDBParserTestCrash5a(): assert crashInfo.crashAddress == 0x000000013F4975DB -def test_CDBSelectorTest5a(): +def test_CDBSelectorTest5a() -> None: config = ProgramConfiguration("test", "x86-64", "windows") crashData = (FIXTURE_PATH / "cdb-5a-crashlog.txt").read_text().splitlines() @@ -1643,7 +1649,7 @@ def test_CDBSelectorTest5a(): # Test 5b is for Win10 with 64-bit js debug deterministic shell crashing: # js_dbg_64_dm_windows_62f79d676e0e!js::gc::IsInsideNursery+1b # 00007ff7`1dcf75db 8b11 mov edx,dword ptr [rcx] -def test_CDBParserTestCrash5b(): +def test_CDBParserTestCrash5b() -> None: config = ProgramConfiguration("test", "x86-64", "windows") crashInfo = CDBCrashInfo( @@ -1708,7 +1714,7 @@ def test_CDBParserTestCrash5b(): assert crashInfo.crashAddress == 0x00007FF71DCF75DB -def test_CDBSelectorTest5b(): +def test_CDBSelectorTest5b() -> None: config = ProgramConfiguration("test", "x86-64", "windows") crashData = (FIXTURE_PATH / "cdb-5b-crashlog.txt").read_text().splitlines() @@ -1720,7 +1726,7 @@ def test_CDBSelectorTest5b(): # Test 6a is for Win7 with 64-bit js opt deterministic shell crashing: # js_64_dm_windows_62f79d676e0e!JSObject::allocKindForTenure # 00000001`3f869ff3 4c8b01 mov r8,qword ptr [rcx] -def test_CDBParserTestCrash6a(): +def test_CDBParserTestCrash6a() -> None: config = ProgramConfiguration("test", "x86-64", "windows") crashInfo = CDBCrashInfo( @@ -1809,7 +1815,7 @@ def test_CDBParserTestCrash6a(): assert crashInfo.crashAddress == 0x000000013F869FF3 -def test_CDBSelectorTest6a(): +def test_CDBSelectorTest6a() -> None: config = ProgramConfiguration("test", "x86-64", "windows") crashData = (FIXTURE_PATH / "cdb-6a-crashlog.txt").read_text().splitlines() @@ -1821,7 +1827,7 @@ def test_CDBSelectorTest6a(): # Test 6b is for Win10 with 64-bit js opt deterministic shell crashing: # js_64_dm_windows_62f79d676e0e!JSObject::allocKindForTenure+13 # 00007ff7`4d469ff3 4c8b01 mov r8,qword ptr [rcx] -def test_CDBParserTestCrash6b(): +def test_CDBParserTestCrash6b() -> None: config = ProgramConfiguration("test", "x86-64", "windows") crashInfo = CDBCrashInfo( @@ -1910,7 +1916,7 @@ def test_CDBParserTestCrash6b(): assert crashInfo.crashAddress == 0x00007FF74D469FF3 -def test_CDBSelectorTest6b(): +def test_CDBSelectorTest6b() -> None: config = ProgramConfiguration("test", "x86-64", "windows") crashData = (FIXTURE_PATH / "cdb-6b-crashlog.txt").read_text().splitlines() @@ -1922,7 +1928,7 @@ def test_CDBSelectorTest6b(): # Test 7 is for Windows Server 2012 R2 with 32-bit js debug deterministic shell: # +205 # 25d80b01 cc int 3 -def test_CDBParserTestCrash7(): +def test_CDBParserTestCrash7() -> None: config = ProgramConfiguration("test", "x86", "windows") crashInfo = CDBCrashInfo( @@ -1991,7 +1997,7 @@ def test_CDBParserTestCrash7(): assert crashInfo.crashAddress == 0x25D80B01 -def test_CDBSelectorTest7(): +def test_CDBSelectorTest7() -> None: config = ProgramConfiguration("test", "x86", "windows") crashData = (FIXTURE_PATH / "cdb-7c-crashlog.txt").read_text().splitlines() @@ -2005,7 +2011,7 @@ def test_CDBSelectorTest7(): # js_dbg_32_prof_dm_windows_42c95d88aaaa!js::jit::Range::upper+3d [ # c:\users\administrator\trees\mozilla-central\js\src\jit\rangeanalysis.h @ 578] # 0142865d cc int 3 -def test_CDBParserTestCrash8(): +def test_CDBParserTestCrash8() -> None: config = ProgramConfiguration("test", "x86", "windows") crashInfo = CDBCrashInfo( @@ -2029,7 +2035,7 @@ def test_CDBParserTestCrash8(): assert crashInfo.crashAddress == 0x0142865D -def test_CDBSelectorTest8(): +def test_CDBSelectorTest8() -> None: config = ProgramConfiguration("test", "x86", "windows") crashData = (FIXTURE_PATH / "cdb-8c-crashlog.txt").read_text().splitlines() @@ -2041,7 +2047,7 @@ def test_CDBSelectorTest8(): # Test 9 is for Windows Server 2012 R2 with 32-bit js opt profiling shell: # +1d8 # 0f2bb4f3 cc int 3 -def test_CDBParserTestCrash9(): +def test_CDBParserTestCrash9() -> None: config = ProgramConfiguration("test", "x86", "windows") crashInfo = CDBCrashInfo( @@ -2108,7 +2114,7 @@ def test_CDBParserTestCrash9(): assert crashInfo.crashAddress == 0x0F2BB4F3 -def test_CDBSelectorTest9(): +def test_CDBSelectorTest9() -> None: config = ProgramConfiguration("test", "x86", "windows") crashData = (FIXTURE_PATH / "cdb-9c-crashlog.txt").read_text().splitlines() @@ -2120,7 +2126,7 @@ def test_CDBSelectorTest9(): # Test 10 is for Windows Server 2012 R2 with 32-bit js opt profiling shell: # +82 # 1c2fbbb0 cc int 3 -def test_CDBParserTestCrash10(): +def test_CDBParserTestCrash10() -> None: config = ProgramConfiguration("test", "x86", "windows") crashInfo = CDBCrashInfo( @@ -2148,7 +2154,7 @@ def test_CDBParserTestCrash10(): assert crashInfo.crashAddress == 0x1C2FBBB0 -def test_CDBSelectorTest10(): +def test_CDBSelectorTest10() -> None: config = ProgramConfiguration("test", "x86", "windows") crashData = (FIXTURE_PATH / "cdb-10c-crashlog.txt").read_text().splitlines() @@ -2162,7 +2168,7 @@ def test_CDBSelectorTest10(): # js_dbg_32_prof_dm_windows_42c95d88aaaa!js::jit::Range::upper+3d [ # c:\users\administrator\trees\mozilla-central\js\src\jit\rangeanalysis.h @ 578] # 0156865d cc int 3 -def test_CDBParserTestCrash11(): +def test_CDBParserTestCrash11() -> None: config = ProgramConfiguration("test", "x86", "windows") crashInfo = CDBCrashInfo( @@ -2186,7 +2192,7 @@ def test_CDBParserTestCrash11(): assert crashInfo.crashAddress == 0x0156865D -def test_CDBSelectorTest11(): +def test_CDBSelectorTest11() -> None: config = ProgramConfiguration("test", "x86", "windows") crashData = (FIXTURE_PATH / "cdb-11c-crashlog.txt").read_text().splitlines() @@ -2198,7 +2204,7 @@ def test_CDBSelectorTest11(): # Test 12 is for Windows Server 2012 R2 with 32-bit js opt profiling deterministic shell # +1d8 # 1fa0b7f8 cc int 3 -def test_CDBParserTestCrash12(): +def test_CDBParserTestCrash12() -> None: config = ProgramConfiguration("test", "x86", "windows") crashInfo = CDBCrashInfo( @@ -2225,7 +2231,7 @@ def test_CDBParserTestCrash12(): assert crashInfo.crashAddress == 0x1FA0B7F8 -def test_CDBSelectorTest12(): +def test_CDBSelectorTest12() -> None: config = ProgramConfiguration("test", "x86", "windows") crashData = (FIXTURE_PATH / "cdb-12c-crashlog.txt").read_text().splitlines() @@ -2234,7 +2240,7 @@ def test_CDBSelectorTest12(): assert crashInfo.crashAddress == 0x1FA0B7F8 -def test_UBSanParserTestCrash1(): +def test_UBSanParserTestCrash1() -> None: config = ProgramConfiguration("test", "x86", "linux") crashInfo = CrashInfo.fromRawCrashData( [], @@ -2255,7 +2261,7 @@ def test_UBSanParserTestCrash1(): assert crashInfo.crashAddress is None -def test_UBSanParserTestCrash2(): +def test_UBSanParserTestCrash2() -> None: config = ProgramConfiguration("test", "x86-64", "linux") crashInfo = CrashInfo.fromRawCrashData( [], @@ -2272,7 +2278,7 @@ def test_UBSanParserTestCrash2(): assert crashInfo.crashAddress is None -def test_UBSanParserTestCrash3(): +def test_UBSanParserTestCrash3() -> None: config = ProgramConfiguration("test", "x86-64", "linux") crashInfo = CrashInfo.fromRawCrashData( [], @@ -2285,7 +2291,7 @@ def test_UBSanParserTestCrash3(): assert crashInfo.crashAddress is None -def test_UBSanParserTestCrash4(): +def test_UBSanParserTestCrash4() -> None: config = ProgramConfiguration("test", "x86-64", "linux") crashInfo = CrashInfo.fromRawCrashData( [], @@ -2304,7 +2310,7 @@ def test_UBSanParserTestCrash4(): assert crashInfo.registers["sp"] == 0x7F0662600680 -def test_RustParserTests1(): +def test_RustParserTests1() -> None: """test RUST_BACKTRACE=1 is parsed correctly""" config = ProgramConfiguration("test", "x86-64", "linux") crashInfo = CrashInfo.fromRawCrashData( @@ -2334,7 +2340,7 @@ def test_RustParserTests1(): assert crashInfo.crashAddress == 0 -def test_RustParserTests2(): +def test_RustParserTests2() -> None: """test RUST_BACKTRACE=full is parsed correctly""" config = ProgramConfiguration("test", "x86-64", "linux") crashInfo = CrashInfo.fromRawCrashData( @@ -2387,7 +2393,7 @@ def test_RustParserTests2(): assert crashInfo.crashAddress == 0 -def test_RustParserTests3(): +def test_RustParserTests3() -> None: """test rust backtraces are weakly found, ie. minidump output wins even if it comes after""" config = ProgramConfiguration("test", "x86-64", "win") @@ -2415,7 +2421,7 @@ def test_RustParserTests3(): assert crashInfo.crashAddress == 0x7FFC41F2F276 -def test_RustParserTests4(): +def test_RustParserTests4() -> None: """test another rust backtrace""" config = ProgramConfiguration("test", "x86-64", "linux") crashInfo = CrashInfo.fromRawCrashData( @@ -2440,7 +2446,7 @@ def test_RustParserTests4(): assert crashInfo.crashAddress == 0 -def test_RustParserTests5(): +def test_RustParserTests5() -> None: """test multi-line with minidump trace in sterror rust backtrace""" auxData = [ "OS|Linux|0.0.0 Linux ... x86_64", @@ -2480,7 +2486,7 @@ def test_RustParserTests5(): assert crashInfo.crashAddress == 0 -def test_RustParserTests6(): +def test_RustParserTests6() -> None: """test parsing rust assertion failure backtrace""" config = ProgramConfiguration("test", "x86-64", "linux") crashInfo = CrashInfo.fromRawCrashData( @@ -2499,7 +2505,7 @@ def test_RustParserTests6(): assert crashInfo.crashAddress == 0 -def test_MinidumpModuleInStackTest(): +def test_MinidumpModuleInStackTest() -> None: config = ProgramConfiguration("test", "x86-64", "linux") crashInfo = CrashInfo.fromRawCrashData( @@ -2512,7 +2518,7 @@ def test_MinidumpModuleInStackTest(): assert crashInfo.backtrace[1] == "swrast_dri.so+0x470ecc" -def test_LSanParserTestLeakDetected(): +def test_LSanParserTestLeakDetected() -> None: config = ProgramConfiguration("test", "x86-64", "linux") crashInfo = CrashInfo.fromRawCrashData( @@ -2530,7 +2536,7 @@ def test_LSanParserTestLeakDetected(): assert crashInfo.crashAddress is None -def test_TSanParserSimpleLeakTest(): +def test_TSanParserSimpleLeakTest() -> None: config = ProgramConfiguration("test", "x86-64", "linux") crashInfo = CrashInfo.fromRawCrashData( @@ -2552,7 +2558,7 @@ def test_TSanParserSimpleLeakTest(): assert crashInfo.crashAddress is None -def test_TSanParserSimpleRaceTest(): +def test_TSanParserSimpleRaceTest() -> None: config = ProgramConfiguration("test", "x86-64", "linux") for fn in ["tsan-simple-race-report.txt", "tsan-simple-race-report-swapped.txt"]: @@ -2575,7 +2581,7 @@ def test_TSanParserSimpleRaceTest(): assert crashInfo.crashAddress is None -def test_TSanParserLockReportTest(): +def test_TSanParserLockReportTest() -> None: config = ProgramConfiguration("test", "x86-64", "linux") crashInfo = CrashInfo.fromRawCrashData( @@ -2597,7 +2603,7 @@ def test_TSanParserLockReportTest(): assert crashInfo.crashAddress is None -def test_TSanParserTestCrash(): +def test_TSanParserTestCrash() -> None: config = ProgramConfiguration("test", "x86", "linux") crashInfo = CrashInfo.fromRawCrashData( @@ -2618,7 +2624,7 @@ def test_TSanParserTestCrash(): assert crashInfo.registers["sp"] == 0x7FE1A51BCF00 -def test_TSanParserTest(): +def test_TSanParserTest() -> None: config = ProgramConfiguration("test", "x86-64", "linux") crashInfo = CrashInfo.fromRawCrashData( @@ -2640,7 +2646,7 @@ def test_TSanParserTest(): assert crashInfo.crashAddress is None -def test_TSanParserTestClang14(): +def test_TSanParserTestClang14() -> None: config = ProgramConfiguration("test", "x86-64", "linux") crashInfo = CrashInfo.fromRawCrashData( @@ -2665,7 +2671,7 @@ def test_TSanParserTestClang14(): ] -def test_ValgrindCJMParser(): +def test_ValgrindCJMParser() -> None: config = ProgramConfiguration("test", "x86-64", "linux") crashInfo = CrashInfo.fromRawCrashData( [], [], config, (FIXTURE_PATH / "valgrind-cjm-01.txt").read_text().splitlines() @@ -2701,7 +2707,7 @@ def test_ValgrindCJMParser(): assert crashInfo.crashAddress is None -def test_ValgrindIRWParser(): +def test_ValgrindIRWParser() -> None: config = ProgramConfiguration("test", "x86-64", "linux") crashInfo = CrashInfo.fromRawCrashData( [], [], config, (FIXTURE_PATH / "valgrind-ir-01.txt").read_text().splitlines() @@ -2747,7 +2753,7 @@ def test_ValgrindIRWParser(): assert crashInfo.crashAddress == 0x41414141 -def test_ValgrindUUVParser(): +def test_ValgrindUUVParser() -> None: config = ProgramConfiguration("test", "x86-64", "linux") crashInfo = CrashInfo.fromRawCrashData( [], [], config, (FIXTURE_PATH / "valgrind-uuv-01.txt").read_text().splitlines() @@ -2768,7 +2774,7 @@ def test_ValgrindUUVParser(): assert crashInfo.crashAddress is None -def test_ValgrindIFParser(): +def test_ValgrindIFParser() -> None: config = ProgramConfiguration("test", "x86-64", "linux") crashInfo = CrashInfo.fromRawCrashData( [], [], config, (FIXTURE_PATH / "valgrind-if-01.txt").read_text().splitlines() @@ -2817,7 +2823,7 @@ def test_ValgrindIFParser(): assert crashInfo.crashAddress == 0xBADF00D -def test_ValgrindSDOParser(): +def test_ValgrindSDOParser() -> None: config = ProgramConfiguration("test", "x86-64", "linux") crashInfo = CrashInfo.fromRawCrashData( [], [], config, (FIXTURE_PATH / "valgrind-sdo-01.txt").read_text().splitlines() @@ -2833,7 +2839,7 @@ def test_ValgrindSDOParser(): assert crashInfo.crashAddress is None -def test_ValgrindSCParser(): +def test_ValgrindSCParser() -> None: config = ProgramConfiguration("test", "x86-64", "linux") crashInfo = CrashInfo.fromRawCrashData( [], [], config, (FIXTURE_PATH / "valgrind-sc-01.txt").read_text().splitlines() @@ -2866,7 +2872,7 @@ def test_ValgrindSCParser(): assert crashInfo.crashAddress == 0x5E7B6B4 -def test_ValgrindNMParser(): +def test_ValgrindNMParser() -> None: config = ProgramConfiguration("test", "x86-64", "linux") crashInfo = CrashInfo.fromRawCrashData( [], [], config, (FIXTURE_PATH / "valgrind-nm-01.txt").read_text().splitlines() @@ -2883,7 +2889,7 @@ def test_ValgrindNMParser(): assert crashInfo.crashAddress is None -def test_ValgrindPTParser(): +def test_ValgrindPTParser() -> None: config = ProgramConfiguration("test", "x86-64", "linux") crashInfo = CrashInfo.fromRawCrashData( [], [], config, (FIXTURE_PATH / "valgrind-pt-01.txt").read_text().splitlines() @@ -2900,7 +2906,7 @@ def test_ValgrindPTParser(): assert crashInfo.crashAddress is None -def test_ValgrindLeakParser(): +def test_ValgrindLeakParser() -> None: config = ProgramConfiguration("test", "x86-64", "linux") crashInfo = CrashInfo.fromRawCrashData( [], [], config, (FIXTURE_PATH / "valgrind-leak-01.txt").read_text().splitlines() @@ -2930,7 +2936,7 @@ def test_ValgrindLeakParser(): assert crashInfo.crashAddress is None -def test_SanitizerSoftRssLimitHeapProfile(): +def test_SanitizerSoftRssLimitHeapProfile() -> None: """test that heap profile given after soft rss limit is exceeded is used in place of the (useless) SEGV stack""" config = ProgramConfiguration("test", "x86-64", "linux") @@ -2952,7 +2958,7 @@ def test_SanitizerSoftRssLimitHeapProfile(): assert crashInfo.crashAddress == 40 -def test_SanitizerHardRssLimitHeapProfile(): +def test_SanitizerHardRssLimitHeapProfile() -> None: """test that heap profile given after hard rss limit is exceeded is used in place of the (useless) SEGV stack""" config = ProgramConfiguration("test", "x86-64", "linux") diff --git a/FTB/Signatures/tests/test_CrashSignature.py b/FTB/Signatures/tests/test_CrashSignature.py index 6c0dc25cb..70b81442f 100644 --- a/FTB/Signatures/tests/test_CrashSignature.py +++ b/FTB/Signatures/tests/test_CrashSignature.py @@ -3,6 +3,9 @@ @author: decoder """ + +from __future__ import annotations + import json from pathlib import Path @@ -15,7 +18,7 @@ FIXTURE_PATH = Path(__file__).parent / "fixtures" -def test_SignatureCreateTest(): +def test_SignatureCreateTest() -> None: config = ProgramConfiguration("test", "x86", "linux") crashInfo = CrashInfo.fromRawCrashData( @@ -34,6 +37,9 @@ def test_SignatureCreateTest(): crashSig3 = crashInfo.createCrashSignature( forceCrashInstruction=True, maxFrames=2, minimumSupportedVersion=10 ) + assert crashSig1 is not None + assert crashSig2 is not None + assert crashSig3 is not None # Check that all generated signatures match their originating crashInfo assert crashSig1.matches(crashInfo) @@ -52,7 +58,7 @@ def test_SignatureCreateTest(): assert json.loads(str(crashSig3)) == json.load(f) -def test_SignatureTestCaseMatchTest(): +def test_SignatureTestCaseMatchTest() -> None: config = ProgramConfiguration("test", "x86", "linux") crashInfo = CrashInfo.fromRawCrashData( @@ -87,7 +93,7 @@ def test_SignatureTestCaseMatchTest(): assert not testSig6.matches(crashInfo) -def test_SignatureStackFramesTest(): +def test_SignatureStackFramesTest() -> None: config = ProgramConfiguration("test", "x86", "linux") crashInfo = CrashInfo.fromRawCrashData( @@ -120,7 +126,7 @@ def test_SignatureStackFramesTest(): assert not testSig5.matches(crashInfo) -def test_SignatureStackFramesAlgorithmsTest(): +def test_SignatureStackFramesAlgorithmsTest() -> None: # Do some direct matcher tests on edge cases assert StackFramesSymptom._match([], [StringMatch("???")]) assert not StackFramesSymptom._match([], [StringMatch("???"), StringMatch("a")]) @@ -154,10 +160,11 @@ def test_SignatureStackFramesAlgorithmsTest(): stack, [StringMatch(x) for x in rawSig], 0, 1, maxDepth ) assert expectedDepth == actualDepth + assert actualSig is not None assert expectedSig == [str(x) for x in actualSig] -def test_SignaturePCREShortTest(): +def test_SignaturePCREShortTest() -> None: config = ProgramConfiguration("test", "x86", "linux") crashInfo = CrashInfo.fromRawCrashData( @@ -174,7 +181,7 @@ def test_SignaturePCREShortTest(): assert not testSig2.matches(crashInfo) -def test_SignatureStackFramesWildcardTailTest(): +def test_SignatureStackFramesWildcardTailTest() -> None: config = ProgramConfiguration("test", "x86", "linux") crashInfo = CrashInfo.fromRawCrashData( @@ -185,6 +192,7 @@ def test_SignatureStackFramesWildcardTailTest(): ) testSig = crashInfo.createCrashSignature() + assert testSig is not None # Ensure that the last frame with a symbol is at the right place and there is # nothing else, especially no wildcard, following afterwards. @@ -196,7 +204,7 @@ def test_SignatureStackFramesWildcardTailTest(): assert len(testSig.symptoms[0].functionNames) == 7 -def test_SignatureStackFramesRegressionTest(): +def test_SignatureStackFramesRegressionTest() -> None: config = ProgramConfiguration("test", "x86", "linux") crashInfoNeg = CrashInfo.fromRawCrashData( [], @@ -223,7 +231,7 @@ def test_SignatureStackFramesRegressionTest(): assert not testSigEmptyCrashAddress.matches(crashInfoNeg) -def test_SignatureStackFramesAuxMessagesTest(): +def test_SignatureStackFramesAuxMessagesTest() -> None: config = ProgramConfiguration("test", "x86-64", "linux") crashInfoPos = CrashInfo.fromRawCrashData( [], @@ -244,6 +252,8 @@ def test_SignatureStackFramesAuxMessagesTest(): crashSignaturePos = crashInfoPos.createCrashSignature() crashSignatureNeg = crashInfoNeg.createCrashSignature() + assert crashSignaturePos is not None + assert crashSignatureNeg is not None # Check that the first crash signature has ASan symptoms but # the second does not because it has a program abort message @@ -259,7 +269,7 @@ def test_SignatureStackFramesAuxMessagesTest(): assert crashSignatureNeg.matches(crashInfoNeg) -def test_SignatureStackFramesNegativeSizeParamTest(): +def test_SignatureStackFramesNegativeSizeParamTest() -> None: config = ProgramConfiguration("test", "x86-64", "linux") crashInfoPos = CrashInfo.fromRawCrashData( [], @@ -271,13 +281,14 @@ def test_SignatureStackFramesNegativeSizeParamTest(): ) testSig = crashInfoPos.createCrashSignature() + assert testSig is not None assert "/ERROR: AddressSanitizer" in str(testSig) assert "negative-size-param" in str(testSig) assert isinstance(testSig.symptoms[1], StackFramesSymptom) -def test_SignatureAsanStackOverflowTest(): +def test_SignatureAsanStackOverflowTest() -> None: config = ProgramConfiguration("test", "x86-64", "linux") crashInfoPos = CrashInfo.fromRawCrashData( [], @@ -289,12 +300,13 @@ def test_SignatureAsanStackOverflowTest(): ) testSig = crashInfoPos.createCrashSignature() + assert testSig is not None # Check matches appropriately assert testSig.matches(crashInfoPos) -def test_SignatureAsanAccessViolationTest(): +def test_SignatureAsanAccessViolationTest() -> None: config = ProgramConfiguration("test", "x86-64", "windows") crashInfoPos = CrashInfo.fromRawCrashData( [], @@ -306,13 +318,14 @@ def test_SignatureAsanAccessViolationTest(): ) testSig = crashInfoPos.createCrashSignature() + assert testSig is not None assert "/ERROR: AddressSanitizer" not in str(testSig) assert "access-violation" not in str(testSig) assert isinstance(testSig.symptoms[0], StackFramesSymptom) -def test_SignatureStackSizeTest(): +def test_SignatureStackSizeTest() -> None: config = ProgramConfiguration("test", "x86-64", "linux") crashInfoPos = CrashInfo.fromRawCrashData( [], @@ -327,7 +340,7 @@ def test_SignatureStackSizeTest(): assert testSig.matches(crashInfoPos) -def test_SignatureAsanFailedAllocTest(): +def test_SignatureAsanFailedAllocTest() -> None: config = ProgramConfiguration("test", "x86-64", "linux") crashInfoPos = CrashInfo.fromRawCrashData( [], @@ -339,12 +352,13 @@ def test_SignatureAsanFailedAllocTest(): ) testSig = crashInfoPos.createCrashSignature() + assert testSig is not None assert "/AddressSanitizer failed to allocate" in str(testSig) assert testSig.matches(crashInfoPos) assert isinstance(testSig.symptoms[1], StackFramesSymptom) -def test_SignatureGenerationTSanLeakTest(): +def test_SignatureGenerationTSanLeakTest() -> None: config = ProgramConfiguration("test", "x86-64", "linux") crashInfo = CrashInfo.fromRawCrashData( [], @@ -355,6 +369,7 @@ def test_SignatureGenerationTSanLeakTest(): .splitlines(), ) testSignature = crashInfo.createCrashSignature() + assert testSignature is not None assert testSignature.matches(crashInfo) @@ -367,7 +382,7 @@ def test_SignatureGenerationTSanLeakTest(): assert found, "Expected correct OutputSymptom in signature" -def test_SignatureGenerationTSanRaceTest(): +def test_SignatureGenerationTSanRaceTest() -> None: config = ProgramConfiguration("test", "x86-64", "linux") crashInfo = CrashInfo.fromRawCrashData( [], @@ -378,6 +393,7 @@ def test_SignatureGenerationTSanRaceTest(): .splitlines(), ) testSignature = crashInfo.createCrashSignature() + assert testSignature is not None print(testSignature) @@ -410,7 +426,7 @@ def test_SignatureGenerationTSanRaceTest(): assert found, f"Couldn't find OutputSymptom with value '{stringMatchVal}'" -def test_SignatureGenerationTSanRaceTestComplex1(): +def test_SignatureGenerationTSanRaceTestComplex1() -> None: config = ProgramConfiguration("test", "x86-64", "linux") crashInfo = CrashInfo.fromRawCrashData( [], @@ -419,6 +435,7 @@ def test_SignatureGenerationTSanRaceTestComplex1(): auxCrashData=(FIXTURE_PATH / "tsan-report2.txt").read_text().splitlines(), ) testSignature = crashInfo.createCrashSignature() + assert testSignature is not None print(testSignature) @@ -451,7 +468,7 @@ def test_SignatureGenerationTSanRaceTestComplex1(): assert found, f"Couldn't find OutputSymptom with value '{stringMatchVal}'" -def test_SignatureGenerationTSanRaceTestAtomic(): +def test_SignatureGenerationTSanRaceTestAtomic() -> None: config = ProgramConfiguration("test", "x86-64", "linux") for fn in ["tsan-report-atomic.txt", "tsan-report-atomic-swapped.txt"]: crashInfo = CrashInfo.fromRawCrashData( @@ -465,6 +482,7 @@ def test_SignatureGenerationTSanRaceTestAtomic(): ) testSignature = crashInfo.createCrashSignature() + assert testSignature is not None assert testSignature.matches(crashInfo) @@ -495,7 +513,7 @@ def test_SignatureGenerationTSanRaceTestAtomic(): assert found, f"Couldn't find OutputSymptom with value '{stringMatchVal}'" -def test_SignatureMatchWithUnicode(): +def test_SignatureMatchWithUnicode() -> None: config = ProgramConfiguration("test", "x86-64", "linux") crashInfo = CrashInfo.fromRawCrashData( ["(«f => (generator.throw(f))», «undefined»)"], [], config @@ -506,7 +524,7 @@ def test_SignatureMatchWithUnicode(): assert not testSignature.matches(crashInfo) -def test_SignatureMatchAssertionSlashes(): +def test_SignatureMatchAssertionSlashes() -> None: # test that a forward slash assertion signature matches a backwards slash crash, but # only on windows cfg_linux = ProgramConfiguration("test", "x86-64", "linux") @@ -530,6 +548,7 @@ def test_SignatureMatchAssertionSlashes(): # test that signature generated from linux assertion matches both linux_sig = fs_linux.createCrashSignature() + assert linux_sig is not None assert linux_sig.matches(fs_linux) assert not linux_sig.matches(bs_linux) # this is invalid and should not match assert linux_sig.matches(fs_windows) @@ -537,13 +556,14 @@ def test_SignatureMatchAssertionSlashes(): # test that signature generated from windows assertion matches both windows_sig = bs_windows.createCrashSignature() + assert windows_sig is not None assert windows_sig.matches(fs_linux) assert not windows_sig.matches(bs_linux) # this is invalid and should not match assert windows_sig.matches(fs_windows) assert windows_sig.matches(bs_windows) -def test_SignatureSanitizerSoftRssLimitHeapProfile(): +def test_SignatureSanitizerSoftRssLimitHeapProfile() -> None: config = ProgramConfiguration("test", "x86-64", "linux") crashInfo = CrashInfo.fromRawCrashData( [], @@ -558,7 +578,7 @@ def test_SignatureSanitizerSoftRssLimitHeapProfile(): assert isinstance(testSig.symptoms[0], StackFramesSymptom) -def test_SignatureSanitizerHardRssLimitHeapProfile(): +def test_SignatureSanitizerHardRssLimitHeapProfile() -> None: config = ProgramConfiguration("test", "x86-64", "linux") crashInfo = CrashInfo.fromRawCrashData( [], diff --git a/FTB/tests/test_AssertionHelper.py b/FTB/tests/test_AssertionHelper.py index 4c1ef4ce6..58d38505e 100644 --- a/FTB/tests/test_AssertionHelper.py +++ b/FTB/tests/test_AssertionHelper.py @@ -11,6 +11,9 @@ @contact: choller@mozilla.com """ + +from __future__ import annotations + import re from pathlib import Path @@ -19,7 +22,9 @@ FIXTURE_PATH = Path(__file__).parent / "fixtures" -def _check_regex_matches(error_lines, sanitized_message): +def _check_regex_matches( + error_lines: list[str] | str, sanitized_message: list[str] | str +) -> None: if isinstance(sanitized_message, (str, bytes)): sanitized_message = [sanitized_message] else: @@ -38,14 +43,14 @@ def _check_regex_matches(error_lines, sanitized_message): ) -def test_AssertionHelperTestASanFFAbort(): +def test_AssertionHelperTestASanFFAbort() -> None: err = (FIXTURE_PATH / "assert_asan_ff_abort.txt").read_text().splitlines() assert AssertionHelper.getAssertion(err) is None assert AssertionHelper.getAuxiliaryAbortMessage(err) is None -def test_AssertionHelperTestASanNegativeSize(): +def test_AssertionHelperTestASanNegativeSize() -> None: err = (FIXTURE_PATH / "assert_asan_negative_size.txt").read_text().splitlines() assert AssertionHelper.getAssertion(err) is None @@ -58,7 +63,7 @@ def test_AssertionHelperTestASanNegativeSize(): assert assertMsg == expectedAssertMsg -def test_AssertionHelperTestASanStackOverflow(): +def test_AssertionHelperTestASanStackOverflow() -> None: err = (FIXTURE_PATH / "assert_asan_stack_overflow.txt").read_text().splitlines() assert AssertionHelper.getAssertion(err) is None @@ -67,7 +72,7 @@ def test_AssertionHelperTestASanStackOverflow(): assert assertMsg == expectedAssertMsg -def test_AssertionHelperTestMozCrash(): +def test_AssertionHelperTestMozCrash() -> None: err = (FIXTURE_PATH / "assert_jsshell_moz_crash.txt").read_text().splitlines() sanitizedMsg = AssertionHelper.getSanitizedAssertionPattern( @@ -81,7 +86,7 @@ def test_AssertionHelperTestMozCrash(): _check_regex_matches(err, sanitizedMsg) -def test_AssertionHelperTestMozCrashMultiLine(): +def test_AssertionHelperTestMozCrashMultiLine() -> None: err = (FIXTURE_PATH / "assert_moz_crash_multiline.txt").read_text().splitlines() sanitizedMsg = AssertionHelper.getSanitizedAssertionPattern( @@ -98,7 +103,7 @@ def test_AssertionHelperTestMozCrashMultiLine(): _check_regex_matches(err, sanitizedMsg) -def test_AssertionHelperTestMozCrashWithPath(): +def test_AssertionHelperTestMozCrashWithPath() -> None: err = (FIXTURE_PATH / "assert_moz_crash_with_path.txt").read_text().splitlines() sanitizedMsg = AssertionHelper.getSanitizedAssertionPattern( @@ -112,7 +117,7 @@ def test_AssertionHelperTestMozCrashWithPath(): _check_regex_matches(err, sanitizedMsg) -def test_AssertionHelperTestMultiMozCrash(): +def test_AssertionHelperTestMultiMozCrash() -> None: err = (FIXTURE_PATH / "assert_moz_crash_multi.txt").read_text().splitlines() sanitizedMsg = AssertionHelper.getSanitizedAssertionPattern( @@ -126,7 +131,7 @@ def test_AssertionHelperTestMultiMozCrash(): _check_regex_matches(err, sanitizedMsg) -def test_AssertionHelperTestJSSelfHosted(): +def test_AssertionHelperTestJSSelfHosted() -> None: err = ( (FIXTURE_PATH / "assert_jsshell_self_hosted_assert.txt") .read_text() @@ -145,7 +150,7 @@ def test_AssertionHelperTestJSSelfHosted(): _check_regex_matches(err, sanitizedMsg) -def test_AssertionHelperTestV8Abort(): +def test_AssertionHelperTestV8Abort() -> None: err = (FIXTURE_PATH / "assert_v8_abort.txt").read_text().splitlines() sanitizedMsgs = AssertionHelper.getSanitizedAssertionPattern( @@ -166,7 +171,7 @@ def test_AssertionHelperTestV8Abort(): _check_regex_matches(err, sanitizedMsgs) -def test_AssertionHelperTestChakraAssert(): +def test_AssertionHelperTestChakraAssert() -> None: err = (FIXTURE_PATH / "assert_chakra_assert.txt").read_text().splitlines() sanitizedMsg = AssertionHelper.getSanitizedAssertionPattern( @@ -181,7 +186,7 @@ def test_AssertionHelperTestChakraAssert(): _check_regex_matches(err, sanitizedMsg) -def test_AssertionHelperTestWindowsPathSanitizing(): +def test_AssertionHelperTestWindowsPathSanitizing() -> None: err1 = ( (FIXTURE_PATH / "assert_windows_forward_slash_path.txt") .read_text() @@ -216,7 +221,7 @@ def test_AssertionHelperTestWindowsPathSanitizing(): # _check_regex_matches(err2, sanitizedMsg2) -def test_AssertionHelperTestAuxiliaryAbortASan(): +def test_AssertionHelperTestAuxiliaryAbortASan() -> None: err = ( (FIXTURE_PATH / "assert_asan_heap_buffer_overflow.txt").read_text().splitlines() ) @@ -233,7 +238,7 @@ def test_AssertionHelperTestAuxiliaryAbortASan(): _check_regex_matches(err, sanitizedMsg) -def test_AssertionHelperTestCPPUnhandledException(): +def test_AssertionHelperTestCPPUnhandledException() -> None: err = (FIXTURE_PATH / "assert_cpp_unhandled_exception.txt").read_text().splitlines() sanitizedMsg = AssertionHelper.getSanitizedAssertionPattern( @@ -245,7 +250,7 @@ def test_AssertionHelperTestCPPUnhandledException(): _check_regex_matches(err, sanitizedMsg) -def test_AssertionHelperTestRustPanic01(): +def test_AssertionHelperTestRustPanic01() -> None: err = (FIXTURE_PATH / "assert_rust_panic1.txt").read_text().splitlines() sanitizedMsg = AssertionHelper.getSanitizedAssertionPattern( AssertionHelper.getAssertion(err) @@ -260,7 +265,7 @@ def test_AssertionHelperTestRustPanic01(): _check_regex_matches(err, sanitizedMsg) -def test_AssertionHelperTestRustPanic02(): +def test_AssertionHelperTestRustPanic02() -> None: err = (FIXTURE_PATH / "assert_rust_panic2.txt").read_text().splitlines() sanitizedMsg = AssertionHelper.getSanitizedAssertionPattern( AssertionHelper.getAssertion(err) @@ -274,7 +279,7 @@ def test_AssertionHelperTestRustPanic02(): _check_regex_matches(err, sanitizedMsg) -def test_AssertionHelperTestRustPanic03(): +def test_AssertionHelperTestRustPanic03() -> None: err = (FIXTURE_PATH / "assert_rust_panic3.txt").read_text().splitlines() sanitizedMsg = AssertionHelper.getSanitizedAssertionPattern( AssertionHelper.getAssertion(err) diff --git a/FTB/tests/test_CoverageHelper.py b/FTB/tests/test_CoverageHelper.py index 4a273afb6..d63f5a702 100644 --- a/FTB/tests/test_CoverageHelper.py +++ b/FTB/tests/test_CoverageHelper.py @@ -11,6 +11,9 @@ @contact: choller@mozilla.com """ + +from __future__ import annotations + import json from FTB import CoverageHelper @@ -146,7 +149,7 @@ """ -def test_CoverageHelperFlattenNames(): +def test_CoverageHelperFlattenNames() -> None: node = json.loads(covdata) result = CoverageHelper.get_flattened_names(node, prefix="") @@ -165,7 +168,7 @@ def test_CoverageHelperFlattenNames(): assert result == set(expected_names) -def test_CoverageHelperApplyDirectivesMixed(): +def test_CoverageHelperApplyDirectivesMixed() -> None: node = json.loads(covdata) # Check that mixed directives work properly (exclude multiple paths, include some @@ -193,7 +196,7 @@ def test_CoverageHelperApplyDirectivesMixed(): assert result == set(expected_names) -def test_CoverageHelperApplyDirectivesPrune(): +def test_CoverageHelperApplyDirectivesPrune() -> None: node = json.loads(covdata) # Check that any empty childs are pruned (empty childs are not useful) @@ -208,7 +211,7 @@ def test_CoverageHelperApplyDirectivesPrune(): assert result == set(expected_names) -def test_CoverageHelperApplyDirectivesExcludeAll(): +def test_CoverageHelperApplyDirectivesExcludeAll() -> None: node = json.loads(covdata) # Check that excluding all paths works (specialized case) @@ -223,7 +226,7 @@ def test_CoverageHelperApplyDirectivesExcludeAll(): assert result == set(expected_names) -def test_CoverageHelperApplyDirectivesMakeEmpty(): +def test_CoverageHelperApplyDirectivesMakeEmpty() -> None: node = json.loads(covdata) # Check that making the set entirely empty doesn't crash things (tsmith mode) @@ -233,6 +236,6 @@ def test_CoverageHelperApplyDirectivesMakeEmpty(): result = CoverageHelper.get_flattened_names(node, prefix="") - expected_names = [] + expected_names: list[str] = [] assert result == set(expected_names) diff --git a/Reporter/Reporter.py b/Reporter/Reporter.py index af5d4fc4e..f42495412 100644 --- a/Reporter/Reporter.py +++ b/Reporter/Reporter.py @@ -11,12 +11,17 @@ @contact: choller@mozilla.com """ + +from __future__ import annotations + import functools import logging import os import platform import time from abc import ABC +from collections.abc import Callable +from typing import Any, TypeVar import requests import requests.exceptions @@ -24,13 +29,14 @@ from FTB.ConfigurationFiles import ConfigurationFiles LOG = logging.getLogger(__name__) +RetType = TypeVar("RetType") -def remote_checks(wrapped): +def remote_checks(wrapped: Callable[..., RetType]) -> Callable[..., RetType]: """Decorator to perform error checks before using remote features""" @functools.wraps(wrapped) - def decorator(self, *args, **kwargs): + def decorator(self: Reporter, *args: str, **kwargs: str) -> RetType: if not self.serverHost: raise RuntimeError( "Must specify serverHost (configuration property: serverhost) to use " @@ -51,11 +57,11 @@ def decorator(self, *args, **kwargs): return decorator -def signature_checks(wrapped): +def signature_checks(wrapped: Callable[..., RetType]) -> Callable[..., RetType]: """Decorator to perform error checks before using signature features""" @functools.wraps(wrapped) - def decorator(self, *args, **kwargs): + def decorator(self: Reporter, *args: str, **kwargs: str) -> RetType: if not self.sigCacheDir: raise RuntimeError( "Must specify sigCacheDir (configuration property: sigdir) to use " @@ -66,12 +72,12 @@ def decorator(self, *args, **kwargs): return decorator -def requests_retry(wrapped): +def requests_retry(wrapped: Callable[..., Any]) -> Callable[..., Any]: """Wrapper around requests methods that retries up to 2 minutes if it's likely that the response codes indicate a temporary error""" @functools.wraps(wrapped) - def wrapper(*args, **kwds): + def wrapper(*args: str, **kwds: Any) -> Any: success = kwds.pop("expected") current_timeout = 2 while True: @@ -107,30 +113,24 @@ def wrapper(*args, **kwds): class Reporter(ABC): def __init__( self, - sigCacheDir=None, - serverHost=None, - serverPort=None, - serverProtocol=None, - serverAuthToken=None, - clientId=None, - tool=None, + sigCacheDir: str | None = None, + serverHost: str | None = None, + serverPort: int | None = None, + serverProtocol: str | None = None, + serverAuthToken: str | None = None, + clientId: str | None = None, + tool: str | None = None, ): """ Initialize the Reporter. This constructor will also attempt to read a configuration file to populate any missing properties that have not been passed to this constructor. - @type sigCacheDir: string @param sigCacheDir: Directory to be used for caching signatures - @type serverHost: string @param serverHost: Server host to contact for refreshing signatures - @type serverPort: int @param serverPort: Server port to use when contacting server - @type serverAuthToken: string @param serverAuthToken: Token for server authentication - @type clientId: string @param clientId: Client ID stored in the server when submitting issues - @type tool: string @param tool: Name of the tool that found this issue """ self.sigCacheDir = sigCacheDir @@ -192,7 +192,7 @@ def __init__( if self.serverHost is not None and self.clientId is None: self.clientId = platform.node() - def get(self, *args, **kwds): + def get(self, *args: Any, **kwds: Any) -> Any: """requests.get, with added support for FuzzManager authentication and retry on 5xx errors. @@ -206,7 +206,7 @@ def get(self, *args, **kwds): ) return requests_retry(self._session.get)(*args, **kwds) - def post(self, *args, **kwds): + def post(self, *args: Any, **kwds: Any) -> Any: """requests.post, with added support for FuzzManager authentication and retry on 5xx errors. @@ -220,7 +220,7 @@ def post(self, *args, **kwds): ) return requests_retry(self._session.post)(*args, **kwds) - def patch(self, *args, **kwds): + def patch(self, *args: Any, **kwds: Any) -> Any: """requests.patch, with added support for FuzzManager authentication and retry on 5xx errors. @@ -235,7 +235,7 @@ def patch(self, *args, **kwds): return requests_retry(self._session.patch)(*args, **kwds) @staticmethod - def serverError(response): + def serverError(response: requests.Response) -> RuntimeError: return RuntimeError( "Server unexpectedly responded with status code %s: %s" % (response.status_code, response.text) diff --git a/TaskStatusReporter/TaskStatusReporter.py b/TaskStatusReporter/TaskStatusReporter.py index ae6775e91..10296e5b8 100755 --- a/TaskStatusReporter/TaskStatusReporter.py +++ b/TaskStatusReporter/TaskStatusReporter.py @@ -15,12 +15,16 @@ @contact: jschwartzentruber@mozilla.com """ + +from __future__ import annotations + import argparse import functools import os import random import sys import time +from typing import Any import requests from fasteners import InterProcessLock @@ -28,7 +32,7 @@ from FTB.ConfigurationFiles import ConfigurationFiles # noqa from Reporter.Reporter import Reporter, remote_checks -__all__ = [] +__all__: list[str] = [] __version__ = 0.1 __date__ = "2014-10-01" __updated__ = "2014-10-01" @@ -36,18 +40,17 @@ class TaskStatusReporter(Reporter): @functools.wraps(Reporter.__init__) - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: kwargs.setdefault( "tool", "N/A" ) # tool is required by remote_checks, but unused by TaskStatusReporter super().__init__(*args, **kwargs) @remote_checks - def report(self, text): + def report(self, text: str) -> None: """ Send textual report to server, overwriting any existing reports. - @type text: string @param text: Report text to send """ url = "{}://{}:{}/taskmanager/rest/tasks/update_status/".format( @@ -64,7 +67,7 @@ def report(self, text): self.post(url, data, expected=requests.codes["ok"]) -def main(argv=None): +def main(argv: list[str] | None = None) -> int: """Command line options.""" # setup argparser diff --git a/TaskStatusReporter/tests/test_TaskStatusReporter.py b/TaskStatusReporter/tests/test_TaskStatusReporter.py index 954634152..dc2e34a35 100644 --- a/TaskStatusReporter/tests/test_TaskStatusReporter.py +++ b/TaskStatusReporter/tests/test_TaskStatusReporter.py @@ -1,7 +1,12 @@ +from __future__ import annotations + +from pathlib import Path from unittest.mock import Mock, patch from urllib.parse import urlsplit import pytest +from django.contrib.auth.models import User +from pytest_django.live_server_helper import LiveServer from taskmanager.models import Task from taskmanager.tests import create_pool, create_task @@ -11,7 +16,7 @@ pytest_plugins = "server.tests" -def test_taskstatusreporter_help(capsys): +def test_taskstatusreporter_help(capsys: pytest.CaptureFixture[str]) -> None: """Test that help prints without throwing""" with pytest.raises(SystemExit): main() @@ -22,7 +27,9 @@ def test_taskstatusreporter_help(capsys): # @pytest.mark.skipif(str is bytes, reason="TaskManager requires python3") @patch("os.path.expanduser") @patch("time.sleep", new=Mock()) -def test_taskstatusreporter_report(mock_expanduser, live_server, tmp_path, fm_user): +def test_taskstatusreporter_report( + mock_expanduser: Mock, live_server: LiveServer, tmp_path: Path, fm_user: User +) -> None: """Test report submission""" mock_expanduser.side_effect = lambda path: str( tmp_path diff --git a/conftest.py b/conftest.py index 52c8697c2..a00a7971f 100644 --- a/conftest.py +++ b/conftest.py @@ -11,8 +11,10 @@ import sys from pathlib import Path +from _pytest.config import Config -def pytest_ignore_collect(path, config): + +def pytest_ignore_collect(path: str, config: Config) -> bool: # Django 4.1 requires 3.8 # 3.11 causes an ImportError in vine (via celery) if sys.version_info < (3, 8) or sys.version_info >= (3, 11): diff --git a/misc/afl-libfuzzer/S3Manager.py b/misc/afl_libfuzzer/S3Manager.py similarity index 91% rename from misc/afl-libfuzzer/S3Manager.py rename to misc/afl_libfuzzer/S3Manager.py index e33e76930..683d299e6 100644 --- a/misc/afl-libfuzzer/S3Manager.py +++ b/misc/afl_libfuzzer/S3Manager.py @@ -11,6 +11,9 @@ @contact: choller@mozilla.com """ + +from __future__ import annotations + import hashlib import os import platform @@ -29,16 +32,15 @@ class S3Manager: def __init__( - self, bucket_name, project_name, build_project_name=None, zip_name="build.zip" + self, + bucket_name: str, + project_name: str, + build_project_name: str | None = None, + zip_name: str = "build.zip", ): """ - @type bucket_name: String @param bucket_name: Name of the S3 bucket to use - - @type project_name: String @param project_name: Name of the project folder inside the S3 bucket - - @type cmdline_file: String @param cmdline_file: Path to the cmdline file to upload. """ self.bucket_name = bucket_name @@ -62,22 +64,19 @@ def __init__( # Memorize which files we have uploaded/downloaded before, so we never attempt # to re-upload them to a different queue or re-download them after a local # merge. - self.uploaded_files = set() - self.downloaded_files = set() + self.uploaded_files: set[str] = set() + self.downloaded_files: set[str] = set() - def upload_libfuzzer_queue_dir(self, base_dir, corpus_dir, original_corpus): + def upload_libfuzzer_queue_dir( + self, base_dir: str, corpus_dir: str, original_corpus: list[str] + ) -> None: """ Synchronize the corpus directory of the specified libFuzzer corpus directory to the specified S3 bucket. This method only uploads files that don't exist yet on the receiving side and excludes all files in the original corpus. - @type base_dir: String @param base_dir: libFuzzer base directory - - @type corpus_dir: String @param corpus_dir: libFuzzer corpus directory - - @type original_corpus: Set @param original_corpus: Set of original corpus files to exclude from synchronization """ @@ -96,12 +95,11 @@ def upload_libfuzzer_queue_dir(self, base_dir, corpus_dir, original_corpus): corpus_dir, upload_files, base_dir, cmdline_file ) - def download_libfuzzer_queues(self, corpus_dir): + def download_libfuzzer_queues(self, corpus_dir: str) -> None: """ Synchronize files from open libFuzzer queues directly back into the local corpus directory. - @type corpus_dir: String @param corpus_dir: libFuzzer corpus directory """ remote_keys = list(self.bucket.list(self.remote_path_queues)) @@ -142,16 +140,13 @@ def download_libfuzzer_queues(self, corpus_dir): self.downloaded_files.add(basename) - def upload_afl_queue_dir(self, base_dir, new_cov_only=True): + def upload_afl_queue_dir(self, base_dir: str, new_cov_only: bool = True) -> None: """ Synchronize the queue directory of the specified AFL base directory to the specified S3 bucket. This method only uploads files that don't exist yet on the receiving side. - @type base_dir: String @param base_dir: AFL base directory - - @type new_cov_only: Boolean @param new_cov_only: Only upload files that have new coverage """ queue_dir = os.path.join(base_dir, "queue") @@ -177,7 +172,7 @@ def upload_afl_queue_dir(self, base_dir, new_cov_only=True): cmdline_file = os.path.join(base_dir, "cmdline") return self.__upload_queue_files(queue_dir, queue_files, base_dir, cmdline_file) - def download_queue_dirs(self, work_dir): + def download_queue_dirs(self, work_dir: str) -> None: """ Downloads all queue files into the queues sub directory of the specified local work directory. The files are renamed to match their SHA1 hashes @@ -185,7 +180,6 @@ def download_queue_dirs(self, work_dir): This method marks all remote queues that have been downloaded as closed. - @type work_dir: String @param work_dir: Local work directory """ download_dir = os.path.join(work_dir, "queues") @@ -242,7 +236,7 @@ def download_queue_dirs(self, work_dir): os.rename(tmp_file, os.path.join(download_dir, hash_name)) - def clean_queue_dirs(self): + def clean_queue_dirs(self) -> None: """ Delete all closed remote queues. """ @@ -270,11 +264,10 @@ def clean_queue_dirs(self): self.bucket.delete_keys(remote_keys_for_deletion, quiet=True) - def get_queue_status(self): + def get_queue_status(self) -> dict[str, int]: """ Return status data for all queues in the specified S3 bucket/project - @rtype: dict @return: Dictionary containing queue size per queue """ remote_keys = list(self.bucket.list(self.remote_path_queues)) @@ -282,7 +275,7 @@ def get_queue_status(self): x.name.rsplit("/", 1)[0] for x in remote_keys if x.name.endswith("/closed") ] - status_data = {} + status_data: dict[str, int] = {} for remote_key in remote_keys: # Ignore any folders @@ -306,22 +299,15 @@ def get_queue_status(self): return status_data - def get_corpus_status(self): + def get_corpus_status(self) -> dict[str, int]: """ Return status data for the corpus of the specified S3 bucket/project - @type bucket_name: String - @param bucket_name: Name of the S3 bucket to use - - @type project_name: String - @param project_name: Name of the project folder inside the S3 bucket - - @rtype: dict @return: Dictionary containing corpus size per date modified """ remote_keys = list(self.bucket.list(self.remote_path_corpus)) - status_data = {} + status_data: dict[str, int] = {} for remote_key in remote_keys: # Ignore any folders @@ -338,18 +324,13 @@ def get_corpus_status(self): return status_data - def download_build(self, build_dir): + def download_build(self, build_dir: str) -> None: """ Downloads build.zip from the specified S3 bucket and unpacks it into the specified build directory. - @type base_dir: String @param base_dir: Build directory - - @type bucket_name: String @param bucket_name: Name of the S3 bucket to use - - @type project_name: String @param project_name: Name of the project folder inside the S3 bucket """ # Clear any previous builds @@ -366,33 +347,31 @@ def download_build(self, build_dir): subprocess.check_call(["unzip", zip_dest, "-d", build_dir]) - def upload_build(self, build_file): + def upload_build(self, build_file: str) -> None: """ Upload the given build zip file to the specified S3 bucket/project directory. - @type build_file: String @param build_file: (ZIP) file containing the build that should be uploaded """ if not os.path.exists(build_file) or not os.path.isfile(build_file): print("Error: Build must be a (zip) file.", file=sys.stderr) - return + return None remote_key = Key(self.bucket) remote_key.name = self.remote_path_build print(f"Uploading file {build_file} -> {remote_key.name}") remote_key.set_contents_from_filename(build_file) - def download_corpus(self, corpus_dir, random_subset_size=None): + def download_corpus( + self, corpus_dir: str, random_subset_size: int | None = None + ) -> None: """ Downloads the test corpus from the specified S3 bucket and project into the specified directory, without overwriting any files. - @type corpus_dir: String @param corpus_dir: Directory where to store test corpus files - - @type random_subset_size: int @param random_subset_size: If specified, only download a random subset of the corpus, with the specified size. """ @@ -436,15 +415,12 @@ def download_corpus(self, corpus_dir, random_subset_size=None): if not os.path.exists(dest_file): remote_key.get_contents_to_filename(dest_file) - def upload_corpus(self, corpus_dir, corpus_delete=False): + def upload_corpus(self, corpus_dir: str, corpus_delete: bool = False) -> None: """ Synchronize the specified test corpus directory to the specified S3 bucket. This method only uploads files that don't exist yet on the receiving side. - @type corpus_dir: String @param corpus_dir: Directory where the test corpus files are stored - - @type corpus_delete: bool @param corpus_delete: Delete all remote files that don't exist on our side """ test_files = [ @@ -496,20 +472,15 @@ def upload_corpus(self, corpus_dir, corpus_delete=False): if corpus_delete: self.bucket.delete_keys(delete_list, quiet=True) - def __get_machine_id(self, base_dir, refresh=False): + def __get_machine_id(self, base_dir: str, refresh: bool = False) -> str: """ Get (and if necessary generate) the machine id which is based on the current timestamp and the hostname of the machine. The generated ID is cached inside the base directory, so all future calls to this method return the same ID. - @type base_dir: String @param base_dir: Base directory - - @type refresh: bool @param refresh: Force generating a new machine ID - - @rtype: String @return: The generated/cached machine ID """ id_file = os.path.join(base_dir, "s3_id") @@ -529,7 +500,13 @@ def __get_machine_id(self, base_dir, refresh=False): with open(id_file) as id_fd: return id_fd.read() - def __upload_queue_files(self, queue_basedir, queue_files, base_dir, cmdline_file): + def __upload_queue_files( + self, + queue_basedir: str, + queue_files: list[str], + base_dir: str, + cmdline_file: str, + ) -> None: machine_id = self.__get_machine_id(base_dir) remote_path = f"{self.remote_path_queues}{machine_id}/" remote_files = [ diff --git a/misc/afl-libfuzzer/__init__.py b/misc/afl_libfuzzer/__init__.py similarity index 100% rename from misc/afl-libfuzzer/__init__.py rename to misc/afl_libfuzzer/__init__.py diff --git a/misc/afl-libfuzzer/afl-libfuzzer-daemon.py b/misc/afl_libfuzzer/afl-libfuzzer-daemon.py similarity index 95% rename from misc/afl-libfuzzer/afl-libfuzzer-daemon.py rename to misc/afl_libfuzzer/afl-libfuzzer-daemon.py index d81a012bf..92aea6f3d 100755 --- a/misc/afl-libfuzzer/afl-libfuzzer-daemon.py +++ b/misc/afl_libfuzzer/afl-libfuzzer-daemon.py @@ -12,6 +12,9 @@ @contact: choller@mozilla.com """ + +from __future__ import annotations + import argparse import collections import os @@ -53,20 +56,26 @@ class LibFuzzerMonitor(threading.Thread): - def __init__(self, process, killOnOOM=True, mid=None, mqueue=None): + def __init__( + self, + process: subprocess.Popen[str], + killOnOOM: bool = True, + mid: int | None = None, + mqueue: queue.Queue[int] | None = None, + ) -> None: threading.Thread.__init__(self) self.process = process self.fd = process.stderr - self.trace = [] - self.stderr = collections.deque([], 128) - self.inTrace = False - self.testcase = None + self.trace: list[str] = [] + self.stderr: collections.deque[str] = collections.deque([], 128) + self.inTrace: bool = False + self.testcase: str | None = None self.killOnOOM = killOnOOM self.hadOOM = False self.hitThreadLimit = False self.inited = False - self.mid = mid + self.mid: int | None = mid self.mqueue = mqueue # Keep some statistics @@ -77,14 +86,15 @@ def __init__(self, process, killOnOOM=True, mid=None, mqueue=None): self.last_new_pc = 0 # Store potential exceptions - self.exc = None + self.exc: Exception | None = None - def run(self): + def run(self) -> None: assert not self.hitThreadLimit assert not self.hadOOM try: while True: + assert self.fd is not None line = self.fd.readline(4096) if not line: @@ -154,18 +164,19 @@ def run(self): self.exc = e finally: if self.mqueue is not None: + assert self.mid is not None self.mqueue.put(self.mid) - def getASanTrace(self): + def getASanTrace(self) -> list[str]: return self.trace - def getTestcase(self): + def getTestcase(self) -> str | None: return self.testcase - def getStderr(self): + def getStderr(self) -> list[str]: return list(self.stderr) - def terminate(self): + def terminate(self) -> None: print(f"[Job {self.mid}] Received terminate request...", file=sys.stderr) # Avoid sending anything through the queue when the run() loop exits @@ -173,6 +184,7 @@ def terminate(self): self.process.terminate() # Emulate a wait() with timeout through poll and sleep + maxSleepTime: int | float (maxSleepTime, pollInterval) = (10, 0.2) while self.process.poll() is None and maxSleepTime > 0: maxSleepTime -= pollInterval @@ -184,14 +196,11 @@ def terminate(self): self.process.wait() -def command_file_to_list(cmd_file): +def command_file_to_list(cmd_file: str) -> tuple[int | None, list[str]]: """ Open and parse custom command line file - @type cmd_file: String @param cmd_file: Command line file containing list of commands - - @rtype: Tuple @return: Test index in list and the command as a list of strings """ cmdline = list() @@ -207,20 +216,15 @@ def command_file_to_list(cmd_file): return test_idx, cmdline -def write_stats_file(outfile, fields, stats, warnings): +def write_stats_file( + outfile: str, fields: list[str], stats, warnings: list[str] +) -> None: """ Write the given stats data to the specified file - @type outfile: str @param outfile: Output file for statistics - - @type fields: list @param fields: The list of fields to write out (defines the order as well) - - @type stats: dict @param stats: The dictionary containing the actual data - - @type warnings: list @param warnings: Any textual warnings to write in addition to stats """ @@ -244,18 +248,15 @@ def write_stats_file(outfile, fields, stats, warnings): return -def write_aggregated_stats_afl(base_dirs, outfile, cmdline_path=None): +def write_aggregated_stats_afl( + base_dirs: list[str], outfile: str, cmdline_path: str | None = None +) -> None: """ Generate aggregated statistics from the given base directories and write them to the specified output file. - @type base_dirs: list @param base_dirs: List of AFL base directories - - @type outfile: str @param outfile: Output file for aggregated statistics - - @type cmdline_path: String @param cmdline_path: Optional command line file to use instead of the one found inside the base directory. """ @@ -288,7 +289,7 @@ def write_aggregated_stats_afl(base_dirs, outfile, cmdline_path=None): fields.extend(wanted_fields_max) # Warnings to include - warnings = list() + warnings: list[str] = list() aggregated_stats = {} @@ -301,7 +302,7 @@ def write_aggregated_stats_afl(base_dirs, outfile, cmdline_path=None): for field in wanted_fields_all: aggregated_stats[field] = [] - def convert_num(num): + def convert_num(num: str) -> float | int: if "." in num: return float(num) return int(num) @@ -352,6 +353,7 @@ def convert_num(num): aggregated_stats[field_name] = val # Verify fuzzmanagerconf exists and can be parsed + assert cmdline_path is not None _, cmdline = command_file_to_list(cmdline_path) target_binary = cmdline[0] if cmdline else None @@ -377,21 +379,16 @@ def convert_num(num): return write_stats_file(outfile, fields, aggregated_stats, warnings) -def write_aggregated_stats_libfuzzer(outfile, stats, monitors, warnings): +def write_aggregated_stats_libfuzzer( + outfile: str, stats, monitors: list[LibFuzzerMonitor], warnings: list[str] +) -> None: """ Generate aggregated statistics for the given overall libfuzzer stats and the individual monitors. Results are written to the specified output file. - @type outfile: str @param outfile: Output file for aggregated statistics - - @type stats: dict @param stats: Dictionary containing overall stats - - @type monitors: list @param monitors: A list of LibFuzzerMonitor instances - - @type warnings: list @param warnings: Any textual warnings to write in addition to stats """ @@ -408,10 +405,10 @@ def write_aggregated_stats_libfuzzer(outfile, stats, monitors, warnings): ] # Which fields to aggregate by mean - wanted_fields_mean = [] + wanted_fields_mean: list[str] = [] # Which fields should be displayed per fuzzer instance - wanted_fields_all = [] + wanted_fields_all: list[str] = [] # Which fields should be aggregated by max wanted_fields_max = ["last_new", "last_new_pc"] @@ -483,39 +480,28 @@ def write_aggregated_stats_libfuzzer(outfile, stats, monitors, warnings): def scan_crashes( - base_dir, - collector, - cmdline_path=None, - env_path=None, - test_path=None, - firefox=None, - firefox_prefs=None, - firefox_extensions=None, - firefox_testpath=None, - transform=None, -): + base_dir: str, + collector: Collector, + cmdline_path: str | None = None, + env_path: str | None = None, + test_path: str | None = None, + firefox: str | None = None, + firefox_prefs: str | None = None, + firefox_extensions: str | None = None, + firefox_testpath: str | None = None, + transform: str | None = None, +) -> int: """ Scan the base directory for crash tests and submit them to FuzzManager. - @type base_dir: String @param base_dir: AFL base directory - - @type cmdline_path: String @param cmdline_path: Optional command line file to use instead of the one found inside the base directory. - - @type env_path: String @param env_path: Optional file containing environment variables. - - @type test_path: String @param test_path: Optional filename where to copy the test before attempting to reproduce a crash. - - @type transform: String @param transform: Optional path to script for applying post-crash transformations. - - @rtype: int @return: Non-zero return code on failure """ crash_dir = os.path.join(base_dir, "crashes") @@ -571,6 +557,7 @@ def scan_crashes( return 2 if firefox: + assert firefox_testpath is not None (ffpInst, ffCmd, ffEnv) = setup_firefox( cmdline[0], firefox_prefs, firefox_extensions, firefox_testpath ) @@ -594,6 +581,7 @@ def scan_crashes( if test_idx is not None: cmdline[test_idx] = orig_test_arg.replace("@@", crash_file) elif test_in_env is not None: + assert env is not None env[test_in_env] = env[test_in_env].replace("@@", crash_file) elif test_path is not None: shutil.copy(crash_file, test_path) @@ -622,7 +610,9 @@ def scan_crashes( ffpInst.clean_up() -def setup_firefox(bin_path, prefs_path, ext_paths, test_path): +def setup_firefox( + bin_path: str, prefs_path: str | None, ext_paths: str | None, test_path: str +): ffp = FFPuppet(use_xvfb=True) # For now we support only one extension, but FFPuppet will handle @@ -647,7 +637,7 @@ def setup_firefox(bin_path, prefs_path, ext_paths, test_path): return (ffp, cmd, env) -def test_binary_asan(bin_path): +def test_binary_asan(bin_path: str) -> bool: process = subprocess.Popen( ["nm", "-g", bin_path], stdin=subprocess.PIPE, @@ -665,17 +655,12 @@ def test_binary_asan(bin_path): return False -def apply_transform(script_path, testcase_path): +def apply_transform(script_path: str, testcase_path: str) -> str: """ Apply a post-crash transformation to the testcase - @type script_path: String @param script_path: Path to the transformation script - - @type testcase_path: String @param testcase_path: Path to the testcase - - @rtype: String @return: Path to the archive containing the original and transformed testcase """ @@ -699,7 +684,7 @@ def apply_transform(script_path, testcase_path): return archive_path -def main(argv=None): +def main(argv: list[str] | None = None) -> int: """Command line options.""" program_name = os.path.basename(sys.argv[0]) @@ -1078,7 +1063,7 @@ def main(argv=None): ) aflGroup.add_argument("rargs", nargs=argparse.REMAINDER) - def warn_local(): + def warn_local() -> None: if not opts.fuzzmanager and not opts.local: # User didn't specify --fuzzmanager but also didn't specify --local # explicitly, so we should warn them that their crash results won't end up @@ -1163,6 +1148,7 @@ def warn_local(): ) if opts.s3_queue_status: + assert s3m is not None status_data = s3m.get_queue_status() total_queue_files = 0 @@ -1174,6 +1160,7 @@ def warn_local(): return 0 if opts.s3_corpus_status: + assert s3m is not None status_data = s3m.get_corpus_status() total_corpus_files = 0 @@ -1185,18 +1172,22 @@ def warn_local(): return 0 if opts.s3_queue_cleanup: + assert s3m is not None s3m.clean_queue_dirs() return 0 if opts.s3_build_download: + assert s3m is not None s3m.download_build(opts.s3_build_download) return 0 if opts.s3_build_upload: + assert s3m is not None s3m.upload_build(opts.s3_build_upload) return 0 if opts.s3_corpus_download: + assert s3m is not None if opts.s3_corpus_download_size is not None: opts.s3_corpus_download_size = int(opts.s3_corpus_download_size) @@ -1204,10 +1195,12 @@ def warn_local(): return 0 if opts.s3_corpus_upload: + assert s3m is not None s3m.upload_corpus(opts.s3_corpus_upload, opts.s3_corpus_replace) return 0 if opts.s3_corpus_refresh: + assert s3m is not None if opts.aflfuzz and not opts.aflbindir: print( "Error: Must specify --afl-binary-dir for refreshing the test corpus", @@ -1224,8 +1217,8 @@ def warn_local(): s3m.clean_queue_dirs() print( - "Downloading queues from s3://%s/%s/queues/ to %s" - % (opts.s3_bucket, opts.project, queues_dir) + f"Downloading queues from s3://{opts.s3_bucket}/{opts.project}/queues/ to " + f"{queues_dir}" ) s3m.download_queue_dirs(opts.s3_corpus_refresh) @@ -1288,8 +1281,8 @@ def warn_local(): # Download our current corpus into the queues directory as well print( - "Downloading corpus from s3://%s/%s/corpus/ to %s" - % (opts.s3_bucket, opts.project, queues_dir) + f"Downloading corpus from s3://{opts.s3_bucket}/{opts.project}/corpus/ to " + f"{queues_dir}" ) s3m.download_corpus(queues_dir) @@ -1405,6 +1398,7 @@ def warn_local(): return 2 if opts.libfuzzer: + assert s3m is not None if not opts.rargs: print("Error: No arguments specified", file=sys.stderr) return 2 @@ -1565,7 +1559,7 @@ def warn_local(): print(rarg, file=fd) monitors = [None] * opts.libfuzzer_instances - monitor_queue = queue.Queue() + monitor_queue: queue.Queue[int] = queue.Queue() # Keep track how often we crash to abort in certain situations crashes_per_minute_interval = 0 @@ -1628,6 +1622,7 @@ def warn_local(): # so we cache it here to avoid running listdir multiple times. corpus_size = len(os.listdir(corpus_dir)) + assert corpus_size is not None if ( corpus_auto_reduce_threshold is not None and corpus_size >= corpus_auto_reduce_threshold @@ -1697,6 +1692,7 @@ def warn_local(): corpus_size = len(os.listdir(corpus_dir)) # Update our auto-reduction target + assert corpus_auto_reduce_ratio is not None if corpus_size >= opts.libfuzzer_auto_reduce_min: corpus_auto_reduce_threshold = int( corpus_size * (1 + corpus_auto_reduce_ratio) @@ -1741,6 +1737,7 @@ def warn_local(): continue monitor = monitors[result] + assert monitor is not None monitor.join(20) if monitor.is_alive(): raise RuntimeError( @@ -1903,6 +1900,7 @@ def warn_local(): for i in range(len(monitors)): if monitors[i] is not None: monitor = monitors[i] + assert monitor is not None monitor.terminate() monitor.join(10) finally: @@ -2022,6 +2020,7 @@ def warn_local(): # Only upload queue files every 20 minutes if opts.s3_queue_upload and last_queue_upload < int(time.time()) - 1200: + assert s3m is not None for afl_out_dir in afl_out_dirs: s3m.upload_afl_queue_dir(afl_out_dir, new_cov_only=True) last_queue_upload = int(time.time()) diff --git a/misc/ec2prices/simulations/best_every_n_hours.py b/misc/ec2prices/simulations/best_every_n_hours.py index 3c3d2f6b5..f5518d0e5 100644 --- a/misc/ec2prices/simulations/best_every_n_hours.py +++ b/misc/ec2prices/simulations/best_every_n_hours.py @@ -14,10 +14,13 @@ @contact: choller@mozilla.com """ + +from __future__ import annotations + from .common import select_better -def run(data, sim_config, main_config): +def run(data, sim_config: dict[str, str], main_config: dict[str, str]) -> int | None: region = list(data.keys())[0] zone = list(data[region].keys())[0] instance_type = list(data[region][zone].keys())[0] diff --git a/misc/ec2prices/simulations/choose_once.py b/misc/ec2prices/simulations/choose_once.py index 433bdda6f..812b730f5 100644 --- a/misc/ec2prices/simulations/choose_once.py +++ b/misc/ec2prices/simulations/choose_once.py @@ -14,10 +14,13 @@ @contact: choller@mozilla.com """ + +from __future__ import annotations + from .common import select_better -def run(data, sim_config, main_config): +def run(data, sim_config: dict[str, str], main_config: dict[str, str]) -> int: fixed_region = None fixed_instance_type = None diff --git a/misc/ec2prices/simulations/common.py b/misc/ec2prices/simulations/common.py index 853673a60..a7b6deffa 100644 --- a/misc/ec2prices/simulations/common.py +++ b/misc/ec2prices/simulations/common.py @@ -13,27 +13,41 @@ @contact: choller@mozilla.com """ +from __future__ import annotations + +from typing_extensions import NotRequired, TypedDict + + +class RetType(TypedDict): + """Type information for ret.""" + + instance_type: NotRequired[str] + price: NotRequired[int] + region: NotRequired[str] + zone: NotRequired[str] + def select_better( data, - current_price=None, - region=None, - zone=None, - instance_type=None, - instance_time=None, - indent=1, - verbose=False, -): + current_price: int | None = None, + region: str | None = None, + zone: str | None = None, + instance_type: str | None = None, + instance_time: int | None = None, + indent: int = 1, + verbose: bool = False, +) -> RetType: best_region = region best_zone = zone best_instance_type = instance_type best_price = current_price - def print_indent(s): + def print_indent(s: RetType | str) -> None: if verbose: print(f"{'*' * indent}{s}") if region is None: + assert best_price is not None for region_name in data: ret = select_better( data, @@ -107,7 +121,11 @@ def print_indent(s): current_time ] - new_ret = {} + assert best_region is not None + assert best_zone is not None + assert best_instance_type is not None + assert best_price is not None + new_ret: RetType = {} new_ret["region"] = best_region new_ret["zone"] = best_zone new_ret["instance_type"] = best_instance_type @@ -116,9 +134,9 @@ def print_indent(s): return new_ret -def get_price_median(data): +def get_price_median(data: list[float]) -> float: sdata = sorted(data) n = len(sdata) if not n % 2: - return (sdata[n / 2] + sdata[n / 2 - 1]) / 2.0 - return sdata[n / 2] + return (sdata[int(n / 2)] + sdata[int(n / 2) - 1]) / 2.0 + return sdata[int(n / 2)] diff --git a/misc/ec2prices/simulator.py b/misc/ec2prices/simulator.py index 0a61fa8d9..61d973a81 100644 --- a/misc/ec2prices/simulator.py +++ b/misc/ec2prices/simulator.py @@ -17,6 +17,9 @@ @contact: choller@mozilla.com """ + +from __future__ import annotations + import configparser import datetime import importlib @@ -33,13 +36,19 @@ # This function must be defined at the module level so it can be pickled # by the multiprocessing module when calling this asynchronously. def get_spot_price_per_region( - region_name, start_time, end_time, aws_key_id, aws_secret_key, instance_type + region_name: str, + start_time: datetime.datetime, + end_time: datetime.datetime, + aws_key_id: str, + aws_secret_key: str, + instance_type: str, ): """Gets spot prices of the specified region and instance type""" print( - "Region %s Instance Type %s Start %s End %s" - % (region_name, instance_type, start_time.isoformat(), end_time.isoformat()) + "Region {} Instance Type {} Start {} End {}".format( + region_name, instance_type, start_time.isoformat(), end_time.isoformat() + ) ) r = None @@ -69,15 +78,15 @@ def get_spot_price_per_region( def get_spot_prices( - regions, - start_time, - end_time, - aws_key_id, - aws_secret_key, - instance_types, - prices, - use_multiprocess=False, -): + regions: dict[str, str], + start_time: datetime.datetime, + end_time: datetime.datetime, + aws_key_id: str, + aws_secret_key: str, + instance_types: list[str], + prices: list[str], + use_multiprocess: bool = False, +) -> None: if use_multiprocess: from multiprocessing import Pool, cpu_count @@ -130,11 +139,7 @@ def get_spot_prices( ): prices[entry.region.name][zone][entry.instance_type][ start_time.isoformat() - ] = [ - end_time.isoformat(), - entry.price, - 1, - ] + ] = [end_time.isoformat(), entry.price, 1] else: cur = prices[entry.region.name][zone][entry.instance_type][ start_time.isoformat() @@ -144,17 +149,13 @@ def get_spot_prices( prices[entry.region.name][zone][entry.instance_type][ start_time.isoformat() - ] = [ - end_time.isoformat(), - mean_price, - cur[2] + 1, - ] + ] = [end_time.isoformat(), mean_price, cur[2] + 1] class ConfigurationFile: - def __init__(self, configFile): - self.simulations = OrderedDict() - self.main = {} + def __init__(self, configFile: str) -> None: + self.simulations: OrderedDict[str, str] = OrderedDict() + self.main: dict[str, str] = {} if configFile: self.parser = configparser.ConfigParser() @@ -199,7 +200,7 @@ def __init__(self, configFile): self.simulations[section] = sectionMap - def getSectionMap(self, section): + def getSectionMap(self, section: str) -> dict[str, str]: ret = OrderedDict() try: options = self.parser.options(section) @@ -210,7 +211,7 @@ def getSectionMap(self, section): return ret -def main(): +def main() -> int: """Command line options.""" # setup argparser @@ -229,7 +230,7 @@ def main(): print("Error: No simulations configured, exiting...") sys.exit(1) - results = OrderedDict() + results: OrderedDict[str, int] = OrderedDict() cacheFile = configFile.main["cache_file"] regions = configFile.main["regions"].split(",") @@ -239,6 +240,7 @@ def main(): aws_secret_key = configFile.main["aws_secret_key"] for (simulation_name, simulation) in configFile.simulations.items(): + assert isinstance(simulation, dict) sim_module = importlib.import_module(f"simulations.{simulation['handler']}") print(f"Performing simulation '{simulation_name}' ...") @@ -281,6 +283,7 @@ def main(): if col_len is None or col_len < len(simulation): col_len = len(simulation) + assert isinstance(col_len, int) col_len += 1 print("") @@ -305,6 +308,8 @@ def main(): sys.stdout.write(" " * (len(simulation_b) - len(p) + 2)) sys.stdout.write("\n") + return 0 + if __name__ == "__main__": sys.exit(main()) diff --git a/misc/libfuzzer/libfuzzer.py b/misc/libfuzzer/libfuzzer.py index 86fd7e20b..ea59d87b9 100755 --- a/misc/libfuzzer/libfuzzer.py +++ b/misc/libfuzzer/libfuzzer.py @@ -15,11 +15,15 @@ @contact: choller@mozilla.com """ + +from __future__ import annotations + import argparse import os import subprocess import sys import threading +from typing import IO, cast from Collector.Collector import Collector from FTB.ProgramConfiguration import ProgramConfiguration @@ -27,17 +31,17 @@ class LibFuzzerMonitor(threading.Thread): - def __init__(self, fd): + def __init__(self, fd: IO[str]) -> None: assert callable(fd.readline) threading.Thread.__init__(self) self.fd = fd - self.trace = [] + self.trace: list[str] = [] self.inTrace = False - self.testcase = None + self.testcase: str | None = None - def run(self): + def run(self) -> None: while True: line = self.fd.readline(4096) @@ -60,20 +64,20 @@ def run(self): self.fd.close() - def getASanTrace(self): + def getASanTrace(self) -> list[str]: return self.trace - def getTestcase(self): + def getTestcase(self) -> str | None: return self.testcase -__all__ = [] +__all__: list[str] = [] __version__ = 0.1 __date__ = "2016-07-28" __updated__ = "2016-07-28" -def main(argv=None): +def main(argv: list[str] | None = None) -> int | None: """Command line options.""" program_name = os.path.basename(sys.argv[0]) @@ -281,7 +285,7 @@ def main(argv=None): universal_newlines=True, ) - monitor = LibFuzzerMonitor(process.stderr) + monitor = LibFuzzerMonitor(cast(IO[str], process.stderr)) monitor.start() monitor.join() @@ -294,7 +298,7 @@ def main(argv=None): [], [], configuration, auxCrashData=trace ) - (sigfile, metadata) = collector.search(crashInfo) + (sigfile, _metadata) = collector.search(crashInfo) if sigfile is not None: if last_signature == sigfile: @@ -322,6 +326,8 @@ def main(argv=None): ) break + return None + if __name__ == "__main__": sys.exit(main()) diff --git a/misc/update_prices.py b/misc/update_prices.py index d24f1271f..dd826ba36 100644 --- a/misc/update_prices.py +++ b/misc/update_prices.py @@ -10,6 +10,7 @@ p3.2xlarge,m4.2xlarge,f1.2xlarge,h1.2xlarge,x1e.2xlarge,m5d.2xlarge,t2.2xlarge """ +from __future__ import annotations import json import sys @@ -54,7 +55,9 @@ } -def get_instance_types(regions=True, index_json=None): +def get_instance_types( + regions: bool = True, index_json: dict[str, object] | None = None +) -> dict[str, object]: """Fetch instance type data from EC2 pricing API. regions: if True, this will add a "regions" field to each instance type, stating @@ -91,8 +94,9 @@ def get_instance_types(regions=True, index_json=None): ).json() data = index_json["products"] + assert isinstance(data, dict) - instance_types = {} + instance_types: dict[str, object] = {} for product in data.values(): if ( @@ -107,6 +111,7 @@ def get_instance_types(regions=True, index_json=None): instance_data = instance_types.setdefault( product["attributes"]["instanceType"], {} ) + assert isinstance(instance_data, dict) if instance_data: # assert that all fields are the same! new_data = { @@ -141,6 +146,7 @@ def get_instance_types(regions=True, index_json=None): # normalize units for instance_data in instance_types.values(): + assert isinstance(instance_data, dict) if regions: instance_data["regions"] = list(instance_data["regions"]) instance_data["vcpu"] = int(instance_data["vcpu"]) @@ -155,7 +161,7 @@ def get_instance_types(regions=True, index_json=None): return instance_types -def main(): +def main() -> None: index_json = None if len(sys.argv) > 1: with open(sys.argv[1]) as data_fp: diff --git a/pyproject.toml b/pyproject.toml index 202c4cd7f..ad873c915 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,6 +14,9 @@ omit = [ "*/.eggs/*", ] +[tool.django-stubs] +django_settings_module = "server.settings_test" + [tool.isort] known_first_party = [ "Collector", @@ -29,6 +32,58 @@ known_first_party = [ ] profile = "black" +[tool.mypy] +plugins = [ + "mypy_django_plugin.main", + "mypy_drf_plugin.main", +] +strict = true +show_error_codes = true + +# Add Python modules to be ignored by mypy here +[[tool.mypy.overrides]] +module = [ + "celeryconf", + "chartjs.colors", # Package: django-chartjs + "chartjs.views.base", # Package: django-chartjs + "covmanager.SourceCodeProvider.GITSourceCodeProvider", + "covmanager.SourceCodeProvider.HGSourceCodeProvider", + "covmanager.SourceCodeProvider.SourceCodeProvider", + "covmanager.models", + "crashmanager", + "crashmanager.models", + "crispy_forms.helper", # Package: django-crispy-forms + "crispy_forms.layout", # Package: django-crispy-forms + "ec2spotmanager", + "ec2spotmanager.CloudProvider.CloudProvider", + "ec2spotmanager.cron", + "ec2spotmanager.models", + "ec2spotmanager.tasks", + "ec2spotmanager.tests", + "enumfields", # Package: django-enumfields + "enumfields.fields", # Package: django-enumfields + "fasteners", # Package: fasteners + "fuzzing_decision.common.pool", # Package: orion (Mozilla, needs FuzzManager types) + "laniakea.core.providers.ec2", # Package: laniakea (Mozilla, currently archived) + "laniakea.core.providers.gce", # Package: laniakea (Mozilla, currently archived) + "laniakea.core.userdata", # Package: laniakea (Mozilla, currently archived) + "mozilla_django_oidc.auth", # Package: mozilla-django-oidc + "mozillapulse.consumers", # Package: mozillapulse + "notifications", # Package: django-notifications-hq + "notifications.models", # Package: django-notifications-hq + "notifications.signals", # Package: django-notifications-hq + "server.auth", + "server.covmanager.models", # Located here in server/covmanager/models.py + "server.utils", + "server.views", + "taskcluster", # Package: taskcluster + "taskmanager.models", + "taskmanager.tasks", + "taskmanager.tests", + "S3Manager", # Located here in misc/afl_libfuzzer/S3Manager.py +] +ignore_missing_imports = true + [tool.pytest.ini_options] log_level = "DEBUG" DJANGO_SETTINGS_MODULE = "server.settings_test" diff --git a/server/conftest.py b/server/conftest.py index 7a94c2e25..f3eaaedae 100644 --- a/server/conftest.py +++ b/server/conftest.py @@ -8,12 +8,18 @@ License, v. 2.0. If a copy of the MPL was not distributed with this file, You can obtain one at http://mozilla.org/MPL/2.0/. """ + +from __future__ import annotations + import logging +from pathlib import Path import pytest from django.apps import apps from django.db import connection from django.db.migrations.executor import MigrationExecutor +from pytest_django.fixtures import SettingsWrapper +from rest_framework.request import Request from rest_framework.test import APIClient logging.getLogger("django").setLevel(logging.WARNING) @@ -24,19 +30,19 @@ @pytest.fixture(autouse=True) -def dj_static_tmp(tmp_path, settings): +def dj_static_tmp(tmp_path: Path, settings: SettingsWrapper) -> None: dj_static = tmp_path / "dj-static" dj_static.mkdir() settings.STATIC_ROOT = str(dj_static) @pytest.fixture -def api_client(): +def api_client() -> APIClient: return APIClient() @pytest.fixture -def migration_hook(request): +def migration_hook(request: Request): """ Pause migration at the migration named in @pytest.mark.migrate_from('0001-initial-migration') @@ -67,13 +73,17 @@ def migration_hook(request): assert len(migrate_to_mark.args) == 1, "migrate_to mark expects 1 arg" assert not migrate_to_mark.kwargs, "migrate_to mark takes no keywords" - app = apps.get_containing_app_config(request.module.__name__).name + apps_get_containing_app_config = apps.get_containing_app_config( + request.module.__name__ + ) + assert apps_get_containing_app_config is not None + app = apps_get_containing_app_config.name migrate_from = [(app, migrate_from_mark.args[0])] migrate_to = [(app, migrate_to_mark.args[0])] class migration_hook_result: - def __init__(self, _from, _to): + def __init__(self, _from, _to) -> None: self._to = _to executor = MigrationExecutor(connection) self.apps = executor.loader.project_state(_from).apps @@ -81,7 +91,7 @@ def __init__(self, _from, _to): # Reverse to the original migration executor.migrate(_from) - def __call__(self): + def __call__(self) -> None: # Run the migration to test executor = MigrationExecutor(connection) executor.loader.build_graph() # reload. diff --git a/server/contrib/create_user.py b/server/contrib/create_user.py index 064a1bd85..400d36c4d 100644 --- a/server/contrib/create_user.py +++ b/server/contrib/create_user.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import argparse import os import random @@ -7,7 +9,7 @@ import django -def main(): +def main() -> None: parser = argparse.ArgumentParser() parser.add_argument("--username", dest="username", type=str, required=True) parser.add_argument("--email", dest="email", type=str, required=True) diff --git a/server/contrib/fuzzmanager_setup_helper.py b/server/contrib/fuzzmanager_setup_helper.py index 1f9604acf..a359c09a2 100644 --- a/server/contrib/fuzzmanager_setup_helper.py +++ b/server/contrib/fuzzmanager_setup_helper.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import random import string @@ -7,7 +9,7 @@ import django -def create_fuzzmanager(): +def create_fuzzmanager() -> None: # This is a throw away password, that is soon reset to the auth_token for # fuzzmanager password = "".join(random.sample(string.letters, 20)) @@ -80,7 +82,7 @@ def create_fuzzmanager(): print("Something went wrong creating the fuzzmanager account") -def main(): +def main() -> None: os.environ.setdefault("DJANGO_SETTINGS_MODULE", "server.settings") from django.contrib.auth.models import User diff --git a/server/covmanager/SourceCodeProvider/GITSourceCodeProvider.py b/server/covmanager/SourceCodeProvider/GITSourceCodeProvider.py index 14be736dd..0200e0750 100644 --- a/server/covmanager/SourceCodeProvider/GITSourceCodeProvider.py +++ b/server/covmanager/SourceCodeProvider/GITSourceCodeProvider.py @@ -11,6 +11,9 @@ @contact: choller@mozilla.com """ + +from __future__ import annotations + import subprocess from .SourceCodeProvider import ( @@ -21,10 +24,10 @@ class GITSourceCodeProvider(SourceCodeProvider): - def __init__(self, location): + def __init__(self, location: str) -> None: super().__init__(location) - def getSource(self, filename, revision): + def getSource(self, filename: str, revision: str) -> str: try: return subprocess.check_output( ["git", "show", f"{revision}:{filename}"], cwd=self.location @@ -37,7 +40,7 @@ def getSource(self, filename, revision): # Otherwise assume the file doesn't exist raise UnknownFilenameException - def testRevision(self, revision): + def testRevision(self, revision: str) -> bool: try: subprocess.check_output( ["git", "show", revision], cwd=self.location, stderr=subprocess.STDOUT @@ -46,11 +49,11 @@ def testRevision(self, revision): return False return True - def update(self): + def update(self) -> None: # TODO: This will fail without remotes subprocess.check_call(["git", "fetch"], cwd=self.location) - def getParents(self, revision): + def getParents(self, revision: str) -> list[str]: try: output = subprocess.check_output( ["git", "log", revision, "--format=%P"], cwd=self.location @@ -58,19 +61,19 @@ def getParents(self, revision): except subprocess.CalledProcessError: raise UnknownRevisionException - output = output.decode("utf-8").splitlines() + output_str = output.decode("utf-8").splitlines() # No parents - if not output[0]: + if not output_str[0]: return [] - return output[0].split(" ") + return output_str[0].split(" ") - def getUnifiedDiff(self, revision): + def getUnifiedDiff(self, revision: str) -> str: # TODO: Implement this method for GIT pass - def checkRevisionsEquivalent(self, revisionA, revisionB): + def checkRevisionsEquivalent(self, revisionA: str, revisionB: str) -> bool: # We do not implement any kind of revision equivalence # for GIT other than equality. return revisionA == revisionB diff --git a/server/covmanager/SourceCodeProvider/HGSourceCodeProvider.py b/server/covmanager/SourceCodeProvider/HGSourceCodeProvider.py index 6899d65a3..01e43fc04 100644 --- a/server/covmanager/SourceCodeProvider/HGSourceCodeProvider.py +++ b/server/covmanager/SourceCodeProvider/HGSourceCodeProvider.py @@ -11,6 +11,9 @@ @contact: choller@mozilla.com """ + +from __future__ import annotations + import re import subprocess @@ -22,10 +25,10 @@ class HGSourceCodeProvider(SourceCodeProvider): - def __init__(self, location): + def __init__(self, location: str) -> None: super().__init__(location) - def getSource(self, filename, revision): + def getSource(self, filename: str, revision: str) -> str: revision = revision.replace("+", "") # Avoid passing in absolute filenames to HG @@ -44,7 +47,7 @@ def getSource(self, filename, revision): # Otherwise assume the file doesn't exist raise UnknownFilenameException - def testRevision(self, revision): + def testRevision(self, revision: str) -> bool: revision = revision.replace("+", "") try: @@ -57,11 +60,11 @@ def testRevision(self, revision): return False return True - def update(self): + def update(self) -> None: # TODO: This will fail without remotes subprocess.check_call(["hg", "pull"], cwd=self.location) - def getParents(self, revision): + def getParents(self, revision: str) -> list[str]: revision = revision.replace("+", "") try: @@ -72,18 +75,18 @@ def getParents(self, revision): except subprocess.CalledProcessError: raise UnknownRevisionException - output = output.splitlines() + output_str = output.splitlines() parents = [] - for line in output: + for line in output_str: result = re.match(r"\d+:([0-9a-f]+)\s+", line) if result: parents.append(result.group(1)) return parents - def getUnifiedDiff(self, revision): + def getUnifiedDiff(self, revision: str) -> str: revision = revision.replace("+", "") try: @@ -95,7 +98,7 @@ def getUnifiedDiff(self, revision): return output.decode("utf-8") - def checkRevisionsEquivalent(self, revisionA, revisionB): + def checkRevisionsEquivalent(self, revisionA: str, revisionB: str) -> bool: # Check if revisions are equal if revisionA == revisionB: return True diff --git a/server/covmanager/SourceCodeProvider/SourceCodeProvider.py b/server/covmanager/SourceCodeProvider/SourceCodeProvider.py index 4fadbd07a..28d857bc7 100644 --- a/server/covmanager/SourceCodeProvider/SourceCodeProvider.py +++ b/server/covmanager/SourceCodeProvider/SourceCodeProvider.py @@ -11,8 +11,13 @@ @contact: choller@mozilla.com """ + +from __future__ import annotations + from abc import ABCMeta, abstractmethod +from typing_extensions import NotRequired, TypedDict + class UnknownRevisionException(Exception): pass @@ -28,41 +33,33 @@ class SourceCodeProvider(metaclass=ABCMeta): implement """ - def __init__(self, location): + def __init__(self, location: str) -> None: self.location = location @abstractmethod - def getSource(self, filename, revision): + def getSource(self, filename: str, revision: str) -> str: """ Return the source code for the given filename on the given revision. - @ptype filename: string @param filename: The path to the requested file, relative to the root of the repository. - - @ptype revision: string @param revision: The revision to use when retrieving the source code. - - @rtype string @return The requested source code as a single string. """ - return + return "" @abstractmethod - def testRevision(self, revision): + def testRevision(self, revision: str) -> bool: """ Check if the given revision exists in the resource associated with this provider - @ptype revision: string @param revision: The revision to check for. - - @rtype bool @return True, if the revision exists, False otherwise. """ - return + return False @abstractmethod - def update(self): + def update(self) -> None: """ Update the resource associated with this provider. @@ -76,51 +73,49 @@ def update(self): return @abstractmethod - def getParents(self, revision): + def getParents(self, revision: str) -> list[str]: """ Gets the parent revisions of the specified revision. - @ptype revision: string @param revision: The revision to get parents for. - - @rtype list @return The list of parent revisions. """ - return + return [] @abstractmethod - def getUnifiedDiff(self, revision): + def getUnifiedDiff(self, revision: str) -> str: """ Return a GIT-style unified diff for the given revision. - @ptype revision: string @param revision: The revision to get the diff for. - - @rtype string @return The unified diff as a single string. """ - return + return "" @abstractmethod - def checkRevisionsEquivalent(self, revisionA, revisionB): + def checkRevisionsEquivalent(self, revisionA: str, revisionB: str) -> bool: """ Check if the given revisions are considered to be equivalent. - @ptype revisionA: string @param revisionA: The first revision to compare. - - @ptype revisionB: string @param revisionB: The second revision to compare. - - @rtype bool @return True, if the revisions are equivalent, False otherwise. """ - return + return False + + +class CObj(TypedDict): + """CObj type specification.""" + + filename: str | None + locations: list[int] + missed: NotRequired[list[int]] + not_coverable: NotRequired[list[int]] class Utils: @staticmethod - def getDiffLocations(diff): + def getDiffLocations(diff: str) -> list[CObj]: """ This method tries to return reasonable diff hunk locations for each changed file in the given unified diff, where the locations refer to the @@ -139,28 +134,25 @@ def getDiffLocations(diff): Note that the heuristics used here are far from perfect and are only meant to aid manual inspection. - @ptype diff: string @param diff: A GIT-style unified diff as a single string. - - @rtype list @return A list containing one object per file changed. Each object in the list has two attributes, "filename" and "locations", where "locations" is the list of diff hunk locations for that particular file. """ - ret = [] - diff = diff.splitlines() + ret: list[CObj] = [] + diff_list = diff.splitlines() - while diff: - cobj = {"filename": None, "locations": []} + while diff_list: + cobj: CObj = {"filename": None, "locations": []} skipDiff = False - line = diff.pop(0) + line = diff_list.pop(0) if line.startswith("diff --git "): - (mm, mmLine) = diff.pop(0).split(" ", 2) - (pp, ppLine) = diff.pop(0).split(" ", 2) + (mm, mmLine) = diff_list.pop(0).split(" ", 2) + (pp, ppLine) = diff_list.pop(0).split(" ", 2) if not mm == "---" or not pp == "+++": raise RuntimeError("Malformed trace") @@ -171,11 +163,11 @@ def getDiffLocations(diff): cobj["filename"] = mmLine[2:] skipHunk = False - lastHunkStart = None - hunkLineRemoveCount = None + lastHunkStart = 0 + hunkLineRemoveCount = 0 - while diff and not diff[0].startswith("diff --git "): - line = diff.pop(0) + while diff_list and not diff_list[0].startswith("diff --git "): + line = diff_list.pop(0) if not skipDiff: if line.startswith("@@ "): diff --git a/server/covmanager/SourceCodeProvider/tests/test_sourcecodeprovider.py b/server/covmanager/SourceCodeProvider/tests/test_sourcecodeprovider.py index e311fb29c..445af5b61 100644 --- a/server/covmanager/SourceCodeProvider/tests/test_sourcecodeprovider.py +++ b/server/covmanager/SourceCodeProvider/tests/test_sourcecodeprovider.py @@ -11,8 +11,13 @@ @contact: choller@mozilla.com """ + +from __future__ import annotations + import os import shutil +from pathlib import Path +from typing import Iterator import pytest @@ -21,8 +26,8 @@ from covmanager.SourceCodeProvider.SourceCodeProvider import Utils -@pytest.fixture -def git_repo(tmp_path): +@pytest.fixture() +def git_repo(tmp_path: Path) -> Iterator[str]: shutil.copytree( os.path.join(os.path.dirname(os.path.abspath(__file__)), "test-git"), str(tmp_path / "test-git"), @@ -31,8 +36,8 @@ def git_repo(tmp_path): yield str(tmp_path / "test-git") -@pytest.fixture -def hg_repo(tmp_path): +@pytest.fixture() +def hg_repo(tmp_path: Path) -> Iterator[str]: shutil.copytree( os.path.join(os.path.dirname(os.path.abspath(__file__)), "test-hg"), str(tmp_path / "test-hg"), @@ -41,7 +46,7 @@ def hg_repo(tmp_path): yield str(tmp_path / "test-hg") -def test_GITSourceCodeProvider(git_repo): +def test_GITSourceCodeProvider(git_repo: str) -> None: provider = GITSourceCodeProvider(git_repo) tests = { @@ -74,7 +79,7 @@ def test_GITSourceCodeProvider(git_repo): assert len(parents) == 0 -def test_HGSourceCodeProvider(hg_repo): +def test_HGSourceCodeProvider(hg_repo: str) -> None: provider = HGSourceCodeProvider(hg_repo) tests = { @@ -111,14 +116,14 @@ def test_HGSourceCodeProvider(hg_repo): not os.path.isdir("/home/decoder/Mozilla/repos/mozilla-central-fm"), reason="not decoder", ) -def test_HGDiff(): +def test_HGDiff() -> None: provider = HGSourceCodeProvider("/home/decoder/Mozilla/repos/mozilla-central-fm") diff = provider.getUnifiedDiff("4f8e0cb21016") print(Utils.getDiffLocations(diff)) -def test_HGRevisionEquivalence(): +def test_HGRevisionEquivalence() -> None: provider = HGSourceCodeProvider("") # Simple equality for short and long revision formats diff --git a/server/covmanager/cron.py b/server/covmanager/cron.py index caaf0678e..ef8245ad2 100644 --- a/server/covmanager/cron.py +++ b/server/covmanager/cron.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import logging import requests @@ -20,7 +22,7 @@ # a summarized report for your testing efforts. -def create_weekly_report_mc(revision): +def create_weekly_report_mc(revision: str) -> None: from crashmanager.models import Client from .models import Collection, Report, Repository @@ -39,7 +41,9 @@ def create_weekly_report_mc(revision): .exclude(description__contains="IGNORE_MERGE") ) - last_monday = collections.first().created + relativedelta(weekday=MO(-1)) + collections_first = collections.first() + assert collections_first is not None + last_monday = collections_first.created + relativedelta(weekday=MO(-1)) mergedCollection = Collection() mergedCollection.description = "Weekly Report (Week of {}, {} reports)".format( @@ -70,7 +74,7 @@ def create_weekly_report_mc(revision): @app.task(ignore_result=True) -def create_current_weekly_report_mc(): +def create_current_weekly_report_mc() -> None: COVERAGE_REVISION_URL = getattr(settings, "COVERAGE_REVISION_URL", None) if not COVERAGE_REVISION_URL: diff --git a/server/covmanager/management/commands/setup_repository.py b/server/covmanager/management/commands/setup_repository.py index 0ac2ad89f..e5848d408 100644 --- a/server/covmanager/management/commands/setup_repository.py +++ b/server/covmanager/management/commands/setup_repository.py @@ -1,4 +1,8 @@ +from __future__ import annotations + import os +from argparse import ArgumentParser +from typing import Any from django.core.management.base import BaseCommand, CommandError @@ -8,12 +12,12 @@ class Command(BaseCommand): help = "Sets up a repository for CovManager" - def add_arguments(self, parser): + def add_arguments(self, parser: ArgumentParser) -> None: parser.add_argument("name", help="repository identifier") parser.add_argument("provider", help="SourceCodeProvider subclass") parser.add_argument("location", help="path to the repository root") - def handle(self, name, provider, location, **opts): + def handle(self, name, provider, location, **opts: Any) -> None: if not name: raise CommandError("Error: invalid repository name") @@ -52,6 +56,6 @@ def handle(self, name, provider, location, **opts): repository.save() print( - "Successfully created repository '%s' with provider '%s' located at %s" - % (name, provider, location) + f"Successfully created repository '{name}' with provider '{provider}' " + f"located at {location}" ) diff --git a/server/covmanager/migrations/0006_auto_20210429_0908.py b/server/covmanager/migrations/0006_auto_20210429_0908.py index 54c440dfd..1dd8e9f22 100644 --- a/server/covmanager/migrations/0006_auto_20210429_0908.py +++ b/server/covmanager/migrations/0006_auto_20210429_0908.py @@ -1,5 +1,7 @@ # Generated by Django 3.0.14 on 2021-04-29 09:08 +from __future__ import annotations + import django.core.files.storage from django.conf import settings from django.db import migrations, models diff --git a/server/covmanager/models.py b/server/covmanager/models.py index cf8a797ae..06a886d6a 100644 --- a/server/covmanager/models.py +++ b/server/covmanager/models.py @@ -1,10 +1,15 @@ +from __future__ import annotations + import codecs import json +from datetime import datetime +from typing import Any from django.conf import settings from django.contrib.auth.models import User as DjangoUser # noqa from django.core.files.storage import FileSystemStorage from django.db import models +from django.db.models.query import QuerySet from django.db.models.signals import post_delete, post_save from django.dispatch.dispatcher import receiver from django.utils import timezone @@ -47,11 +52,11 @@ class Collection(models.Model): branch = models.CharField(max_length=255, blank=True) tools = models.ManyToManyField(Tool) client = models.ForeignKey(Client, on_delete=models.deletion.CASCADE) - coverage = models.ForeignKey( + coverage: CollectionFile | None = models.ForeignKey( CollectionFile, blank=True, null=True, on_delete=models.deletion.CASCADE ) - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: # This variable can hold the deserialized contents of the coverage blob self.content = None @@ -72,19 +77,19 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - def loadCoverage(self): + def loadCoverage(self) -> None: + assert self.coverage is not None self.coverage.file.open(mode="rb") self.content = json.load(codecs.getreader("utf-8")(self.coverage.file)) self.coverage.file.close() - def annotateSource(self, path, coverage): + def annotateSource(self, path: str, coverage) -> None: """ Annotate the source code to the given (leaf) coverage object by querying the SourceCodeProvider registered for the repository associated with this collection. The resulting source code is added to a "source" property in the object. - @type path: string @param path: The path to the source code that this coverage belongs to. @type coverage: dict @@ -95,12 +100,11 @@ def annotateSource(self, path, coverage): provider = self.repository.getInstance() coverage["source"] = provider.getSource(path, self.revision) - def subset(self, path, report_configuration=None): + def subset(self, path: str, report_configuration=None): """ Calculate a subset of the coverage stored in this collection based on the given path. - @type path: string @param path: The path to reduce to. It is expected to use forward slashes. The path is interpreted as relative to the root of the collection. @@ -128,6 +132,7 @@ def subset(self, path, report_configuration=None): names[0] = "" try: + assert self.content is not None ret = self.content["children"] for name in names[:-1]: ret = ret[name]["children"] @@ -160,7 +165,7 @@ def remove_childrens_children(coverage): coverage["children"][child]["children"] = True @staticmethod - def strip(coverage): + def strip(coverage) -> None: """ This method strips all detailed coverage information from the given coverage data. Only the summarized coverage fields are left intact. @@ -185,7 +190,7 @@ def strip(coverage): # This post_delete handler ensures that the corresponding coverage # file is deleted when the Collection is gone. @receiver(post_delete, sender=Collection) -def Collection_delete(sender, instance, **kwargs): +def Collection_delete(sender: Collection, instance: Collection, **kwargs: Any) -> None: if instance.coverage: instance.coverage.file.delete(False) instance.coverage.delete(False) @@ -195,7 +200,9 @@ def Collection_delete(sender, instance, **kwargs): if getattr(settings, "USE_CELERY", None): @receiver(post_save, sender=Collection) - def Collection_save(sender, instance, **kwargs): + def Collection_save( + sender: Collection, instance: Collection, **kwargs: Any + ) -> None: check_revision_update.delay(instance.pk) @@ -208,7 +215,7 @@ class ReportConfiguration(models.Model): "self", blank=True, null=True, on_delete=models.deletion.CASCADE ) - def apply(self, collection): + def apply(self, collection: Collection) -> None: CoverageHelper.apply_include_exclude_directives( collection, self.directives.splitlines() ) @@ -217,7 +224,7 @@ def apply(self, collection): class ReportSummary(models.Model): collection = models.OneToOneField(Collection, on_delete=models.deletion.CASCADE) - cached_result = models.TextField(null=True, blank=True) + cached_result: str | None = models.TextField(null=True, blank=True) class Report(models.Model): diff --git a/server/covmanager/serializers.py b/server/covmanager/serializers.py index 1d055fc69..e64c77bb9 100644 --- a/server/covmanager/serializers.py +++ b/server/covmanager/serializers.py @@ -1,4 +1,7 @@ +from __future__ import annotations + import hashlib +from typing import Any from django.core.exceptions import MultipleObjectsReturned # noqa from django.core.files.base import ContentFile @@ -14,7 +17,7 @@ class InvalidArgumentException(APIException): status_code = 400 -class CollectionSerializer(serializers.ModelSerializer): +class CollectionSerializer(serializers.ModelSerializer[Collection]): # We need to redefine several fields explicitly because we flatten our # foreign keys into these fields instead of using primary keys, hyperlinks # or slug fields. All of the other solutions would require the client to @@ -54,7 +57,7 @@ def to_representation(self, obj): return serialized - def create(self, attrs): + def create(self, attrs) -> Collection: """ Create a Collection instance based on the given dictionary of values received. We need to unflatten foreign relationships like repository, @@ -104,14 +107,14 @@ def create(self, attrs): return super().create(attrs) -class RepositorySerializer(serializers.ModelSerializer): +class RepositorySerializer(serializers.ModelSerializer[Repository]): class Meta: model = Repository fields = ("name",) read_only_fields = ("name",) -class ReportConfigurationSerializer(serializers.ModelSerializer): +class ReportConfigurationSerializer(serializers.ModelSerializer[ReportConfiguration]): repository = serializers.CharField(source="repository.name", max_length=255) class Meta: @@ -126,7 +129,7 @@ class Meta: ) read_only_fields = ("id", "created") - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) request = self.context.get("request") @@ -139,7 +142,7 @@ def __init__(self, *args, **kwargs): for field in exclude_fields: self.fields.pop(field) - def handle_repository(self, attrs): + def handle_repository(self, attrs) -> None: """ When creating or updating a ReportConfiguration instance, we need to unflatten the foreign relationship to the repository and validate that it exists. @@ -159,20 +162,20 @@ def handle_repository(self, attrs): attrs["repository"] = repository[0] - def update(self, instance, attrs): + def update(self, instance, attrs) -> ReportConfiguration: self.handle_repository(attrs) # Update our ReportConfiguration instance return super().update(instance, attrs) - def create(self, attrs): + def create(self, attrs) -> ReportConfiguration: self.handle_repository(attrs) # Create our ReportConfiguration instance return super().create(attrs) -class ReportSerializer(serializers.ModelSerializer): +class ReportSerializer(serializers.ModelSerializer[Report]): class Meta: model = Report fields = ( @@ -185,7 +188,7 @@ class Meta: ) read_only_fields = ("id", "data_created", "coverage") - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) request = self.context.get("request") diff --git a/server/covmanager/tasks.py b/server/covmanager/tasks.py index 0d6817817..748fc2ca6 100644 --- a/server/covmanager/tasks.py +++ b/server/covmanager/tasks.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import copy import hashlib import json @@ -10,7 +12,7 @@ @app.task(ignore_result=True) -def check_revision_update(pk): +def check_revision_update(pk: int) -> None: from covmanager.models import Collection, Repository # noqa collection = Collection.objects.get(pk=pk) @@ -32,7 +34,7 @@ def check_revision_update(pk): @app.task(ignore_result=True) -def aggregate_coverage_data(pk, pks): +def aggregate_coverage_data(pk: int, pks: list[int]) -> None: from covmanager.models import Collection, CollectionFile from FTB import CoverageHelper @@ -84,7 +86,7 @@ def aggregate_coverage_data(pk, pks): @app.task(ignore_result=True) -def calculate_report_summary(pk): +def calculate_report_summary(pk: int) -> None: from covmanager.models import ReportConfiguration, ReportSummary summary = ReportSummary.objects.get(pk=pk) @@ -142,6 +144,7 @@ def calculate_report_summary(pk): if waiting: # We shouldn't have orphaned reports + assert data is not None data["warning"] = "There are orphaned reports that won't be displayed." summary.cached_result = json.dumps(data) diff --git a/server/covmanager/templatetags/recurseroot.py b/server/covmanager/templatetags/recurseroot.py index 3e65c3568..520588f37 100644 --- a/server/covmanager/templatetags/recurseroot.py +++ b/server/covmanager/templatetags/recurseroot.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from django import template from django.utils.safestring import mark_safe @@ -5,11 +7,15 @@ class RecurseReportSummaryTree(template.Node): - def __init__(self, template_nodes, config_var): + def __init__( + self, template_nodes: template.NodeList, config_var: template.Variable + ) -> None: self.template_nodes = template_nodes self.config_var = config_var - def _render_node(self, context, node): + def _render_node( + self, context: template.context.Context, node: template.Node + ) -> str: context.push() context["node"] = node if "children" in node: @@ -19,12 +25,14 @@ def _render_node(self, context, node): context.pop() return rendered - def render(self, context): + def render(self, context: template.context.Context) -> str: return self._render_node(context, self.config_var.resolve(context)) @register.tag -def recurseroot(parser, token): +def recurseroot( + parser: template.base.Parser, token: template.base.Token +) -> RecurseReportSummaryTree: bits = token.contents.split() if len(bits) != 2: raise template.TemplateSyntaxError(f"{bits[0]} tag requires a root") diff --git a/server/covmanager/tests/__init__.py b/server/covmanager/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/server/covmanager/tests/conftest.py b/server/covmanager/tests/conftest.py index 1c6a1ced9..6beeab07b 100644 --- a/server/covmanager/tests/conftest.py +++ b/server/covmanager/tests/conftest.py @@ -8,15 +8,23 @@ License, v. 2.0. If a copy of the MPL was not distributed with this file, You can obtain one at http://mozilla.org/MPL/2.0/. """ + +from __future__ import annotations + import logging import os import shutil import subprocess import tempfile +from typing import cast +import py as py_package import pytest from django.contrib.auth.models import Permission, User from django.contrib.contenttypes.models import ContentType +from django.core.files.storage import Storage +from pytest_django.fixtures import SettingsWrapper +from typing_extensions import TypedDict from covmanager.models import Collection, CollectionFile, Repository from crashmanager.models import Client, Tool @@ -25,7 +33,7 @@ LOG = logging.getLogger("fm.covmanager.tests") -def _check_git(): +def _check_git() -> bool: try: proc = subprocess.Popen( ["git"], stdout=subprocess.PIPE, stderr=subprocess.STDOUT @@ -38,7 +46,7 @@ def _check_git(): return False -def _check_hg(): +def _check_hg() -> bool: try: proc = subprocess.Popen( ["hg"], stdout=subprocess.PIPE, stderr=subprocess.STDOUT @@ -56,7 +64,7 @@ def _check_hg(): @pytest.fixture -def covmanager_test(db): # pylint: disable=invalid-name,unused-argument +def covmanager_test(db: None) -> None: # pylint: disable=invalid-name,unused-argument """Common setup/teardown tasks for all server unittests""" user = User.objects.create_user("test", "test@mozilla.com", "test") user.user_permissions.clear() @@ -67,14 +75,69 @@ def covmanager_test(db): # pylint: disable=invalid-name,unused-argument user_np.user_permissions.clear() +class covType(TypedDict): + """Type information for cov""" + + children: dict[str, str] + coveragePercent: float + linesCovered: int + linesMissed: int + linesTotal: int + name: str | None + + +class _result: # pylint: disable=invalid-name + have_git: bool + have_hg: bool + + @classmethod + def create_repository(cls, repotype: str, name: str = "testrepo") -> Repository: + ... + + @staticmethod + def create_collection_file(data: str) -> CollectionFile: + ... + + @classmethod + def create_collection( + cls, + created: bool | None = None, + description: str = "", + repository: Repository | None = None, + revision: str = "", + branch: str = "", + tools: tuple[str] = ("testtool",), + client: str = "testclient", + coverage: str = '{"linesTotal":0,' + '"name":null,' + '"coveragePercent":0.0,' + '"children":{},' + '"linesMissed":0,' + '"linesCovered":0}', + ) -> Collection: + ... + + @staticmethod + def git(repo: Repository, *args: str) -> str: + ... + + @staticmethod + def hg(repo: Repository, *args: str) -> str: + ... + + @pytest.fixture -def cm(request, settings, tmpdir): +def cm( + request: pytest.FixtureRequest, + settings: SettingsWrapper, + tmpdir: py_package.path.local, +): class _result: have_git = HAVE_GIT have_hg = HAVE_HG @classmethod - def create_repository(cls, repotype, name="testrepo"): + def create_repository(cls, repotype: str, name: str = "testrepo") -> Repository: location = tempfile.mkdtemp( prefix="testrepo", dir=os.path.dirname(__file__) ) @@ -91,8 +154,11 @@ def create_repository(cls, repotype, name="testrepo"): raise Exception( f"unknown repository type: {repotype} (expecting git or hg)" ) - result = Repository.objects.create( - classname=classname, name=name, location=location + result = cast( + Repository, + Repository.objects.create( + classname=classname, name=name, location=location + ), ) LOG.debug("Created Repository pk=%d", result.pk) if repotype == "git": @@ -102,66 +168,73 @@ def create_repository(cls, repotype, name="testrepo"): return result @staticmethod - def create_collection_file(data): + def create_collection_file(data: str) -> CollectionFile: # Use a specific temporary directory to upload covmanager files. This is # required as Django now needs a path relative to that folder in FileField location = str(tmpdir) + assert isinstance(CollectionFile.file.field.storage, Storage) CollectionFile.file.field.storage.location = location tmp_fd, path = tempfile.mkstemp(suffix=".data", dir=location) os.close(tmp_fd) with open(path, "w") as fp: fp.write(data) - result = CollectionFile.objects.create(file=os.path.basename(path)) + result = cast( + CollectionFile, + CollectionFile.objects.create(file=os.path.basename(path)), + ) LOG.debug("Created CollectionFile pk=%d", result.pk) return result @classmethod def create_collection( cls, - created=None, - description="", - repository=None, - revision="", - branch="", - tools=("testtool",), - client="testclient", - coverage='{"linesTotal":0,' + created: bool | None = None, + description: str = "", + repository: Repository | None = None, + revision: str = "", + branch: str = "", + tools: tuple[str] = ("testtool",), + client: str = "testclient", + coverage: str = '{"linesTotal":0,' '"name":null,' '"coveragePercent":0.0,' '"children":{},' '"linesMissed":0,' '"linesCovered":0}', - ): + ) -> Collection: # create collectionfile - coverage = cls.create_collection_file(coverage) + coverage_ = cls.create_collection_file(coverage) # create client - client, created = Client.objects.get_or_create(name=client) + client_, created = Client.objects.get_or_create(name=client) if created: - LOG.debug("Created Client pk=%d", client.pk) + LOG.debug("Created Client pk=%d", client_.pk) # create repository if repository is None: repository = cls.create_repository("git") - result = Collection.objects.create( - description=description, - repository=repository, - revision=revision, - branch=branch, - client=client, - coverage=coverage, + result = cast( + Collection, + Collection.objects.create( + description=description, + repository=repository, + revision=revision, + branch=branch, + client=client_, + coverage=coverage_, + ), ) LOG.debug("Created Collection pk=%d", result.pk) # create tools - for tool in tools: - tool, created = Tool.objects.get_or_create(name=tool) + for single_tool in tools: + tool, created = Tool.objects.get_or_create(name=single_tool) if created: LOG.debug("Created Tool pk=%d", tool.pk) result.tools.add(tool) return result @staticmethod - def git(repo, *args): + def git(repo: Repository, *args: str) -> str: path = os.getcwd() try: os.chdir(repo.location) @@ -170,7 +243,7 @@ def git(repo, *args): os.chdir(path) @staticmethod - def hg(repo, *args): + def hg(repo: Repository, *args: str) -> str: path = os.getcwd() try: os.chdir(repo.location) diff --git a/server/covmanager/tests/test_collections.py b/server/covmanager/tests/test_collections.py index 2a4a72e5a..1fa633e7e 100644 --- a/server/covmanager/tests/test_collections.py +++ b/server/covmanager/tests/test_collections.py @@ -8,15 +8,22 @@ License, v. 2.0. If a copy of the MPL was not distributed with this file, You can obtain one at http://mozilla.org/MPL/2.0/. """ + +from __future__ import annotations + import json import logging import os import re +import typing import pytest import requests +from django.test.client import Client from django.urls import reverse +from .conftest import _result + LOG = logging.getLogger("fm.covmanager.tests.collections") pytestmark = pytest.mark.usefixtures("covmanager_test") # pylint: disable=invalid-name @@ -30,12 +37,15 @@ "covmanager:collections_patch", ], ) -def test_collections_no_login(name, client): +def test_collections_no_login(name: str, client: Client) -> None: """Request without login hits the login redirect""" path = reverse(name) response = client.get(path, follow=False) assert response.status_code == requests.codes["found"] - assert response.url == "/login/?next=" + path + assert ( + typing.cast(typing.Union[str, None], getattr(response, "url", None)) + == "/login/?next=" + path + ) @pytest.mark.parametrize( @@ -47,7 +57,7 @@ def test_collections_no_login(name, client): "covmanager:collections_patch", ], ) -def test_collections_view_simple_get(name, client): +def test_collections_view_simple_get(name: str, client: Client) -> None: """No errors are thrown in template""" client.login(username="test", password="test") response = client.get(reverse(name)) @@ -55,15 +65,18 @@ def test_collections_view_simple_get(name, client): assert response.status_code == requests.codes["ok"] -def test_collections_diff_no_login(client): +def test_collections_diff_no_login(client: Client) -> None: """Request without login hits the login redirect""" path = reverse("covmanager:collections_diff_api", kwargs={"path": ""}) response = client.get(path, follow=False) assert response.status_code == requests.codes["found"] - assert response.url == "/login/?next=" + path + assert ( + typing.cast(typing.Union[str, None], getattr(response, "url", None)) + == "/login/?next=" + path + ) -def test_collections_diff_simple_get(client, cm): +def test_collections_diff_simple_get(client: Client, cm: _result) -> None: """No errors are thrown in template""" repo = cm.create_repository("git") col1 = cm.create_collection(repository=repo, coverage=json.dumps({"children": []})) @@ -77,7 +90,7 @@ def test_collections_diff_simple_get(client, cm): assert response.status_code == requests.codes["ok"] -def test_collections_patch_no_login(client): +def test_collections_patch_no_login(client: Client) -> None: """Request without login hits the login redirect""" path = reverse( "covmanager:collections_patch_api", @@ -85,10 +98,13 @@ def test_collections_patch_no_login(client): ) response = client.get(path, follow=False) assert response.status_code == requests.codes["found"] - assert response.url == "/login/?next=" + path + assert ( + typing.cast(typing.Union[str, None], getattr(response, "url", None)) + == "/login/?next=" + path + ) -def test_collections_patch_simple_get(client, cm): +def test_collections_patch_simple_get(client: Client, cm: _result) -> None: """No errors are thrown in template""" client.login(username="test", password="test") repo = cm.create_repository("hg") @@ -112,7 +128,9 @@ def test_collections_patch_simple_get(client, cm): with open(os.path.join(repo.location, "test.c"), "w") as fp: fp.write("world") cm.hg(repo, "commit", "-m", "update") - rev = re.match(r"changeset: 1:([0-9a-f]+)", cm.hg(repo, "log")).group(1) + re_match = re.match(r"changeset: 1:([0-9a-f]+)", cm.hg(repo, "log")) + assert re_match is not None + rev = re_match.group(1) response = client.get( reverse( "covmanager:collections_patch_api", @@ -123,15 +141,18 @@ def test_collections_patch_simple_get(client, cm): assert response.status_code == requests.codes["ok"] -def test_collections_browse_no_login(client): +def test_collections_browse_no_login(client: Client) -> None: """Request without login hits the login redirect""" path = reverse("covmanager:collections_browse", kwargs={"collectionid": 0}) response = client.get(path, follow=False) assert response.status_code == requests.codes["found"] - assert response.url == "/login/?next=" + path + assert ( + typing.cast(typing.Union[str, None], getattr(response, "url", None)) + == "/login/?next=" + path + ) -def test_collections_browse_simple_get(client): +def test_collections_browse_simple_get(client: Client) -> None: """No errors are thrown in template""" client.login(username="test", password="test") response = client.get( @@ -141,17 +162,20 @@ def test_collections_browse_simple_get(client): assert response.status_code == requests.codes["ok"] -def test_collections_browse_api_no_login(client): +def test_collections_browse_api_no_login(client: Client) -> None: """Request without login hits the login redirect""" path = reverse( "covmanager:collections_browse_api", kwargs={"collectionid": 0, "path": ""} ) response = client.get(path, follow=False) assert response.status_code == requests.codes["found"] - assert response.url == "/login/?next=" + path + assert ( + typing.cast(typing.Union[str, None], getattr(response, "url", None)) + == "/login/?next=" + path + ) -def test_collections_browse_api_simple_get(client, cm): +def test_collections_browse_api_simple_get(client: Client, cm: _result) -> None: """No errors are thrown in template""" client.login(username="test", password="test") repo = cm.create_repository("git") diff --git a/server/covmanager/tests/test_collections_rest.py b/server/covmanager/tests/test_collections_rest.py index 1e0c0d7cd..2c512ab47 100644 --- a/server/covmanager/tests/test_collections_rest.py +++ b/server/covmanager/tests/test_collections_rest.py @@ -8,6 +8,9 @@ License, v. 2.0. If a copy of the MPL was not distributed with this file, You can obtain one at http://mozilla.org/MPL/2.0/. """ + +from __future__ import annotations + import codecs import json import logging @@ -16,14 +19,17 @@ import requests from django.contrib.auth.models import User from django.utils import dateparse, timezone +from rest_framework.test import APIClient from covmanager.models import Collection +from .conftest import _result, covType + LOG = logging.getLogger("fm.covmanager.tests.collections.rest") pytestmark = pytest.mark.usefixtures("covmanager_test") # pylint: disable=invalid-name -def test_rest_collections_no_auth(api_client): +def test_rest_collections_no_auth(api_client: APIClient) -> None: """must yield forbidden without authentication""" url = "/covmanager/rest/collections/" assert api_client.get(url).status_code == requests.codes["unauthorized"] @@ -33,7 +39,7 @@ def test_rest_collections_no_auth(api_client): assert api_client.delete(url).status_code == requests.codes["unauthorized"] -def test_rest_collections_no_perm(api_client): +def test_rest_collections_no_perm(api_client: APIClient) -> None: """must yield forbidden without permission""" user = User.objects.get(username="test-noperm") api_client.force_authenticate(user=user) @@ -45,7 +51,7 @@ def test_rest_collections_no_perm(api_client): assert api_client.delete(url).status_code == requests.codes["forbidden"] -def test_rest_collections_patch(api_client): +def test_rest_collections_patch(api_client: APIClient) -> None: """patch should not be allowed""" user = User.objects.get(username="test") api_client.force_authenticate(user=user) @@ -54,12 +60,13 @@ def test_rest_collections_patch(api_client): assert resp.status_code == requests.codes["method_not_allowed"] -def test_rest_collections_post(api_client, cm): +def test_rest_collections_post(api_client: APIClient, cm: _result) -> None: """post should be allowed""" user = User.objects.get(username="test") api_client.force_authenticate(user=user) repo = cm.create_repository("git", name="testrepo") - cov = { + + cov: covType = { "linesTotal": 0, "name": None, "coveragePercent": 0.0, @@ -94,7 +101,7 @@ def test_rest_collections_post(api_client, cm): assert json.load(codecs.getreader("utf-8")(result.coverage.file)) == cov -def test_rest_collections_put(api_client): +def test_rest_collections_put(api_client: APIClient) -> None: """put should not be allowed""" user = User.objects.get(username="test") api_client.force_authenticate(user=user) @@ -103,7 +110,7 @@ def test_rest_collections_put(api_client): assert resp.status_code == requests.codes["method_not_allowed"] -def test_rest_collections_delete(api_client): +def test_rest_collections_delete(api_client: APIClient) -> None: """delete should not be allowed""" user = User.objects.get(username="test") api_client.force_authenticate(user=user) @@ -112,7 +119,7 @@ def test_rest_collections_delete(api_client): assert resp.status_code == requests.codes["method_not_allowed"] -def test_rest_collections_get(api_client, cm): +def test_rest_collections_get(api_client: APIClient, cm: _result) -> None: """get should be allowed""" repo = cm.create_repository("git", name="testrepo") coll = cm.create_collection( @@ -146,15 +153,17 @@ def test_rest_collections_get(api_client, cm): assert resp["repository"] == "testrepo" created = dateparse.parse_datetime(resp["created"]) LOG.debug("time now: %s", timezone.now()) + assert created is not None assert (timezone.now() - created).total_seconds() < 60 assert resp["description"] == "testdesc" assert resp["client"] == "testclient" assert resp["tools"] == "testtool" assert resp["revision"] == "abc" + assert coll.coverage is not None assert resp["coverage"] == coll.coverage.file -def test_rest_collection_no_auth(api_client): +def test_rest_collection_no_auth(api_client: APIClient) -> None: """must yield forbidden without authentication""" url = "/covmanager/rest/collections/1/" assert api_client.get(url).status_code == requests.codes["unauthorized"] @@ -164,7 +173,7 @@ def test_rest_collection_no_auth(api_client): assert api_client.delete(url).status_code == requests.codes["unauthorized"] -def test_rest_collection_no_perm(api_client): +def test_rest_collection_no_perm(api_client: APIClient) -> None: """must yield forbidden without permission""" user = User.objects.get(username="test-noperm") api_client.force_authenticate(user=user) @@ -176,7 +185,7 @@ def test_rest_collection_no_perm(api_client): assert api_client.delete(url).status_code == requests.codes["forbidden"] -def test_rest_collection_patch(api_client): +def test_rest_collection_patch(api_client: APIClient) -> None: """patch should not be allowed""" user = User.objects.get(username="test") api_client.force_authenticate(user=user) @@ -185,7 +194,7 @@ def test_rest_collection_patch(api_client): assert resp.status_code == requests.codes["method_not_allowed"] -def test_rest_collection_post(api_client): +def test_rest_collection_post(api_client: APIClient) -> None: """post should not be allowed""" user = User.objects.get(username="test") api_client.force_authenticate(user=user) @@ -194,7 +203,7 @@ def test_rest_collection_post(api_client): assert resp.status_code == requests.codes["method_not_allowed"] -def test_rest_collection_put(api_client): +def test_rest_collection_put(api_client: APIClient) -> None: """put should not be allowed""" user = User.objects.get(username="test") api_client.force_authenticate(user=user) @@ -203,7 +212,7 @@ def test_rest_collection_put(api_client): assert resp.status_code == requests.codes["method_not_allowed"] -def test_rest_collection_delete(api_client): +def test_rest_collection_delete(api_client: APIClient) -> None: """delete should not be allowed""" user = User.objects.get(username="test") api_client.force_authenticate(user=user) @@ -212,7 +221,7 @@ def test_rest_collection_delete(api_client): assert resp.status_code == requests.codes["method_not_allowed"] -def test_rest_collection_get(api_client, cm): +def test_rest_collection_get(api_client: APIClient, cm: _result) -> None: """get should not be allowed""" repo = cm.create_repository("git", name="testrepo") coll = cm.create_collection( @@ -240,9 +249,11 @@ def test_rest_collection_get(api_client, cm): assert resp["repository"] == "testrepo" created = dateparse.parse_datetime(resp["created"]) LOG.debug("time now: %s", timezone.now()) + assert created is not None assert (timezone.now() - created).total_seconds() < 60 assert resp["description"] == "testdesc" assert resp["client"] == "testclient" assert resp["tools"] == "testtool" assert resp["revision"] == "abc" + assert coll.coverage is not None assert resp["coverage"] == coll.coverage.file diff --git a/server/covmanager/tests/test_covmanager.py b/server/covmanager/tests/test_covmanager.py index 137842cbe..817eeaf0e 100644 --- a/server/covmanager/tests/test_covmanager.py +++ b/server/covmanager/tests/test_covmanager.py @@ -8,25 +8,32 @@ License, v. 2.0. If a copy of the MPL was not distributed with this file, You can obtain one at http://mozilla.org/MPL/2.0/. """ + +from __future__ import annotations + import logging +import typing import pytest import requests +from django.test.client import Client from django.urls import reverse LOG = logging.getLogger("fm.covmanager.tests.covmanager") pytestmark = pytest.mark.usefixtures("covmanager_test") # pylint: disable=invalid-name -def test_covmanager_index(client): +def test_covmanager_index(client: Client) -> None: """Request of root url redirects to pools view""" client.login(username="test", password="test") resp = client.get(reverse("covmanager:index")) assert resp.status_code == requests.codes["found"] - assert resp.url == reverse("covmanager:collections") + assert typing.cast(typing.Union[str, None], getattr(resp, "url", None)) == reverse( + "covmanager:collections" + ) -def test_covmanager_noperm(client): +def test_covmanager_noperm(client: Client) -> None: """Request without permission results in 403""" client.login(username="test-noperm", password="test") resp = client.get(reverse("covmanager:index")) diff --git a/server/covmanager/tests/test_mgmt_setup_repository.py b/server/covmanager/tests/test_mgmt_setup_repository.py index 3bfa94b79..000dd35af 100644 --- a/server/covmanager/tests/test_mgmt_setup_repository.py +++ b/server/covmanager/tests/test_mgmt_setup_repository.py @@ -8,6 +8,9 @@ License, v. 2.0. If a copy of the MPL was not distributed with this file, You can obtain one at http://mozilla.org/MPL/2.0/. """ + +from __future__ import annotations + import os import pytest @@ -18,7 +21,7 @@ pytestmark = pytest.mark.django_db() # pylint: disable=invalid-name -def test_bad_args(): +def test_bad_args() -> None: with pytest.raises(CommandError, match=r"Error: .*? arguments"): call_command("setup_repository") @@ -41,7 +44,7 @@ def test_bad_args(): call_command("setup_repository", "", "", "", "") -def test_repo_exists(): +def test_repo_exists() -> None: Repository.objects.create(name="test") with pytest.raises( CommandError, match=r"Error: repository with name '.*' already exists!" @@ -49,14 +52,14 @@ def test_repo_exists(): call_command("setup_repository", "test", "", "") -def test_bad_provider(): +def test_bad_provider() -> None: with pytest.raises( CommandError, match=r"Error: 'bad' is not a valid source code provider!" ): call_command("setup_repository", "test", "bad", ".") -def test_git_create(): +def test_git_create() -> None: call_command("setup_repository", "test", "git", ".") repo = Repository.objects.get(name="test") assert repo.classname == "GITSourceCodeProvider" @@ -68,7 +71,7 @@ def test_git_create(): assert repo.location == os.path.realpath(".") -def test_hg_create(): +def test_hg_create() -> None: call_command("setup_repository", "test", "hg", ".") repo = Repository.objects.get(name="test") assert repo.classname == "HGSourceCodeProvider" diff --git a/server/covmanager/tests/test_repositories.py b/server/covmanager/tests/test_repositories.py index 7b9983ae5..822a81c3e 100644 --- a/server/covmanager/tests/test_repositories.py +++ b/server/covmanager/tests/test_repositories.py @@ -8,13 +8,20 @@ License, v. 2.0. If a copy of the MPL was not distributed with this file, You can obtain one at http://mozilla.org/MPL/2.0/. """ + +from __future__ import annotations + import json import logging +import typing import pytest import requests +from django.test.client import Client from django.urls import reverse +from .conftest import _result + LOG = logging.getLogger("fm.covmanager.tests.repos") pytestmark = pytest.mark.usefixtures("covmanager_test") # pylint: disable=invalid-name @@ -22,15 +29,18 @@ @pytest.mark.parametrize( "name", ["covmanager:repositories", "covmanager:repositories_search_api"] ) -def test_repositories_no_login(name, client): +def test_repositories_no_login(name: str, client: Client) -> None: """Request without login hits the login redirect""" path = reverse(name) response = client.get(path, follow=False) assert response.status_code == requests.codes["found"] - assert response.url == "/login/?next=" + path + assert ( + typing.cast(typing.Union[str, None], getattr(response, "url", None)) + == "/login/?next=" + path + ) -def test_repositories_view_simpleget(client): +def test_repositories_view_simpleget(client: Client) -> None: """No errors are thrown in template""" client.login(username="test", password="test") response = client.get(reverse("covmanager:repositories")) @@ -38,7 +48,7 @@ def test_repositories_view_simpleget(client): assert response.status_code == requests.codes["ok"] -def test_repositories_view_list(client, cm): +def test_repositories_view_list(client: Client, cm: _result) -> None: """Repositories are listed""" client.login(username="test", password="test") repos = [] @@ -56,7 +66,7 @@ def test_repositories_view_list(client, cm): assert set(response.context["repositories"]) == set(repos) -def test_repositories_search_view_simpleget(client): +def test_repositories_search_view_simpleget(client: Client) -> None: """No errors are thrown in template""" client.login(username="test", password="test") response = client.get(reverse("covmanager:repositories_search_api")) @@ -64,45 +74,45 @@ def test_repositories_search_view_simpleget(client): assert response.status_code == requests.codes["ok"] -def test_repositories_search_view_search_git(client, cm): +def test_repositories_search_view_search_git(client: Client, cm: _result) -> None: cm.create_repository("git", name="gittest1") cm.create_repository("git", name="gittest2") client.login(username="test", password="test") - response = client.get( + response_blah = client.get( reverse("covmanager:repositories_search_api"), {"name": "blah"} ) - LOG.debug(response) - assert response.status_code == requests.codes["ok"] - response = json.loads(response.content.decode("utf-8")) - assert set(response.keys()) == {"results"} - assert response["results"] == [] - response = client.get( + LOG.debug(response_blah) + assert response_blah.status_code == requests.codes["ok"] + response_blah_json = json.loads(response_blah.content.decode("utf-8")) + assert set(response_blah_json.keys()) == {"results"} + assert response_blah_json["results"] == [] + response_test = client.get( reverse("covmanager:repositories_search_api"), {"name": "test"} ) - LOG.debug(response) - assert response.status_code == requests.codes["ok"] - response = json.loads(response.content.decode("utf-8")) - assert set(response.keys()) == {"results"} - assert set(response["results"]) == {"gittest1", "gittest2"} + LOG.debug(response_test) + assert response_test.status_code == requests.codes["ok"] + response_test_json = json.loads(response_test.content.decode("utf-8")) + assert set(response_test_json.keys()) == {"results"} + assert set(response_test_json["results"]) == {"gittest1", "gittest2"} -def test_repositories_search_view_search_hg(client, cm): +def test_repositories_search_view_search_hg(client: Client, cm: _result) -> None: cm.create_repository("hg", name="hgtest1") cm.create_repository("hg", name="hgtest2") client.login(username="test", password="test") - response = client.get( + response_blah = client.get( reverse("covmanager:repositories_search_api"), {"name": "blah"} ) - LOG.debug(response) - assert response.status_code == requests.codes["ok"] - response = json.loads(response.content.decode("utf-8")) - assert set(response.keys()) == {"results"} - assert response["results"] == [] - response = client.get( + LOG.debug(response_blah) + assert response_blah.status_code == requests.codes["ok"] + response_blah_json = json.loads(response_blah.content.decode("utf-8")) + assert set(response_blah_json.keys()) == {"results"} + assert response_blah_json["results"] == [] + response_test = client.get( reverse("covmanager:repositories_search_api"), {"name": "test"} ) - LOG.debug(response) - assert response.status_code == requests.codes["ok"] - response = json.loads(response.content.decode("utf-8")) - assert set(response.keys()) == {"results"} - assert set(response["results"]) == {"hgtest1", "hgtest2"} + LOG.debug(response_test) + assert response_test.status_code == requests.codes["ok"] + response_test_json = json.loads(response_test.content.decode("utf-8")) + assert set(response_test_json.keys()) == {"results"} + assert set(response_test_json["results"]) == {"hgtest1", "hgtest2"} diff --git a/server/covmanager/tests/test_repositories_rest.py b/server/covmanager/tests/test_repositories_rest.py index 95e77c3ed..64c8db8e5 100644 --- a/server/covmanager/tests/test_repositories_rest.py +++ b/server/covmanager/tests/test_repositories_rest.py @@ -9,18 +9,23 @@ file, You can obtain one at http://mozilla.org/MPL/2.0/. """ +from __future__ import annotations + import json import logging import pytest import requests from django.contrib.auth.models import User +from rest_framework.test import APIClient + +from .conftest import _result LOG = logging.getLogger("fm.covmanager.tests.repos.rest") pytestmark = pytest.mark.usefixtures("covmanager_test") # pylint: disable=invalid-name -def test_rest_repositories_no_auth(api_client): +def test_rest_repositories_no_auth(api_client: APIClient) -> None: """must yield forbidden without authentication""" url = "/covmanager/rest/repositories/" assert api_client.get(url).status_code == requests.codes["unauthorized"] @@ -30,7 +35,7 @@ def test_rest_repositories_no_auth(api_client): assert api_client.delete(url).status_code == requests.codes["unauthorized"] -def test_rest_repositories_no_perm(api_client): +def test_rest_repositories_no_perm(api_client: APIClient) -> None: """must yield forbidden without permission""" user = User.objects.get(username="test-noperm") api_client.force_authenticate(user=user) @@ -42,7 +47,7 @@ def test_rest_repositories_no_perm(api_client): assert api_client.delete(url).status_code == requests.codes["forbidden"] -def test_rest_repositories_patch(api_client): +def test_rest_repositories_patch(api_client: APIClient) -> None: """patch should not be allowed""" user = User.objects.get(username="test") api_client.force_authenticate(user=user) @@ -50,7 +55,7 @@ def test_rest_repositories_patch(api_client): assert resp.status_code == requests.codes["method_not_allowed"] -def test_rest_repositories_post(api_client): +def test_rest_repositories_post(api_client: APIClient) -> None: """post should not be allowed""" user = User.objects.get(username="test") api_client.force_authenticate(user=user) @@ -58,7 +63,7 @@ def test_rest_repositories_post(api_client): assert resp.status_code == requests.codes["method_not_allowed"] -def test_rest_repositories_put(api_client): +def test_rest_repositories_put(api_client: APIClient) -> None: """put should not be allowed""" user = User.objects.get(username="test") api_client.force_authenticate(user=user) @@ -66,7 +71,7 @@ def test_rest_repositories_put(api_client): assert resp.status_code == requests.codes["method_not_allowed"] -def test_rest_repositories_delete(api_client): +def test_rest_repositories_delete(api_client: APIClient) -> None: """delete should not be allowed""" user = User.objects.get(username="test") api_client.force_authenticate(user=user) @@ -74,7 +79,7 @@ def test_rest_repositories_delete(api_client): assert resp.status_code == requests.codes["method_not_allowed"] -def test_rest_repositories_get(api_client, cm): +def test_rest_repositories_get(api_client: APIClient, cm: _result) -> None: """get should be allowed""" cm.create_repository("git", name="testrepo") user = User.objects.get(username="test") @@ -92,7 +97,7 @@ def test_rest_repositories_get(api_client, cm): assert resp["name"] == "testrepo" -def test_rest_repository_no_auth(api_client): +def test_rest_repository_no_auth(api_client: APIClient) -> None: """must yield forbidden without authentication""" url = "/covmanager/rest/repositories/1/" assert api_client.get(url).status_code == requests.codes["unauthorized"] @@ -102,7 +107,7 @@ def test_rest_repository_no_auth(api_client): assert api_client.delete(url).status_code == requests.codes["unauthorized"] -def test_rest_repository_no_perm(api_client): +def test_rest_repository_no_perm(api_client: APIClient) -> None: """must yield forbidden without permission""" user = User.objects.get(username="test-noperm") api_client.force_authenticate(user=user) @@ -114,7 +119,7 @@ def test_rest_repository_no_perm(api_client): assert api_client.delete(url).status_code == requests.codes["forbidden"] -def test_rest_repository_patch(api_client): +def test_rest_repository_patch(api_client: APIClient) -> None: """patch should not be allowed""" user = User.objects.get(username="test") api_client.force_authenticate(user=user) @@ -122,7 +127,7 @@ def test_rest_repository_patch(api_client): assert resp.status_code == requests.codes["method_not_allowed"] -def test_rest_repository_post(api_client): +def test_rest_repository_post(api_client: APIClient) -> None: """post should not be allowed""" user = User.objects.get(username="test") api_client.force_authenticate(user=user) @@ -130,7 +135,7 @@ def test_rest_repository_post(api_client): assert resp.status_code == requests.codes["method_not_allowed"] -def test_rest_repository_put(api_client): +def test_rest_repository_put(api_client: APIClient) -> None: """put should not be allowed""" user = User.objects.get(username="test") api_client.force_authenticate(user=user) @@ -138,7 +143,7 @@ def test_rest_repository_put(api_client): assert resp.status_code == requests.codes["method_not_allowed"] -def test_rest_repository_delete(api_client): +def test_rest_repository_delete(api_client: APIClient) -> None: """delete should not be allowed""" user = User.objects.get(username="test") api_client.force_authenticate(user=user) @@ -146,7 +151,7 @@ def test_rest_repository_delete(api_client): assert resp.status_code == requests.codes["method_not_allowed"] -def test_rest_repository_get(api_client, cm): +def test_rest_repository_get(api_client: APIClient, cm: _result) -> None: """get should be allowed""" repo = cm.create_repository("git", name="testrepo") user = User.objects.get(username="test") diff --git a/server/covmanager/urls.py b/server/covmanager/urls.py index 32b3739c3..e1243a456 100644 --- a/server/covmanager/urls.py +++ b/server/covmanager/urls.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from django.conf.urls import include from django.urls import re_path from rest_framework import routers diff --git a/server/covmanager/views.py b/server/covmanager/views.py index e2f4e018e..ff3c4d6a2 100644 --- a/server/covmanager/views.py +++ b/server/covmanager/views.py @@ -1,16 +1,28 @@ +from __future__ import annotations + import json import os +from typing import Any, TypeVar, cast from wsgiref.util import FileWrapper from django.conf import settings from django.core.exceptions import PermissionDenied, SuspiciousOperation -from django.db.models import Q +from django.db.models import Model, Q +from django.db.models.query import QuerySet from django.http import Http404 -from django.http.response import HttpResponse +from django.http.request import HttpRequest +from django.http.response import ( + HttpResponse, + HttpResponsePermanentRedirect, + HttpResponseRedirect, +) from django.shortcuts import get_object_or_404, redirect, render from django.views.decorators.csrf import csrf_exempt from rest_framework import filters, mixins, viewsets from rest_framework.authentication import SessionAuthentication, TokenAuthentication +from rest_framework.request import Request +from rest_framework.response import Response +from rest_framework.views import APIView from crashmanager.models import Tool from server.views import JsonQueryFilterBackend, SimpleQueryFilterBackend @@ -25,46 +37,50 @@ from .SourceCodeProvider import SourceCodeProvider from .tasks import aggregate_coverage_data, calculate_report_summary +MT = TypeVar("MT", bound=Model) -def index(request): + +def index(request: HttpRequest) -> HttpResponseRedirect | HttpResponsePermanentRedirect: return redirect( f"covmanager:{getattr(settings, 'COV_DEFAULT_PAGE', 'collections')}" ) -def reports(request): +def reports(request: HttpRequest) -> HttpResponse: return render(request, "collections/report.html", {}) -def repositories(request): +def repositories(request: HttpRequest) -> HttpResponse: repositories = Repository.objects.all() return render(request, "repositories/index.html", {"repositories": repositories}) -def reportconfigurations(request): +def reportconfigurations(request: HttpRequest) -> HttpResponse: return render(request, "reportconfigurations/index.html", {}) -def collections(request): +def collections(request: HttpRequest) -> HttpResponse: return render(request, "collections/index.html", {}) -def collections_browse(request, collectionid): +def collections_browse(request: HttpRequest, collectionid: str) -> HttpResponse: return render(request, "collections/browse.html", {"collectionid": collectionid}) -def collections_diff(request): +def collections_diff(request: HttpRequest) -> HttpResponse: return render(request, "collections/browse.html", {"diff_api": True}) -def collections_reportsummary(request, collectionid): +def collections_reportsummary(request: HttpRequest, collectionid: str) -> HttpResponse: return render( request, "reportconfigurations/summary.html", {"collectionid": collectionid} ) -def collections_reportsummary_html_list(request, collectionid): - collection = get_object_or_404(Collection, pk=collectionid) +def collections_reportsummary_html_list( + request: HttpRequest, collectionid: str +) -> HttpResponse: + collection: Collection = get_object_or_404(Collection, pk=collectionid) if not collection.coverage: return HttpResponse( @@ -131,7 +147,7 @@ def collections_reportsummary_html_list(request, collectionid): root["diffid"] = diff_collection.pk - def annotate_delta(a, b): + def annotate_delta(a, b) -> None: delta = round(a["coveragePercent"] - b["coveragePercent"], 2) if delta >= 1.0: @@ -140,7 +156,7 @@ def annotate_delta(a, b): a["coveragePercentDelta"] = f"{delta} %" if "children" not in a or "children" not in b: - return + return None # Map children to their ids so we can iterate them side-by-side a_child_dict = {c["id"]: c for c in a["children"]} @@ -161,7 +177,7 @@ def annotate_delta(a, b): ) -def collections_download(request, collectionid): +def collections_download(request: HttpRequest, collectionid: str) -> HttpResponse: collection = get_object_or_404(Collection, pk=collectionid) if not collection.coverage: @@ -183,7 +199,9 @@ def collections_download(request, collectionid): return response -def collections_browse_api(request, collectionid, path): +def collections_browse_api( + request: HttpRequest, collectionid: str, path: str +) -> HttpResponse: collection = get_object_or_404(Collection, pk=collectionid) if not collection.coverage: @@ -218,16 +236,16 @@ def collections_browse_api(request, collectionid, path): return HttpResponse(json.dumps(data), content_type="application/json") -def collections_diff_api(request, path): +def collections_diff_api(request: HttpRequest, path: str) -> HttpResponse: - collections = None + collections: list[Collection] coverages = [] if "ids" in request.GET: ids = request.GET["ids"].split(",") - collections = Collection.objects.filter(pk__in=ids) + collections = cast(list[Collection], Collection.objects.filter(pk__in=ids)) - if len(collections) < 2: + if collections and len(collections) < 2: raise Http404("Need at least two collections") report_configuration = None @@ -331,11 +349,13 @@ def collections_diff_api(request, path): return HttpResponse(json.dumps(data), content_type="application/json") -def collections_patch(request): +def collections_patch(request: HttpRequest) -> HttpResponse: return render(request, "collections/patch.html", {}) -def collections_patch_api(request, collectionid, patch_revision): +def collections_patch_api( + request: HttpRequest, collectionid: str, patch_revision: str +) -> HttpResponse: collection = get_object_or_404(Collection, pk=collectionid) if not collection.coverage: @@ -372,6 +392,7 @@ def collections_patch_api(request, collectionid, patch_revision): prepatch_source = provider.getSource(filename, diff_revision) coll_source = provider.getSource(filename, collection.revision) + assert filename is not None if prepatch_source != coll_source: response = {"error": "Source code mismatch."} response["filename"] = filename @@ -460,7 +481,9 @@ def collections_patch_api(request, collectionid, patch_revision): return HttpResponse(json.dumps(results), content_type="application/json") -def collections_reportsummary_api(request, collectionid): +def collections_reportsummary_api( + request: HttpRequest, collectionid: str +) -> HttpResponse: collection = get_object_or_404(Collection, pk=collectionid) if not collection.coverage: @@ -500,36 +523,38 @@ def collections_reportsummary_api(request, collectionid): return HttpResponse(summary.cached_result, content_type="application/json") -def repositories_search_api(request): - results = [] +def repositories_search_api(request: HttpRequest) -> HttpResponse: + results: list[Repository] = [] if "name" in request.GET: name = request.GET["name"] - results = Repository.objects.filter(name__contains=name).values_list( - "name", flat=True + results = list( + Repository.objects.filter(name__contains=name).values_list( + "name", flat=True + ) ) return HttpResponse( - json.dumps({"results": list(results)}), content_type="application/json" + json.dumps({"results": results}), content_type="application/json" ) -def tools_search_api(request): - results = [] +def tools_search_api(request: HttpRequest) -> HttpResponse: + results: list[Tool] = [] if "name" in request.GET: name = request.GET["name"] - results = Tool.objects.filter(name__contains=name).values_list( - "name", flat=True + results = list( + Tool.objects.filter(name__contains=name).values_list("name", flat=True) ) return HttpResponse( - json.dumps({"results": list(results)}), content_type="application/json" + json.dumps({"results": results}), content_type="application/json" ) @csrf_exempt -def collections_aggregate_api(request): +def collections_aggregate_api(request: HttpRequest) -> HttpResponse: if request.method != "POST": return HttpResponse( content=json.dumps({"error": "This API only supports POST."}), @@ -659,7 +684,9 @@ class CollectionFilterBackend(filters.BaseFilterBackend): Accepts filtering with several collection-specific fields from the URL """ - def filter_queryset(self, request, queryset, view): + def filter_queryset( + self, request: Request, queryset: QuerySet[MT], view: APIView + ) -> QuerySet[MT]: """ Return a filtered queryset. """ @@ -723,7 +750,9 @@ class ReportFilterBackend(filters.BaseFilterBackend): Accepts broad filtering by q parameter to search multiple fields """ - def filter_queryset(self, request, queryset, view): + def filter_queryset( + self, request: Request, queryset: QuerySet[MT], view: APIView + ) -> QuerySet[MT]: """ Return a filtered queryset. """ @@ -762,7 +791,7 @@ class ReportViewSet( paginate_by_param = "limit" filter_backends = [ReportFilterBackend] - def partial_update(self, request, *args, **kwargs): + def partial_update(self, request: Request, *args: Any, **kwargs: Any) -> Response: if ( not request.user or not request.user.is_authenticated @@ -791,7 +820,9 @@ class ReportConfigurationFilterBackend(filters.BaseFilterBackend): Accepts broad filtering by q parameter to search multiple fields """ - def filter_queryset(self, request, queryset, view): + def filter_queryset( + self, request: Request, queryset: QuerySet[MT], view: APIView + ) -> QuerySet[MT]: """ Return a filtered queryset. """ diff --git a/server/crashmanager/Bugtracker/BugzillaProvider.py b/server/crashmanager/Bugtracker/BugzillaProvider.py index e26f7305c..f1e657be1 100644 --- a/server/crashmanager/Bugtracker/BugzillaProvider.py +++ b/server/crashmanager/Bugtracker/BugzillaProvider.py @@ -11,20 +11,25 @@ @contact: choller@mozilla.com """ + +from __future__ import annotations + +from django.db.models.query import QuerySet from django.forms.models import model_to_dict from django.shortcuts import get_object_or_404 from django.utils import dateparse +from rest_framework.request import Request -from ..models import BugzillaTemplate, User +from ..models import BugzillaTemplate, CrashEntry, User from .BugzillaREST import BugzillaREST from .Provider import Provider class BugzillaProvider(Provider): - def __init__(self, pk, hostname): + def __init__(self, pk: int, hostname: str) -> None: super().__init__(pk, hostname) - def getTemplateForUser(self, request, crashEntry): + def getTemplateForUser(self, request: Request, crashEntry: CrashEntry): if "template" in request.GET: obj = get_object_or_404(BugzillaTemplate, pk=request.GET["template"]) template = model_to_dict(obj) @@ -32,15 +37,15 @@ def getTemplateForUser(self, request, crashEntry): else: user = User.get_or_create_restricted(request.user)[0] - obj = BugzillaTemplate.objects.filter(name__contains=crashEntry.tool.name) - if not obj: + obj_ = BugzillaTemplate.objects.filter(name__contains=crashEntry.tool.name) + if not obj_: defaultTemplateId = user.defaultTemplateId if not defaultTemplateId: defaultTemplateId = 1 - obj = BugzillaTemplate.objects.filter(pk=defaultTemplateId) + obj_ = BugzillaTemplate.objects.filter(pk=defaultTemplateId) - if not obj: + if not obj_: template = {} else: template = model_to_dict(obj[0]) @@ -48,18 +53,31 @@ def getTemplateForUser(self, request, crashEntry): return template - def getTemplateList(self): + def getTemplateList(self) -> QuerySet[BugzillaTemplate]: return BugzillaTemplate.objects.all() - def getBugData(self, bugId, username=None, password=None, api_key=None): + def getBugData( + self, + bugId: str, + username: str | None = None, + password: str | None = None, + api_key: str | None = None, + ) -> str | None: bz = BugzillaREST(self.hostname, username, password, api_key) return bz.getBug(bugId) - def getBugStatus(self, bugIds, username=None, password=None, api_key=None): + def getBugStatus( + self, + bugIds: list[str], + username: str | None = None, + password: str | None = None, + api_key: str | None = None, + ): ret = {} bz = BugzillaREST(self.hostname, username, password, api_key) bugs = bz.getBugStatus(bugIds) + assert bugs is not None for bugId in bugs: if bugs[bugId]["is_open"]: ret[bugId] = None diff --git a/server/crashmanager/Bugtracker/BugzillaREST.py b/server/crashmanager/Bugtracker/BugzillaREST.py index 237d1e890..9256fc17e 100644 --- a/server/crashmanager/Bugtracker/BugzillaREST.py +++ b/server/crashmanager/Bugtracker/BugzillaREST.py @@ -11,11 +11,20 @@ @contact: choller@mozilla.com """ + +from __future__ import annotations + import requests class BugzillaREST: - def __init__(self, hostname, username=None, password=None, api_key=None): + def __init__( + self, + hostname: str, + username: str | None = None, + password: str | None = None, + api_key: str | None = None, + ) -> None: self.hostname = hostname self.baseUrl = f"https://{self.hostname}/rest" self.username = username @@ -35,7 +44,7 @@ def __init__(self, hostname, username=None, password=None, api_key=None): # it in the URI for additional security. self.request_headers["X-BUGZILLA-API-KEY"] = self.api_key - def login(self, loginRequired=True, forceLogin=False): + def login(self, loginRequired: bool = True, forceLogin: bool = False) -> bool: if (self.username is None or self.password is None) and self.api_key is None: if loginRequired: raise RuntimeError("Need username/password or API key to login.") @@ -65,7 +74,7 @@ def login(self, loginRequired=True, forceLogin=False): self.authToken = json["token"] return True - def getBug(self, bugId): + def getBug(self, bugId: str) -> str | None: bugs = self.getBugs([bugId]) if not bugs: @@ -73,7 +82,7 @@ def getBug(self, bugId): return bugs[int(bugId)] - def getBugStatus(self, bugIds): + def getBugStatus(self, bugIds: list[str]): return self.getBugs( bugIds, include_fields=[ @@ -85,7 +94,12 @@ def getBugStatus(self, bugIds): ], ) - def getBugs(self, bugIds, include_fields=None, exclude_fields=None): + def getBugs( + self, + bugIds: list[str] | str, + include_fields: list[str] | None = None, + exclude_fields: list[str] | None = None, + ): if not isinstance(bugIds, list): bugIds = [bugIds] diff --git a/server/crashmanager/Bugtracker/Provider.py b/server/crashmanager/Bugtracker/Provider.py index 8f5f5b84e..88d759016 100644 --- a/server/crashmanager/Bugtracker/Provider.py +++ b/server/crashmanager/Bugtracker/Provider.py @@ -11,8 +11,15 @@ @contact: choller@mozilla.com """ + +from __future__ import annotations + from abc import ABCMeta, abstractmethod +from django.db.models.query import QuerySet + +from ..models import BugzillaTemplate + class Provider(metaclass=ABCMeta): """ @@ -20,18 +27,25 @@ class Provider(metaclass=ABCMeta): class that defines what interfaces Bug Providers must implement """ - def __init__(self, pk, hostname): + def __init__(self, pk: int, hostname: str) -> None: self.pk = pk self.hostname = hostname @abstractmethod - def getTemplateList(self): + def getTemplateList(self) -> QuerySet[BugzillaTemplate]: return @abstractmethod - def getBugData(self, bugId, username=None, password=None): + def getBugData( + self, bugId: str, username: str | None = None, password: str | None = None + ) -> str | None: return @abstractmethod - def getBugStatus(self, bugIds, username=None, password=None): + def getBugStatus( + self, + bugIds: list[str], + username: str | None = None, + password: str | None = None, + ) -> None: return diff --git a/server/crashmanager/__init__.py b/server/crashmanager/__init__.py index ba00c6d1d..2e82c2ba6 100644 --- a/server/crashmanager/__init__.py +++ b/server/crashmanager/__init__.py @@ -1 +1,3 @@ +from __future__ import annotations + from . import tasks # noqa diff --git a/server/crashmanager/admin.py b/server/crashmanager/admin.py index c07b44441..25a9cdb9b 100644 --- a/server/crashmanager/admin.py +++ b/server/crashmanager/admin.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from django.contrib import admin from crashmanager.models import ( # @UnresolvedImport diff --git a/server/crashmanager/cron.py b/server/crashmanager/cron.py index 92254defe..9c00e6667 100644 --- a/server/crashmanager/cron.py +++ b/server/crashmanager/cron.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import shutil from tempfile import mkstemp @@ -12,22 +14,22 @@ @app.task(ignore_result=True) -def bug_update_status(): +def bug_update_status() -> None: call_command("bug_update_status") @app.task(ignore_result=True) -def cleanup_old_crashes(): +def cleanup_old_crashes() -> None: call_command("cleanup_old_crashes") @app.task(ignore_result=True) -def triage_new_crashes(): +def triage_new_crashes() -> None: call_command("triage_new_crashes") @app.task(ignore_result=True) -def export_signatures(): +def export_signatures() -> None: fd, tmpf = mkstemp(prefix="fm-sigs-", suffix=".zip") os.close(fd) try: @@ -39,5 +41,5 @@ def export_signatures(): @app.task(ignore_result=True) -def notify_by_email(): +def notify_by_email() -> None: call_command("notify_by_email") diff --git a/server/crashmanager/forms.py b/server/crashmanager/forms.py index 3a7395b99..5c86dc2f9 100644 --- a/server/crashmanager/forms.py +++ b/server/crashmanager/forms.py @@ -1,3 +1,7 @@ +from __future__ import annotations + +from typing import Any + from crispy_forms.helper import FormHelper from crispy_forms.layout import HTML, Div, Field, Layout, Submit from django.conf import settings @@ -21,7 +25,7 @@ class Row(Div): css_class = "row" -class BugzillaTemplateBugForm(ModelForm): +class BugzillaTemplateBugForm(ModelForm[BugzillaTemplate]): helper = FormHelper() helper.layout = Layout( HTML("""
"""), @@ -143,7 +147,7 @@ class Meta: widgets["attrs"] = Textarea(attrs={"rows": 2}) -class BugzillaTemplateCommentForm(ModelForm): +class BugzillaTemplateCommentForm(ModelForm[BugzillaTemplate]): helper = FormHelper() helper.layout = Layout( HTML("""
"""), @@ -177,7 +181,7 @@ class Meta: } -class UserSettingsForm(ModelForm): +class UserSettingsForm(ModelForm[User]): helper = FormHelper() helper.layout = Layout( "defaultToolsFilter", @@ -215,7 +219,7 @@ class Meta: "tasks_failed", ] - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: self.user = kwargs.pop("user", None) super().__init__(*args, **kwargs) @@ -253,7 +257,7 @@ def clean_defaultProviderId(self): data = self.cleaned_data["defaultProviderId"].id return data - def save(self, *args, **kwargs): + def save(self, *args: Any, **kwargs: Any) -> User: self.instance.user.email = self.cleaned_data["email"] self.instance.user.save() return super().save(*args, **kwargs) diff --git a/server/crashmanager/management/commands/add_permission.py b/server/crashmanager/management/commands/add_permission.py index 0212fb726..c96d63a53 100644 --- a/server/crashmanager/management/commands/add_permission.py +++ b/server/crashmanager/management/commands/add_permission.py @@ -1,3 +1,8 @@ +from __future__ import annotations + +from argparse import ArgumentParser +from typing import Any + from django.contrib.auth.models import Permission, User from django.contrib.contenttypes.models import ContentType from django.core.management import BaseCommand @@ -6,7 +11,7 @@ class Command(BaseCommand): help = "Adds permissions to the specified user." - def handle(self, *args, **options): + def handle(self, *args: Any, **options: Any) -> None: user = User.objects.get(username=options["user"]) for perm in options["permission"]: @@ -22,7 +27,7 @@ def handle(self, *args, **options): print("done") - def add_arguments(self, parser): + def add_arguments(self, parser: ArgumentParser) -> None: parser.add_argument("user") parser.add_argument( "permission", diff --git a/server/crashmanager/management/commands/bug_update_status.py b/server/crashmanager/management/commands/bug_update_status.py index 81115d20f..ec43ddbe7 100644 --- a/server/crashmanager/management/commands/bug_update_status.py +++ b/server/crashmanager/management/commands/bug_update_status.py @@ -1,4 +1,7 @@ +from __future__ import annotations + import logging +from typing import Any from django.conf import settings from django.contrib.contenttypes.models import ContentType @@ -14,7 +17,7 @@ class Command(BaseCommand): help = "Check the status of all bugs we have" - def handle(self, *args, **options): + def handle(self, *args: Any, **options: Any) -> None: if args: raise CommandError("Command doesn't accept any arguments") diff --git a/server/crashmanager/management/commands/cleanup_old_crashes.py b/server/crashmanager/management/commands/cleanup_old_crashes.py index 677cc854b..54582de6a 100644 --- a/server/crashmanager/management/commands/cleanup_old_crashes.py +++ b/server/crashmanager/management/commands/cleanup_old_crashes.py @@ -1,5 +1,8 @@ -import logging +from __future__ import annotations + from datetime import timedelta +import logging +from typing import Any from django.conf import settings from django.core.management import BaseCommand, CommandError # noqa @@ -14,7 +17,7 @@ class Command(BaseCommand): help = "Cleanup old crash entries." - def handle(self, *args, **options): + def handle(self, *args: Any, **options: Any) -> None: cleanup_crashes_after_days = getattr(settings, "CLEANUP_CRASHES_AFTER_DAYS", 14) cleanup_fixed_buckets_after_days = getattr( diff --git a/server/crashmanager/management/commands/export_signatures.py b/server/crashmanager/management/commands/export_signatures.py index d9bbd23a5..715b7fc9e 100644 --- a/server/crashmanager/management/commands/export_signatures.py +++ b/server/crashmanager/management/commands/export_signatures.py @@ -1,4 +1,8 @@ +from __future__ import annotations + +import argparse import json +from typing import Any from zipfile import ZipFile from django.core.management.base import BaseCommand @@ -10,12 +14,12 @@ class Command(BaseCommand): help = "Export signatures and their metadata." - def add_arguments(self, parser): + def add_arguments(self, parser: argparse.ArgumentParser) -> None: parser.add_argument( "filename", help="output filename to write signatures zip to" ) - def handle(self, filename, **options): + def handle(self, filename: str, **options: Any) -> None: with ZipFile(filename, "w") as zipFile: for bucket in Bucket.objects.annotate( diff --git a/server/crashmanager/management/commands/get_auth_token.py b/server/crashmanager/management/commands/get_auth_token.py index 6a8df7063..889a9a73b 100644 --- a/server/crashmanager/management/commands/get_auth_token.py +++ b/server/crashmanager/management/commands/get_auth_token.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from django.contrib.auth.models import User from django.core.management.base import LabelCommand from rest_framework.authtoken.models import Token @@ -7,7 +9,7 @@ class Command(LabelCommand): help = "Provides the REST interface authentication token for the specified user(s)." - def handle_label(self, label, **options): + def handle_label(self, label: str, **options: str) -> None: user = User.objects.get(username=label) (token, created) = Token.objects.get_or_create(user=user) diff --git a/server/crashmanager/management/commands/notify_by_email.py b/server/crashmanager/management/commands/notify_by_email.py index ed8e6b88f..cfe5b8eca 100644 --- a/server/crashmanager/management/commands/notify_by_email.py +++ b/server/crashmanager/management/commands/notify_by_email.py @@ -1,3 +1,7 @@ +from __future__ import annotations + +from typing import Any + from django.core.mail import send_mail from django.core.management import BaseCommand from django.template.loader import render_to_string @@ -9,7 +13,7 @@ class Command(BaseCommand): help = "Send notifications by email." - def handle(self, *args, **options): + def handle(self, *args: Any, **options: Any) -> None: # Select all notifications that haven't been sent by email for now notifications = Notification.objects.filter(emailed=False) for notification in notifications: diff --git a/server/crashmanager/management/commands/triage_new_crash.py b/server/crashmanager/management/commands/triage_new_crash.py index a312ab078..50885ad79 100644 --- a/server/crashmanager/management/commands/triage_new_crash.py +++ b/server/crashmanager/management/commands/triage_new_crash.py @@ -1,4 +1,8 @@ +from __future__ import annotations + +from argparse import ArgumentParser from collections import OrderedDict +from typing import Any from django.conf import settings from django.core.management import BaseCommand @@ -11,20 +15,20 @@ # although this cache looks pointless within this command, # the command is called in a loop from triage_new_crashes.py # and may be called multiple times in one process by celery -TRIAGE_CACHE = OrderedDict() +TRIAGE_CACHE: OrderedDict[str, list[int]] = OrderedDict() class Command(BaseCommand): help = "Triage a crash entry into an existing bucket." - def add_arguments(self, parser): + def add_arguments(self, parser: ArgumentParser) -> None: parser.add_argument( "id", type=int, help="Crash ID", ) - def handle(self, *args, **options): + def handle(self, *args: Any, **options: Any) -> None: entry = CrashEntry.objects.get(pk=options["id"]) crashInfo = entry.getCrashInfo(attachTestcase=True) diff --git a/server/crashmanager/management/commands/triage_new_crashes.py b/server/crashmanager/management/commands/triage_new_crashes.py index 222796ad5..39c0c2c26 100644 --- a/server/crashmanager/management/commands/triage_new_crashes.py +++ b/server/crashmanager/management/commands/triage_new_crashes.py @@ -1,3 +1,7 @@ +from __future__ import annotations + +from typing import Any + from django.core.management import BaseCommand, call_command from crashmanager.models import CrashEntry @@ -9,7 +13,7 @@ class Command(BaseCommand): "before to assign them into the existing buckets." ) - def handle(self, *args, **options): + def handle(self, *args: Any, **options: Any) -> None: entries = CrashEntry.objects.filter(triagedOnce=False, bucket=None).values_list( "id", flat=True ) diff --git a/server/crashmanager/migrations/0003_auto_20210429_0908.py b/server/crashmanager/migrations/0003_auto_20210429_0908.py index c66d5ad25..1a8661d7c 100644 --- a/server/crashmanager/migrations/0003_auto_20210429_0908.py +++ b/server/crashmanager/migrations/0003_auto_20210429_0908.py @@ -1,5 +1,7 @@ # Generated by Django 3.0.14 on 2021-04-29 09:08 +from __future__ import annotations + import django.core.files.storage from django.conf import settings from django.db import migrations, models diff --git a/server/crashmanager/migrations/0004_bugzillatemplate_mode.py b/server/crashmanager/migrations/0004_bugzillatemplate_mode.py index 176fdab92..3af7e951e 100644 --- a/server/crashmanager/migrations/0004_bugzillatemplate_mode.py +++ b/server/crashmanager/migrations/0004_bugzillatemplate_mode.py @@ -1,5 +1,7 @@ # Generated by Django 2.2.20 on 2021-05-21 14:15 +from __future__ import annotations + import enumfields.fields from django.db import migrations diff --git a/server/crashmanager/migrations/0005_user_notification_booleans.py b/server/crashmanager/migrations/0005_user_notification_booleans.py index 590f09799..91a1f16aa 100644 --- a/server/crashmanager/migrations/0005_user_notification_booleans.py +++ b/server/crashmanager/migrations/0005_user_notification_booleans.py @@ -1,5 +1,7 @@ # Generated by Django 2.2.20 on 2021-06-18 13:28 +from __future__ import annotations + from django.db import migrations, models diff --git a/server/crashmanager/models.py b/server/crashmanager/models.py index 0854671e4..032f45d57 100644 --- a/server/crashmanager/models.py +++ b/server/crashmanager/models.py @@ -1,14 +1,20 @@ +from __future__ import annotations + +from datetime import datetime +from itertools import zip_longest import json import logging import re -from itertools import zip_longest +from typing import Any, TypeVar from django.conf import settings -from django.contrib.auth.models import Permission +from django.contrib.auth.base_user import AbstractBaseUser +from django.contrib.auth.models import AnonymousUser, Permission from django.contrib.auth.models import User as DjangoUser from django.contrib.contenttypes.models import ContentType from django.core.files.storage import FileSystemStorage from django.db import models +from django.db.models.query import QuerySet from django.db.models.signals import post_delete, post_save from django.dispatch.dispatcher import receiver from django.utils import timezone @@ -22,11 +28,13 @@ if getattr(settings, "USE_CELERY", None): from .tasks import triage_new_crash +MT = TypeVar("MT", bound=models.Model) + class Tool(models.Model): name = models.CharField(max_length=63) - def __str__(self): + def __str__(self) -> str: return self.name @@ -53,7 +61,7 @@ class TestCase(models.Model): quality = models.IntegerField(default=0) isBinary = models.BooleanField(default=False) - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: # This variable can hold the testcase data temporarily self.content = None @@ -63,12 +71,13 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - def loadTest(self): + def loadTest(self) -> None: self.test.open(mode="rb") self.content = self.test.read() self.test.close() - def storeTestAndSave(self): + def storeTestAndSave(self) -> None: + assert self.content is not None self.size = len(self.content) self.test.open(mode="w") self.test.write(self.content) @@ -95,7 +104,7 @@ def getInstance(self): providerClass = getattr(providerModule, self.classname) return providerClass(self.pk, self.hostname) - def __str__(self): + def __str__(self) -> str: return self.hostname @@ -105,7 +114,7 @@ class Bug(models.Model): closed = models.DateTimeField(blank=True, null=True) @property - def tools_filter_users(self): + def tools_filter_users(self) -> QuerySet[DjangoUser]: ids = User.objects.filter( defaultToolsFilter__crashentry__bucket__in=self.bucket_set.all(), inaccessible_bug=True, @@ -114,30 +123,31 @@ def tools_filter_users(self): class Bucket(models.Model): - bug = models.ForeignKey( + bug: Bug | None = models.ForeignKey( Bug, blank=True, null=True, on_delete=models.deletion.CASCADE ) signature = models.TextField() - optimizedSignature = models.TextField(blank=True, null=True) + optimizedSignature: str | None = models.TextField(blank=True, null=True) shortDescription = models.CharField(max_length=1023, blank=True) frequent = models.BooleanField(blank=False, default=False) permanent = models.BooleanField(blank=False, default=False) doNotReduce = models.BooleanField(blank=False, default=False) @property - def watchers(self): + def watchers(self) -> QuerySet[DjangoUser]: ids = User.objects.filter( bucketwatch__bucket=self, bucket_hit=True ).values_list("user_id", flat=True) return DjangoUser.objects.filter(id__in=ids).distinct() - def getSignature(self): + def getSignature(self) -> CrashSignature: return CrashSignature(self.signature) - def getOptimizedSignature(self): + def getOptimizedSignature(self) -> CrashSignature: + assert self.optimizedSignature is not None return CrashSignature(self.optimizedSignature) - def save(self, *args, **kwargs): + def save(self, *args: Any, **kwargs: Any) -> None: # Sanitize signature line endings so we end up with the same hash # TODO: We might want to just parse the JSON here, and re-serialize # it to a canonical string representation. @@ -151,7 +161,7 @@ def save(self, *args, **kwargs): super().save(*args, **kwargs) - def reassign(self, submitSave): + def reassign(self, submitSave: bool) -> tuple[list[str], list[str], int, int]: """ Assign all unassigned issues that match our signature to this bucket. Furthermore, remove all non-matching issues from our bucket. @@ -247,7 +257,9 @@ def grouper(iterable, n, fillvalue=None): return inList, outList, inListCount, outListCount - def optimizeSignature(self, unbucketed_entries): + def optimizeSignature( + self, unbucketed_entries: QuerySet[CrashEntry] + ) -> tuple[CrashSignature | None, list[CrashEntry]]: buckets = Bucket.objects.all() signature = self.getSignature() @@ -267,7 +279,7 @@ def optimizeSignature(self, unbucketed_entries): for entry in entries: entry.crashinfo = entry.getCrashInfo( attachTestcase=signature.matchRequiresTest(), - requiredOutputSources=requiredOutputs, + requiredOutputSources=tuple(requiredOutputs), ) # For optimization, disregard any issues that directly match since those @@ -284,7 +296,7 @@ def optimizeSignature(self, unbucketed_entries): # than others). matchesInOtherBuckets = False nonMatchesInOtherBuckets = 0 # noqa - otherMatchingBucketIds = [] # noqa + otherMatchingBucketIds: list[int] = [] # noqa for otherBucket in buckets: if otherBucket.pk == self.pk: continue @@ -303,11 +315,13 @@ def optimizeSignature(self, unbucketed_entries): bucket=otherBucket ).select_related("product", "platform", "os") c = CrashEntry.deferRawFields(c, requiredOutputs) - c = c.first() - firstEntryPerBucketCache[otherBucket.pk] = c - if c: + c_first = c.first() + assert c_first is not None + c_ = c_first + firstEntryPerBucketCache[otherBucket.pk] = c_ + if c_: # Omit testcase for performance reasons for now - firstEntryPerBucketCache[otherBucket.pk] = c.getCrashInfo( + firstEntryPerBucketCache[otherBucket.pk] = c_.getCrashInfo( attachTestcase=False, requiredOutputSources=requiredOutputs, ) @@ -326,7 +340,8 @@ def optimizeSignature(self, unbucketed_entries): else: for otherEntry in entries: otherEntry.crashinfo = otherEntry.getCrashInfo( - attachTestcase=False, requiredOutputSources=requiredOutputs + attachTestcase=False, + requiredOutputSources=tuple(requiredOutputs), ) if optimizedSignature.matches(otherEntry.crashinfo): matchingEntries.append(otherEntry) @@ -341,7 +356,7 @@ def optimizeSignature(self, unbucketed_entries): return (optimizedSignature, matchingEntries) -def buckethit_default_range_begin(): +def buckethit_default_range_begin() -> datetime: return timezone.now().replace(microsecond=0, second=0, minute=0) @@ -352,7 +367,7 @@ class BucketHit(models.Model): count = models.IntegerField(default=0) @classmethod - def decrement_count(cls, bucket_id, tool_id, begin): + def decrement_count(cls, bucket_id: int, tool_id: int, begin: datetime) -> None: begin = begin.replace(microsecond=0, second=0, minute=0) counter = cls.objects.filter( bucket_id=bucket_id, @@ -364,7 +379,7 @@ def decrement_count(cls, bucket_id, tool_id, begin): counter.save() @classmethod - def increment_count(cls, bucket_id, tool_id, begin): + def increment_count(cls, bucket_id: int, tool_id: int, begin: datetime) -> None: begin = begin.replace(microsecond=0, second=0, minute=0) counter, _ = cls.objects.get_or_create( bucket_id=bucket_id, begin=begin, tool_id=tool_id @@ -383,7 +398,7 @@ class CrashEntry(models.Model): TestCase, blank=True, null=True, on_delete=models.deletion.CASCADE ) client = models.ForeignKey(Client, on_delete=models.deletion.CASCADE) - bucket = models.ForeignKey( + bucket: Bucket | None = models.ForeignKey( Bucket, blank=True, null=True, on_delete=models.deletion.CASCADE ) rawStdout = models.TextField(blank=True) @@ -395,20 +410,23 @@ class CrashEntry(models.Model): crashAddress = models.CharField(max_length=255, blank=True) crashAddressNumeric = models.BigIntegerField(blank=True, null=True) shortSignature = models.CharField(max_length=255, blank=True) - cachedCrashInfo = models.TextField(blank=True, null=True) + cachedCrashInfo: str | None = models.TextField(blank=True, null=True) triagedOnce = models.BooleanField(blank=False, default=False) - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: # These variables can hold temporarily deserialized data - self.argsList = None - self.envList = None - self.metadataList = None + self.argsList: list[str] | None = None + self.envList: list[str] | None = None + self.metadataList: list[str] | None = None # For performance reasons we do not deserialize these fields # automatically here. You need to explicitly call the # deserializeFields method if you need this data. self._original_bucket = None + + self.crashinfo: CrashInfo + super().__init__(*args, **kwargs) @classmethod @@ -417,13 +435,13 @@ def from_db(cls, db, field_names, values): instance._original_bucket = instance.bucket_id return instance - def save(self, *args, **kwargs): + def save(self, *args: Any, **kwargs: Any) -> None: if self.pk is None and not getattr(settings, "DB_ISUTF8MB4", False): # Replace 4-byte UTF-8 characters with U+FFFD if our database # doesn't support them. By default, MySQL utf-8 does not support these. utf8_4byte_re = re.compile("[^\u0000-\uD7FF\uE000-\uFFFF]", re.UNICODE) - def sanitize_utf8(s): + def sanitize_utf8(s: str) -> str: if not isinstance(s, str): s = str(s, "utf-8") @@ -465,7 +483,7 @@ def sanitize_utf8(s): super().save(*args, **kwargs) - def deserializeFields(self): + def deserializeFields(self) -> None: if self.args: self.argsList = json.loads(self.args) @@ -479,9 +497,9 @@ def deserializeFields(self): def getCrashInfo( self, - attachTestcase=False, - requiredOutputSources=("stdout", "stderr", "crashdata"), - ): + attachTestcase: bool = False, + requiredOutputSources: tuple[str, ...] = ("stdout", "stderr", "crashdata"), + ) -> CrashInfo: # TODO: This should be cached at some level # TODO: Need to include environment and program arguments here configuration = ProgramConfiguration( @@ -517,7 +535,7 @@ def getCrashInfo( return crashInfo - def reparseCrashInfo(self): + def reparseCrashInfo(self) -> None: # Purges cached crash information and then forces a reparsing # of the raw crash information. Based on the new crash information, # the depending fields are also repopulated. @@ -546,7 +564,10 @@ def reparseCrashInfo(self): return self.save() @staticmethod - def deferRawFields(queryset, requiredOutputSources=()): + def deferRawFields( + queryset: QuerySet[MT], + requiredOutputSources: tuple[str, str, str] | list[str] = ("", "", ""), + ) -> QuerySet[MT]: # This method calls defer() on the given query set for every raw field # that is not required as specified in requiredOutputSources. if "stdout" not in requiredOutputSources: @@ -562,7 +583,7 @@ def deferRawFields(queryset, requiredOutputSources=()): # is also deleted when the CrashEntry is gone. It also explicitly # deletes the file on the filesystem which would otherwise remain. @receiver(post_delete, sender=CrashEntry) -def CrashEntry_delete(sender, instance, **kwargs): +def CrashEntry_delete(sender: CrashEntry, instance: CrashEntry, **kwargs: Any) -> None: if instance.testcase: instance.testcase.delete(False) if instance.bucket_id is not None: @@ -572,13 +593,15 @@ def CrashEntry_delete(sender, instance, **kwargs): @receiver(post_delete, sender=TestCase) -def TestCase_delete(sender, instance, **kwargs): +def TestCase_delete(sender: TestCase, instance: TestCase, **kwargs: Any) -> None: if instance.test: instance.test.delete(False) @receiver(post_save, sender=CrashEntry) -def CrashEntry_save(sender, instance, created, **kwargs): +def CrashEntry_save( + sender: CrashEntry, instance: CrashEntry, created: bool, **kwargs: Any +) -> None: if getattr(settings, "USE_CELERY", None): if created and not instance.triagedOnce: triage_new_crash.delay(instance.pk) @@ -643,7 +666,7 @@ class BugzillaTemplate(models.Model): blocks = models.TextField(blank=True) dependson = models.TextField(blank=True) - def __str__(self): + def __str__(self) -> str: return self.name @@ -671,7 +694,9 @@ class Meta: tasks_failed = models.BooleanField(blank=False, default=False) @staticmethod - def get_or_create_restricted(request_user): + def get_or_create_restricted( + request_user: AbstractBaseUser | AnonymousUser, + ) -> tuple[DjangoUser, bool]: (user, created) = User.objects.get_or_create(user=request_user) if created and getattr(settings, "USERS_RESTRICTED_BY_DEFAULT", False): user.restricted = True @@ -680,7 +705,9 @@ def get_or_create_restricted(request_user): @receiver(post_save, sender=DjangoUser) -def add_default_perms(sender, instance, created, **kwargs): +def add_default_perms( + sender: DjangoUser, instance: DjangoUser, created: bool, **kwargs: Any +) -> None: if created: log = logging.getLogger("crashmanager") for perm in getattr(settings, "DEFAULT_PERMISSIONS", []): diff --git a/server/crashmanager/serializers.py b/server/crashmanager/serializers.py index 5eea278a3..e831a9a90 100644 --- a/server/crashmanager/serializers.py +++ b/server/crashmanager/serializers.py @@ -1,5 +1,9 @@ +from __future__ import annotations + import base64 import hashlib +from datetime import datetime +from typing import Any from django.conf import settings from django.core.exceptions import MultipleObjectsReturned # noqa @@ -32,7 +36,7 @@ class InvalidArgumentException(APIException): status_code = 400 -class CrashEntrySerializer(serializers.ModelSerializer): +class CrashEntrySerializer(serializers.ModelSerializer[CrashEntry]): # We need to redefine several fields explicitly because we flatten our # foreign keys into these fields instead of using primary keys, hyperlinks # or slug fields. All of the other solutions would require the client to @@ -57,7 +61,7 @@ class CrashEntrySerializer(serializers.ModelSerializer): source="testcase.isBinary", required=False, default=False ) - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: include_raw = kwargs.pop("include_raw", True) @@ -97,7 +101,7 @@ class Meta: ordering = ["-id"] read_only_fields = ("bucket", "id", "shortSignature", "crashAddress") - def create(self, attrs): + def create(self, attrs) -> CrashEntry: """ Create a CrashEntry instance based on the given dictionary of values received. We need to unflatten foreign relationships like product, @@ -179,7 +183,7 @@ def create(self, attrs): raise -class BucketSerializer(serializers.ModelSerializer): +class BucketSerializer(serializers.ModelSerializer[Bucket]): signature = serializers.CharField( style={"base_template": "textarea.html"}, required=False ) @@ -258,28 +262,28 @@ class Meta(BucketSerializer.Meta): "view_url", ) - def get_bug_closed(self, sig): + def get_bug_closed(self, sig: Bucket) -> datetime | None: if sig.bug: return sig.bug.closed return None - def get_bug_hostname(self, sig): + def get_bug_hostname(self, sig: Bucket) -> str | None: if sig.bug and sig.bug.externalType: - return sig.bug.externalType.hostname + return str(sig.bug.externalType.hostname) return None - def get_bug_urltemplate(self, sig): + def get_bug_urltemplate(self, sig: Bucket) -> str | None: if sig.bug and sig.bug.externalType: try: - return sig.bug.externalType.urlTemplate % sig.bug.externalId + return str(sig.bug.externalType.urlTemplate % sig.bug.externalId) except Exception: return None return None - def get_opt_pre_url(self, sig): + def get_opt_pre_url(self, sig: Bucket) -> str: return reverse("crashmanager:sigoptpre", kwargs={"sigid": sig.id}) - def get_view_url(self, sig): + def get_view_url(self, sig: Bucket) -> str: return reverse("crashmanager:sigview", kwargs={"sigid": sig.id}) @@ -303,22 +307,22 @@ class Meta(CrashEntrySerializer.Meta): "find_sigs_url", ) - def get_view_url(self, entry): + def get_view_url(self, entry: CrashEntry) -> str: return reverse("crashmanager:crashview", kwargs={"crashid": entry.id}) - def get_sig_view_url(self, entry): + def get_sig_view_url(self, entry: CrashEntry) -> str | None: if entry.bucket: return reverse("crashmanager:sigview", kwargs={"sigid": entry.bucket.id}) return None - def get_sig_new_url(self, entry): + def get_sig_new_url(self, entry: CrashEntry) -> str: return f"{reverse('crashmanager:signew')}?crashid={entry.id}" - def get_find_sigs_url(self, entry): + def get_find_sigs_url(self, entry: CrashEntry) -> str: return reverse("crashmanager:findsigs", kwargs={"crashid": entry.id}) -class BugProviderSerializer(serializers.ModelSerializer): +class BugProviderSerializer(serializers.ModelSerializer[BugProvider]): class Meta: model = BugProvider fields = ( @@ -329,7 +333,7 @@ class Meta: ) -class BugzillaTemplateSerializer(serializers.ModelSerializer): +class BugzillaTemplateSerializer(serializers.ModelSerializer[BugzillaTemplate]): mode = serializers.SerializerMethodField() class Meta: @@ -368,7 +372,7 @@ def get_mode(self, obj): return obj.mode.value -class NotificationSerializer(serializers.ModelSerializer): +class NotificationSerializer(serializers.ModelSerializer[Notification]): actor_url = serializers.SerializerMethodField() target_url = serializers.SerializerMethodField() external_bug_url = serializers.SerializerMethodField() @@ -385,7 +389,7 @@ class Meta: "external_bug_url", ) - def get_actor_url(self, notification): + def get_actor_url(self, notification: Notification) -> str | None: if isinstance(notification.actor, Bucket): return reverse( "crashmanager:sigview", kwargs={"sigid": notification.actor.id} @@ -395,7 +399,7 @@ def get_actor_url(self, notification): return f"{settings.TC_ROOT_URL}tasks/{task.task_id}/runs/{task.run_id}" return None - def get_target_url(self, notification): + def get_target_url(self, notification: Notification) -> str | None: if isinstance(notification.target, CrashEntry): return reverse( "crashmanager:crashview", kwargs={"crashid": notification.target.id} @@ -406,7 +410,7 @@ def get_target_url(self, notification): ) return None - def get_external_bug_url(self, notification): + def get_external_bug_url(self, notification: Notification) -> str | None: if isinstance(notification.target, Bug): return ( f"https://{notification.target.externalType.hostname}" diff --git a/server/crashmanager/tasks.py b/server/crashmanager/tasks.py index c7c2ad0c2..316c1dce3 100644 --- a/server/crashmanager/tasks.py +++ b/server/crashmanager/tasks.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from celeryconf import app from django.core.management import call_command @@ -5,5 +7,5 @@ @app.task(ignore_result=True) -def triage_new_crash(pk): +def triage_new_crash(pk: int) -> None: call_command("triage_new_crash", pk) diff --git a/server/crashmanager/templatetags/extratags.py b/server/crashmanager/templatetags/extratags.py index 3ea9731e9..84b7eb746 100644 --- a/server/crashmanager/templatetags/extratags.py +++ b/server/crashmanager/templatetags/extratags.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import json import os @@ -7,12 +9,12 @@ @register.filter -def basename(value): +def basename(value: str) -> str: return os.path.basename(value) @register.filter -def linejoin(value): +def linejoin(value: list[str]) -> str: if value: return "\n".join(value) else: @@ -20,12 +22,12 @@ def linejoin(value): @register.filter -def varformat(arg, val): +def varformat(arg: int, val: int) -> int: return arg % val @register.filter -def listcsv(value): +def listcsv(value: list[str]) -> str: if value: return ", ".join(value) else: @@ -33,7 +35,7 @@ def listcsv(value): @register.filter -def dictcsv(value): +def dictcsv(value: dict[str, object]) -> str: if value: return ", ".join("%s=%s" % x for x in value.items()) else: @@ -41,7 +43,7 @@ def dictcsv(value): @register.filter -def jsonparse(value): +def jsonparse(value: str): if value: return json.loads(value) else: diff --git a/server/crashmanager/tests/__init__.py b/server/crashmanager/tests/__init__.py index 510dc3c7c..5904cbd09 100644 --- a/server/crashmanager/tests/__init__.py +++ b/server/crashmanager/tests/__init__.py @@ -8,24 +8,28 @@ License, v. 2.0. If a copy of the MPL was not distributed with this file, You can obtain one at http://mozilla.org/MPL/2.0/. """ + +from __future__ import annotations + +from django.http.response import HttpResponse from django.test import SimpleTestCase as DjangoTestCase -def assert_contains(response, text): +def assert_contains(response: HttpResponse, text: str) -> None: """Assert that the response was successful, and contains the given text.""" class _(DjangoTestCase): - def runTest(self): + def runTest(self) -> None: pass _().assertContains(response, text) -def assert_not_contains(response, text): +def assert_not_contains(response: HttpResponse, text: str) -> None: """Assert that the response was successful, and does not contain the given text.""" class _(DjangoTestCase): - def runTest(self): + def runTest(self) -> None: pass _().assertNotContains(response, text) diff --git a/server/crashmanager/tests/conftest.py b/server/crashmanager/tests/conftest.py index 962019f31..d59eb1646 100644 --- a/server/crashmanager/tests/conftest.py +++ b/server/crashmanager/tests/conftest.py @@ -8,12 +8,17 @@ License, v. 2.0. If a copy of the MPL was not distributed with this file, You can obtain one at http://mozilla.org/MPL/2.0/. """ + +from __future__ import annotations + import logging +from typing import cast import pytest from django.contrib.auth.models import Permission, User from django.contrib.contenttypes.models import ContentType from django.core.files.base import ContentFile +from rest_framework.test import APIClient from crashmanager.models import ( OS, @@ -36,12 +41,12 @@ def _create_user( - username, - email="test@mozilla.com", - password="test", - restricted=False, - has_permission=True, -): + username: str, + email: str = "test@mozilla.com", + password: str = "test", + restricted: bool = False, + has_permission: bool = True, +) -> User: user = User.objects.create_user(username, email, password) user.user_permissions.clear() if has_permission: @@ -57,7 +62,7 @@ def _create_user( @pytest.fixture -def crashmanager_test(db): # pylint: disable=invalid-name,unused-argument +def crashmanager_test(db: None) -> None: # pylint: disable=invalid-name,unused-argument """Common testcase class for all crashmanager unittests""" # Create one unrestricted and one restricted test user _create_user("test") @@ -66,7 +71,9 @@ def crashmanager_test(db): # pylint: disable=invalid-name,unused-argument @pytest.fixture -def user_normal(db, api_client): # pylint: disable=invalid-name,unused-argument +def user_normal( + db: None, api_client: APIClient +) -> User: # pylint: disable=invalid-name,unused-argument """Create a normal, authenticated user""" user = _create_user("test") api_client.force_authenticate(user=user) @@ -74,7 +81,9 @@ def user_normal(db, api_client): # pylint: disable=invalid-name,unused-argument @pytest.fixture -def user_restricted(db, api_client): # pylint: disable=invalid-name,unused-argument +def user_restricted( + db: None, api_client: APIClient +) -> User: # pylint: disable=invalid-name,unused-argument """Create a restricted, authenticated user""" user = _create_user("test-restricted", restricted=True) api_client.force_authenticate(user=user) @@ -82,7 +91,9 @@ def user_restricted(db, api_client): # pylint: disable=invalid-name,unused-argu @pytest.fixture -def user_noperm(db, api_client): # pylint: disable=invalid-name,unused-argument +def user_noperm( + db: None, api_client: APIClient +) -> User: # pylint: disable=invalid-name,unused-argument """Create an authenticated user with no crashmanager ACL""" user = _create_user("test-noperm", has_permission=False) api_client.force_authenticate(user=user) @@ -90,104 +101,222 @@ def user_noperm(db, api_client): # pylint: disable=invalid-name,unused-argument @pytest.fixture -def user(request): +def user(request) -> User: assert request.param in {"normal", "restricted", "noperm"} return request.getfixturevalue("user_" + request.param) +class _cm_result: # pylint: disable=invalid-name + @staticmethod + def create_crash( + tool: str = "testtool", + platform: str = "testplatform", + product: str = "testproduct", + product_version: Product | str | None = None, + os: str = "testos", + testcase: cmTestCase | None = None, + client: str = "testclient", + bucket: Bucket | None = None, + stdout: str = "", + stderr: str = "", + crashdata: str = "", + metadata: str = "", + env: str = "", + args: str = "", + crashAddress: str = "", + crashAddressNumeric: int | None = None, + shortSignature: str = "", + cachedCrashInfo: str = "", + triagedOnce: bool = False, + ) -> CrashEntry: + ... + + @staticmethod + def create_bugprovider( + classname: str = "BugzillaProvider", hostname: str = "", urlTemplate: str = "%s" + ) -> BugProvider: + ... + + @classmethod + def create_bug( + cls, + externalId: str, + externalType: BugProvider | None = None, + closed: bool | None = None, + ) -> Bug: + ... + + @staticmethod + def create_testcase( + filename: str, testdata: str = "", quality: int = 0, isBinary: bool = False + ) -> cmTestCase: + ... + + @staticmethod + def create_template( + mode: str = BugzillaTemplateMode.Bug, + name: str = "", + product: str = "", + component: str = "", + summary: str = "", + version: str = "", + description: str = "", + whiteboard: str = "", + keywords: str = "", + op_sys: str = "", + platform: str = "", + priority: str = "", + severity: str = "", + alias: str = "", + cc: str = "", + assigned_to: str = "", + qa_contact: str = "", + target_milestone: str = "", + attrs: str = "", + security: bool = False, + security_group: str = "", + comment: str = "", + testcase_filename: str = "", + blocks: str = "", + dependson: str = "", + ) -> BugzillaTemplate: + ... + + @staticmethod + def create_bucket( + bug: Bug | None = None, + signature: str = "", + shortDescription: str = "", + frequent: bool = False, + permanent: bool = False, + ) -> Bucket: + ... + + @staticmethod + def create_toolfilter(tool: str, user: str = "test") -> None: + ... + + @staticmethod + def create_bucketwatch(bucket: Bucket, crash: CrashEntry | int = 0) -> BucketWatch: + ... + + @pytest.fixture def cm(): class _cm_result: # pylint: disable=invalid-name @staticmethod def create_crash( - tool="testtool", - platform="testplatform", - product="testproduct", - product_version=None, - os="testos", - testcase=None, - client="testclient", - bucket=None, - stdout="", - stderr="", - crashdata="", - metadata="", - env="", - args="", - crashAddress="", - crashAddressNumeric=None, - shortSignature="", - cachedCrashInfo="", - triagedOnce=False, - ): + tool: str = "testtool", + platform: str = "testplatform", + product: str = "testproduct", + product_version: Product | str | None = None, + os: str = "testos", + testcase: cmTestCase | None = None, + client: str = "testclient", + bucket: Bucket | None = None, + stdout: str = "", + stderr: str = "", + crashdata: str = "", + metadata: str = "", + env: str = "", + args: str = "", + crashAddress: str = "", + crashAddressNumeric: int | None = None, + shortSignature: str = "", + cachedCrashInfo: str = "", + triagedOnce: bool = False, + ) -> CrashEntry: # create tool tool, created = Tool.objects.get_or_create(name=tool) + assert isinstance(tool, Tool) if created: LOG.debug("Created Tool pk=%d", tool.pk) # create platform platform, created = Platform.objects.get_or_create(name=platform) + assert isinstance(platform, Platform) if created: LOG.debug("Created Platform pk=%d", platform.pk) # create product product, created = Product.objects.get_or_create(name=product) + assert isinstance(product, Product) if created: LOG.debug("Created Product pk=%d", product.pk) if product_version is not None: - product.version = product_version + product.version = str(product_version) product.save() # create os os, created = OS.objects.get_or_create(name=os) + assert isinstance(os, OS) if created: LOG.debug("Created OS pk=%d", os.pk) # create client client, created = Client.objects.get_or_create(name=client) + assert isinstance(client, Client) if created: LOG.debug("Created Client pk=%d", client.pk) - result = CrashEntry.objects.create( - tool=tool, - platform=platform, - product=product, - os=os, - testcase=testcase, - client=client, - bucket=bucket, - rawStdout=stdout, - rawStderr=stderr, - rawCrashData=crashdata, - metadata=metadata, - env=env, - args=args, - crashAddress=crashAddress, - crashAddressNumeric=crashAddressNumeric, - shortSignature=shortSignature, - cachedCrashInfo=cachedCrashInfo, - triagedOnce=triagedOnce, + result = cast( + CrashEntry, + CrashEntry.objects.create( + tool=tool, + platform=platform, + product=product, + os=os, + testcase=testcase, + client=client, + bucket=bucket, + rawStdout=stdout, + rawStderr=stderr, + rawCrashData=crashdata, + metadata=metadata, + env=env, + args=args, + crashAddress=crashAddress, + crashAddressNumeric=crashAddressNumeric, + shortSignature=shortSignature, + cachedCrashInfo=cachedCrashInfo, + triagedOnce=triagedOnce, + ), ) LOG.debug("Created CrashEntry pk=%d", result.pk) return result @staticmethod def create_bugprovider( - classname="BugzillaProvider", hostname="", urlTemplate="%s" - ): - result = BugProvider.objects.create( - classname=classname, hostname=hostname, urlTemplate=urlTemplate + classname: str = "BugzillaProvider", + hostname: str = "", + urlTemplate: str = "%s", + ) -> BugProvider: + result = cast( + BugProvider, + BugProvider.objects.create( + classname=classname, hostname=hostname, urlTemplate=urlTemplate + ), ) LOG.debug("Created BugProvider pk=%d", result.pk) return result @classmethod - def create_bug(cls, externalId, externalType=None, closed=None): + def create_bug( + cls, + externalId: str, + externalType: BugProvider | None = None, + closed: bool | None = None, + ) -> Bug: if externalType is None: externalType = cls.create_bugprovider() - result = Bug.objects.create( - externalId=externalId, externalType=externalType, closed=closed + result = cast( + Bug, + Bug.objects.create( + externalId=externalId, externalType=externalType, closed=closed + ), ) LOG.debug("Created Bug pk=%d", result.pk) return result @staticmethod - def create_testcase(filename, testdata="", quality=0, isBinary=False): + def create_testcase( + filename: str, testdata: str = "", quality: int = 0, isBinary: bool = False + ) -> cmTestCase: result = cmTestCase(quality=quality, isBinary=isBinary, size=len(testdata)) result.test.save(filename, ContentFile(testdata)) result.save() @@ -195,88 +324,102 @@ def create_testcase(filename, testdata="", quality=0, isBinary=False): @staticmethod def create_template( - mode=BugzillaTemplateMode.Bug, - name="", - product="", - component="", - summary="", - version="", - description="", - whiteboard="", - keywords="", - op_sys="", - platform="", - priority="", - severity="", - alias="", - cc="", - assigned_to="", - qa_contact="", - target_milestone="", - attrs="", - security=False, - security_group="", - comment="", - testcase_filename="", - blocks="", - dependson="", - ): - result = BugzillaTemplate.objects.create( - mode=mode, - name=name, - product=product, - component=component, - summary=summary, - version=version, - description=description, - whiteboard=whiteboard, - keywords=keywords, - op_sys=op_sys, - platform=platform, - priority=priority, - severity=severity, - alias=alias, - cc=cc, - assigned_to=assigned_to, - qa_contact=qa_contact, - target_milestone=target_milestone, - attrs=attrs, - security=security, - security_group=security_group, - comment=comment, - testcase_filename=testcase_filename, + mode: str = BugzillaTemplateMode.Bug, + name: str = "", + product: str = "", + component: str = "", + summary: str = "", + version: str = "", + description: str = "", + whiteboard: str = "", + keywords: str = "", + op_sys: str = "", + platform: str = "", + priority: str = "", + severity: str = "", + alias: str = "", + cc: str = "", + assigned_to: str = "", + qa_contact: str = "", + target_milestone: str = "", + attrs: str = "", + security: bool = False, + security_group: str = "", + comment: str = "", + testcase_filename: str = "", + blocks: str = "", + dependson: str = "", + ) -> BugzillaTemplate: + result = cast( + BugzillaTemplate, + BugzillaTemplate.objects.create( + mode=mode, + name=name, + product=product, + component=component, + summary=summary, + version=version, + description=description, + whiteboard=whiteboard, + keywords=keywords, + op_sys=op_sys, + platform=platform, + priority=priority, + severity=severity, + alias=alias, + cc=cc, + assigned_to=assigned_to, + qa_contact=qa_contact, + target_milestone=target_milestone, + attrs=attrs, + security=security, + security_group=security_group, + comment=comment, + testcase_filename=testcase_filename, + ), ) LOG.debug("Created BugzillaTemplate pk=%d", result.pk) return result @staticmethod def create_bucket( - bug=None, signature="", shortDescription="", frequent=False, permanent=False - ): - result = Bucket.objects.create( - bug=bug, - signature=signature, - shortDescription=shortDescription, - frequent=frequent, - permanent=permanent, + bug: Bug | None = None, + signature: str = "", + shortDescription: str = "", + frequent: bool = False, + permanent: bool = False, + ) -> Bucket: + result = cast( + Bucket, + Bucket.objects.create( + bug=bug, + signature=signature, + shortDescription=shortDescription, + frequent=frequent, + permanent=permanent, + ), ) LOG.debug("Created Bucket pk=%d", result.pk) return result @staticmethod - def create_toolfilter(tool, user="test"): + def create_toolfilter(tool: str, user: str = "test") -> None: user = User.objects.get(username=user) cmuser, _ = cmUser.objects.get_or_create(user=user) cmuser.defaultToolsFilter.add(Tool.objects.get(name=tool)) @staticmethod - def create_bucketwatch(bucket, crash=0): + def create_bucketwatch( + bucket: Bucket, crash: CrashEntry | int = 0 + ) -> BucketWatch: user = User.objects.get(username="test") cmuser, _ = cmUser.objects.get_or_create(user=user) if crash: + assert isinstance(crash, CrashEntry) crash = crash.pk - result = BucketWatch.objects.create( - bucket=bucket, user=cmuser, lastCrash=crash + result = cast( + BucketWatch, + BucketWatch.objects.create(bucket=bucket, user=cmuser, lastCrash=crash), ) LOG.debug("Created BucketWatch pk=%d", result.pk) return result diff --git a/server/crashmanager/tests/test_bugproviders_rest.py b/server/crashmanager/tests/test_bugproviders_rest.py index d3cea8375..a3597acbe 100644 --- a/server/crashmanager/tests/test_bugproviders_rest.py +++ b/server/crashmanager/tests/test_bugproviders_rest.py @@ -1,13 +1,23 @@ +from __future__ import annotations + import logging import pytest import requests +from django.contrib.auth.models import User +from rest_framework.test import APIClient + +from crashmanager.models import BugProvider + +from .conftest import _cm_result LOG = logging.getLogger("fm.crashmanager.tests.bugproviders.rest") @pytest.mark.parametrize("method", ["delete", "get", "patch", "post", "put"]) -def test_rest_bugproviders_no_auth(db, api_client, method): +def test_rest_bugproviders_no_auth( + db: None, api_client: APIClient, method: str +) -> None: """must yield unauthorized without authentication""" assert ( getattr(api_client, method)("/crashmanager/rest/bugproviders/", {}).status_code @@ -16,7 +26,9 @@ def test_rest_bugproviders_no_auth(db, api_client, method): @pytest.mark.parametrize("method", ["delete", "get", "patch", "post", "put"]) -def test_rest_bugproviders_no_perm(user_noperm, api_client, method): +def test_rest_bugproviders_no_perm( + user_noperm: User, api_client: APIClient, method: str +) -> None: """must yield forbidden without permission""" assert ( getattr(api_client, method)("/crashmanager/rest/bugproviders/", {}).status_code @@ -38,7 +50,9 @@ def test_rest_bugproviders_no_perm(user_noperm, api_client, method): ], indirect=["user"], ) -def test_rest_bugproviders_methods(api_client, user, method, url): +def test_rest_bugproviders_methods( + api_client: APIClient, user: User, method: str, url: str +) -> None: """must yield method-not-allowed for unsupported methods""" assert ( getattr(api_client, method)(url, {}).status_code @@ -62,14 +76,18 @@ def test_rest_bugproviders_methods(api_client, user, method, url): ], indirect=["user"], ) -def test_rest_bugproviders_methods_not_found(api_client, user, method, url): +def test_rest_bugproviders_methods_not_found( + api_client: APIClient, user: User, method: str, url: str +) -> None: """must yield not-found for undeclared methods""" assert ( getattr(api_client, method)(url, {}).status_code == requests.codes["not_found"] ) -def _compare_rest_result_to_bugprovider(result, provider): +def _compare_rest_result_to_bugprovider( + result: dict[str, object], provider: BugProvider +) -> None: expected_fields = {"id", "classname", "hostname", "urlTemplate"} assert set(result) == expected_fields for key, value in result.items(): @@ -77,7 +95,9 @@ def _compare_rest_result_to_bugprovider(result, provider): @pytest.mark.parametrize("user", ["normal", "restricted"], indirect=True) -def test_rest_bugproviders_list(api_client, user, cm): +def test_rest_bugproviders_list( + api_client: APIClient, user: User, cm: _cm_result +) -> None: """test that list returns the right bug providers""" expected = 4 providers = [ diff --git a/server/crashmanager/tests/test_bugs.py b/server/crashmanager/tests/test_bugs.py index 2b1b8f715..50c98ccf6 100644 --- a/server/crashmanager/tests/test_bugs.py +++ b/server/crashmanager/tests/test_bugs.py @@ -8,14 +8,21 @@ License, v. 2.0. If a copy of the MPL was not distributed with this file, You can obtain one at http://mozilla.org/MPL/2.0/. """ + +from __future__ import annotations + import logging +import typing import pytest import requests +from django.test.client import Client from django.urls import reverse from crashmanager.models import BugzillaTemplate +from .conftest import _cm_result + LOG = logging.getLogger("fm.crashmanager.tests.bugs") pytestmark = pytest.mark.usefixtures("crashmanager_test") @@ -32,12 +39,17 @@ ("crashmanager:createbugcomment", {"crashid": 0}), ], ) -def test_bug_providers_no_login(client, name, kwargs): +def test_bug_providers_no_login( + client: Client, name: str, kwargs: dict[str, int] +) -> None: """Request without login hits the login redirect""" path = reverse(name, kwargs=kwargs) resp = client.get(path) assert resp.status_code == requests.codes["found"] - assert resp.url == "/login/?next=" + path + assert ( + typing.cast(typing.Union[str, None], getattr(resp, "url", None)) + == "/login/?next=" + path + ) @pytest.mark.parametrize( @@ -51,12 +63,17 @@ def test_bug_providers_no_login(client, name, kwargs): ("crashmanager:templatedel", {"templateId": 0}), ], ) -def test_bugzilla_templates_no_login(client, name, kwargs): +def test_bugzilla_templates_no_login( + client: Client, name: str, kwargs: dict[str, int] +) -> None: """Request without login hits the login redirect""" path = reverse(name, kwargs=kwargs) resp = client.get(path) assert resp.status_code == requests.codes["found"] - assert resp.url == "/login/?next=" + path + assert ( + typing.cast(typing.Union[str, None], getattr(resp, "url", None)) + == "/login/?next=" + path + ) @pytest.mark.parametrize( @@ -69,7 +86,9 @@ def test_bugzilla_templates_no_login(client, name, kwargs): ("crashmanager:bugproviderview", {"providerId": 0}), ], ) -def test_bug_providers_simple_get(client, cm, name, kwargs): +def test_bug_providers_simple_get( + client: Client, cm: _cm_result, name: str, kwargs: dict[str, int] +) -> None: """No errors are thrown in template""" client.login(username="test", password="test") if "providerId" in kwargs: @@ -89,7 +108,9 @@ def test_bug_providers_simple_get(client, cm, name, kwargs): ("crashmanager:templatedel", {"templateId": 0}), ], ) -def test_bugzilla_templates_simple_get(client, cm, name, kwargs): +def test_bugzilla_templates_simple_get( + client: Client, cm: _cm_result, name: str, kwargs: dict[str, int] +) -> None: """No errors are thrown in template""" client.login(username="test", password="test") if "templateId" in kwargs: @@ -99,7 +120,7 @@ def test_bugzilla_templates_simple_get(client, cm, name, kwargs): assert response.status_code == requests.codes["ok"] -def test_template_edit(client, cm): +def test_template_edit(client: Client, cm: _cm_result) -> None: """No errors are thrown in template""" pk = cm.create_template().pk assert len(BugzillaTemplate.objects.all()) == 1 @@ -143,7 +164,10 @@ def test_template_edit(client, cm): LOG.debug(response) # Redirecting to template list when the action is successful assert response.status_code == requests.codes["found"] - assert response.url == "/crashmanager/bugzilla/templates/" + assert ( + typing.cast(typing.Union[str, None], getattr(response, "url", None)) + == "/crashmanager/bugzilla/templates/" + ) assert len(BugzillaTemplate.objects.all()) == 1 template = BugzillaTemplate.objects.get() assert template.mode.value == "bug" @@ -153,7 +177,7 @@ def test_template_edit(client, cm): assert template.version == "1.0" -def test_template_dup(client, cm): +def test_template_dup(client: Client, cm: _cm_result) -> None: """No errors are thrown in template""" pk = cm.create_template().pk assert len(BugzillaTemplate.objects.all()) == 1 @@ -164,7 +188,10 @@ def test_template_dup(client, cm): LOG.debug(response) # Redirecting to template list when the action is successful assert response.status_code == requests.codes["found"] - assert response.url == "/crashmanager/bugzilla/templates/" + assert ( + typing.cast(typing.Union[str, None], getattr(response, "url", None)) + == "/crashmanager/bugzilla/templates/" + ) assert len(BugzillaTemplate.objects.all()) == 2 template = BugzillaTemplate.objects.get(pk=pk) clone = BugzillaTemplate.objects.get(pk=pk + 1) @@ -196,7 +223,7 @@ def test_template_dup(client, cm): assert getattr(template, field) == getattr(clone, field) -def test_template_del(client, cm): +def test_template_del(client: Client, cm: _cm_result) -> None: """No errors are thrown in template""" pk = cm.create_template().pk assert len(BugzillaTemplate.objects.all()) == 1 @@ -207,11 +234,14 @@ def test_template_del(client, cm): LOG.debug(response) # Redirecting to template list when the action is successful assert response.status_code == requests.codes["found"] - assert response.url == "/crashmanager/bugzilla/templates/" + assert ( + typing.cast(typing.Union[str, None], getattr(response, "url", None)) + == "/crashmanager/bugzilla/templates/" + ) assert len(BugzillaTemplate.objects.all()) == 0 -def test_template_create_bug_post(client, cm): +def test_template_create_bug_post(client: Client, cm: _cm_result) -> None: """No errors are thrown in template""" assert len(BugzillaTemplate.objects.all()) == 0 client.login(username="test", password="test") @@ -227,7 +257,10 @@ def test_template_create_bug_post(client, cm): LOG.debug(response) # Redirecting to template list when the action is successful assert response.status_code == requests.codes["found"] - assert response.url == "/crashmanager/bugzilla/templates/" + assert ( + typing.cast(typing.Union[str, None], getattr(response, "url", None)) + == "/crashmanager/bugzilla/templates/" + ) assert len(BugzillaTemplate.objects.all()) == 1 template = BugzillaTemplate.objects.get() assert template.mode.value == "bug" @@ -237,7 +270,7 @@ def test_template_create_bug_post(client, cm): assert template.version == "1.0" -def test_template_create_comment_post(client, cm): +def test_template_create_comment_post(client: Client, cm: _cm_result) -> None: """No errors are thrown in template""" assert len(BugzillaTemplate.objects.all()) == 0 client.login(username="test", password="test") @@ -248,7 +281,10 @@ def test_template_create_comment_post(client, cm): LOG.debug(response) # Redirecting to template list when the action is successful assert response.status_code == requests.codes["found"] - assert response.url == "/crashmanager/bugzilla/templates/" + assert ( + typing.cast(typing.Union[str, None], getattr(response, "url", None)) + == "/crashmanager/bugzilla/templates/" + ) assert len(BugzillaTemplate.objects.all()) == 1 template = BugzillaTemplate.objects.get() assert template.mode.value == "comment" @@ -256,7 +292,7 @@ def test_template_create_comment_post(client, cm): assert template.comment == "A comment" -def test_create_external_bug_simple_get(client, cm): +def test_create_external_bug_simple_get(client: Client, cm: _cm_result) -> None: """No errors are thrown in template""" client.login(username="test", password="test") bucket = cm.create_bucket() @@ -270,7 +306,7 @@ def test_create_external_bug_simple_get(client, cm): assert response.status_code == requests.codes["ok"] -def test_create_external_bug_comment_simple_get(client, cm): +def test_create_external_bug_comment_simple_get(client: Client, cm: _cm_result) -> None: """No errors are thrown in template""" client.login(username="test", password="test") crash = cm.create_crash() diff --git a/server/crashmanager/tests/test_crashes.py b/server/crashmanager/tests/test_crashes.py index 6ebd02f3d..be6e78265 100644 --- a/server/crashmanager/tests/test_crashes.py +++ b/server/crashmanager/tests/test_crashes.py @@ -8,13 +8,20 @@ License, v. 2.0. If a copy of the MPL was not distributed with this file, You can obtain one at http://mozilla.org/MPL/2.0/. """ + +from __future__ import annotations + import logging +import typing import pytest import requests +from django.http.response import HttpResponse +from django.test.client import Client from django.urls import reverse from . import assert_contains +from .conftest import _cm_result LOG = logging.getLogger("fm.crashmanager.tests.crashes") pytestmark = pytest.mark.usefixtures( @@ -22,14 +29,14 @@ ) # pylint: disable=invalid-name -def test_crashes_view(client): # pylint: disable=invalid-name +def test_crashes_view(client: Client) -> None: # pylint: disable=invalid-name """Check that the Vue component is called""" client.login(username="test", password="test") response = client.get(reverse("crashmanager:crashes")) LOG.debug(response) assert response.status_code == requests.codes["ok"] assert response.context["restricted"] is False - assert_contains(response, "crasheslist") + assert_contains(typing.cast(HttpResponse, response), "crasheslist") @pytest.mark.parametrize( @@ -41,19 +48,24 @@ def test_crashes_view(client): # pylint: disable=invalid-name ("crashmanager:crashview", {"crashid": 0}), ], ) -def test_crashes_no_login(client, name, kwargs): +def test_crashes_no_login(client: Client, name: str, kwargs: dict[str, int]) -> None: """Request without login hits the login redirect""" path = reverse(name, kwargs=kwargs) resp = client.get(path) assert resp.status_code == requests.codes["found"] - assert resp.url == "/login/?next=" + path + assert ( + typing.cast(typing.Union[str, None], getattr(resp, "url", None)) + == "/login/?next=" + path + ) @pytest.mark.parametrize( "name", ["crashmanager:crashdel", "crashmanager:crashedit", "crashmanager:crashview"], ) -def test_crash_simple_get(client, cm, name): # pylint: disable=invalid-name +def test_crash_simple_get( + client: Client, cm: _cm_result, name: str +) -> None: # pylint: disable=invalid-name """No errors are thrown in template""" client.login(username="test", password="test") crash = cm.create_crash() @@ -62,7 +74,7 @@ def test_crash_simple_get(client, cm, name): # pylint: disable=invalid-name assert response.status_code == requests.codes["ok"] -def test_delete_testcase(cm): +def test_delete_testcase(cm: _cm_result) -> None: """Testcases should be delete when TestCase object is removed""" testcase = cm.create_testcase("test.txt", "hello world") test_file = testcase.test.name @@ -76,7 +88,7 @@ def test_delete_testcase(cm): ) -def test_delete_testcase_crash(cm): +def test_delete_testcase_crash(cm: _cm_result) -> None: """Testcases should be delete when CrashInfo object is removed""" testcase = cm.create_testcase("test.txt", "hello world") test_file = testcase.test.name diff --git a/server/crashmanager/tests/test_crashes_rest.py b/server/crashmanager/tests/test_crashes_rest.py index 730fd0c3b..581089475 100644 --- a/server/crashmanager/tests/test_crashes_rest.py +++ b/server/crashmanager/tests/test_crashes_rest.py @@ -9,6 +9,8 @@ file, You can obtain one at http://mozilla.org/MPL/2.0/. """ +from __future__ import annotations + import json import logging import os.path @@ -18,10 +20,15 @@ import pytest import requests +from django.contrib.auth.models import User +from rest_framework.test import APIClient +from Collector.Collector import DataType from crashmanager.models import CrashEntry from crashmanager.models import TestCase as cmTestCase +from .conftest import _cm_result + # What should be allowed: # # +--------+------+----------+---------+---------+--------------+-------------------+ @@ -83,7 +90,9 @@ @pytest.mark.parametrize( "url", ["/crashmanager/rest/crashes/", "/crashmanager/rest/crashes/1/"] ) -def test_rest_crashes_no_auth(db, api_client, method, url): +def test_rest_crashes_no_auth( + db: None, api_client: APIClient, method: str, url: str +) -> None: """must yield unauthorized without authentication""" assert ( getattr(api_client, method)(url, {}).status_code @@ -95,7 +104,9 @@ def test_rest_crashes_no_auth(db, api_client, method, url): @pytest.mark.parametrize( "url", ["/crashmanager/rest/crashes/", "/crashmanager/rest/crashes/1/"] ) -def test_rest_crashes_no_perm(user_noperm, api_client, method, url): +def test_rest_crashes_no_perm( + user_noperm: User, api_client: APIClient, method: str, url: str +) -> None: """must yield forbidden without permission""" assert ( getattr(api_client, method)(url, {}).status_code == requests.codes["forbidden"] @@ -125,7 +136,9 @@ def test_rest_crashes_no_perm(user_noperm, api_client, method, url): ], indirect=["user"], ) -def test_rest_crashes_methods(api_client, user, method, url): +def test_rest_crashes_methods( + api_client: APIClient, user: User, method: str, url: str +) -> None: """must yield method-not-allowed for unsupported methods""" assert ( getattr(api_client, method)(url, {}).status_code @@ -133,7 +146,9 @@ def test_rest_crashes_methods(api_client, user, method, url): ) -def _compare_rest_result_to_crash(result, crash, raw=True): +def _compare_rest_result_to_crash( + result: dict[str, str], crash, raw: bool = True +) -> None: expected_fields = { "args", "bucket", @@ -180,8 +195,11 @@ def _compare_rest_result_to_crash(result, crash, raw=True): def _compare_created_data_to_crash( - data, crash, crash_address=None, short_signature=None -): + data: DataType, + crash: CrashEntry, + crash_address: str | None = None, + short_signature: str | None = None, +) -> None: for field in ("rawStdout", "rawStderr", "rawCrashData"): assert getattr(crash, field) == data[field].strip() if "testcase" in data: @@ -207,7 +225,13 @@ def _compare_created_data_to_crash( @pytest.mark.parametrize("user", ["normal", "restricted"], indirect=True) @pytest.mark.parametrize("ignore_toolfilter", [True, False]) @pytest.mark.parametrize("include_raw", [True, False]) -def test_rest_crashes_list(api_client, user, cm, ignore_toolfilter, include_raw): +def test_rest_crashes_list( + api_client: APIClient, + user: User, + cm: _cm_result, + ignore_toolfilter: bool, + include_raw: bool, +) -> None: """test that list returns the right crashes""" # if restricted or normal, must only list crashes in toolfilter buckets = [cm.create_bucket(shortDescription="bucket #1"), None] @@ -255,7 +279,13 @@ def test_rest_crashes_list(api_client, user, cm, ignore_toolfilter, include_raw) @pytest.mark.parametrize("user", ["normal", "restricted"], indirect=True) @pytest.mark.parametrize("ignore_toolfilter", [True, False]) @pytest.mark.parametrize("include_raw", [True, False]) -def test_rest_crashes_retrieve(api_client, user, cm, ignore_toolfilter, include_raw): +def test_rest_crashes_retrieve( + api_client: APIClient, + user: User, + cm: _cm_result, + ignore_toolfilter: bool, + include_raw: bool, +) -> None: """test that retrieve returns the right crash""" # if restricted or normal, must only list crashes in toolfilter buckets = [cm.create_bucket(shortDescription="bucket #1"), None] @@ -307,7 +337,13 @@ def test_rest_crashes_retrieve(api_client, user, cm, ignore_toolfilter, include_ ], indirect=["user"], ) -def test_rest_crashes_list_query(api_client, cm, user, expected, toolfilter): +def test_rest_crashes_list_query( + api_client: APIClient, + cm: _cm_result, + user: User, + expected: int | None, + toolfilter: str | None, +) -> None: """test that crashes can be queried""" buckets = [cm.create_bucket(shortDescription="bucket #1"), None, None, None] testcases = [ @@ -409,7 +445,9 @@ def test_rest_crashes_list_query(api_client, cm, user, expected, toolfilter): }, ], ) -def test_rest_crashes_report_crash(api_client, user, data): +def test_rest_crashes_report_crash( + api_client: APIClient, user: User, data: DataType +) -> None: """test that crash reporting works""" resp = api_client.post("/crashmanager/rest/crashes/", data=data) LOG.debug(resp) @@ -418,9 +456,11 @@ def test_rest_crashes_report_crash(api_client, user, data): _compare_created_data_to_crash(data, crash) -def test_rest_crashes_report_crash_long(api_client, user_normal): +def test_rest_crashes_report_crash_long( + api_client: APIClient, user_normal: User +) -> None: """test that crash reporting works with fields interpreted as `long` in python 2""" - data = { + data: DataType = { "rawStdout": "", "rawStderr": "", "rawCrashData": (FIXTURE_PATH / "gdb_crash_data.txt").read_text(), @@ -446,7 +486,9 @@ def test_rest_crashes_report_crash_long(api_client, user_normal): "crashmanager.models.CrashEntry.save", new=Mock(side_effect=RuntimeError("crashentry failing intentionally")), ) -def test_rest_crashes_report_bad_crash_removes_testcase(api_client, user_normal): +def test_rest_crashes_report_bad_crash_removes_testcase( + api_client: APIClient, user_normal: User +) -> None: """test that reporting a bad crash doesn't leave an orphaned testcase""" data = { "rawStdout": "data on\nstdout", @@ -469,7 +511,9 @@ def test_rest_crashes_report_bad_crash_removes_testcase(api_client, user_normal) assert not cmTestCase.objects.exists() -def test_rest_crashes_report_crash_long_sig(api_client, user_normal): +def test_rest_crashes_report_crash_long_sig( + api_client: APIClient, user_normal: User +) -> None: """test that crash reporting works with an assertion message too long for shortSignature""" data = { @@ -493,7 +537,9 @@ def test_rest_crashes_report_crash_long_sig(api_client, user_normal): _compare_created_data_to_crash(data, crash, short_signature=expected) -def test_rest_crash_update(api_client, cm, user_normal): +def test_rest_crash_update( + api_client: APIClient, cm: _cm_result, user_normal: User +) -> None: """test that only allowed fields of CrashEntry can be updated""" test = cm.create_testcase("test.txt", quality=0) bucket = cm.create_bucket(shortDescription="bucket #1") @@ -543,7 +589,9 @@ def test_rest_crash_update(api_client, cm, user_normal): assert test.quality == 5 -def test_rest_crash_update_restricted(api_client, cm, user_restricted): +def test_rest_crash_update_restricted( + api_client: APIClient, cm: _cm_result, user_restricted: User +) -> None: """test that restricted users cannot perform updates on CrashEntry""" test = cm.create_testcase("test.txt", quality=0) bucket = cm.create_bucket(shortDescription="bucket #1") diff --git a/server/crashmanager/tests/test_crashmanager.py b/server/crashmanager/tests/test_crashmanager.py index dd219c3c7..259538964 100644 --- a/server/crashmanager/tests/test_crashmanager.py +++ b/server/crashmanager/tests/test_crashmanager.py @@ -8,10 +8,15 @@ License, v. 2.0. If a copy of the MPL was not distributed with this file, You can obtain one at http://mozilla.org/MPL/2.0/. """ + +from __future__ import annotations + import logging +import typing import pytest import requests +from django.test.client import Client from django.urls import reverse LOG = logging.getLogger("fm.crashmanager.tests.crashmanager") @@ -20,22 +25,27 @@ ) # pylint: disable=invalid-name -def test_crashmanager_redirect(client): +def test_crashmanager_redirect(client: Client) -> None: """Request without login hits the login redirect""" resp = client.get("/") assert resp.status_code == requests.codes["found"] - assert resp.url == "/login/?next=/" + assert ( + typing.cast(typing.Union[str, None], getattr(resp, "url", None)) + == "/login/?next=/" + ) -def test_crashmanager_no_login(client): +def test_crashmanager_no_login(client: Client) -> None: """Request of root url redirects to crashes view""" client.login(username="test", password="test") resp = client.get("/") assert resp.status_code == requests.codes["found"] - assert resp.url == reverse("crashmanager:index") + assert typing.cast(typing.Union[str, None], getattr(resp, "url", None)) == reverse( + "crashmanager:index" + ) -def test_crashmanager_logout(client): +def test_crashmanager_logout(client: Client) -> None: """Logout url actually logs us out""" client.login(username="test", password="test") assert ( @@ -46,10 +56,13 @@ def test_crashmanager_logout(client): response = client.get("/") LOG.debug(response) assert response.status_code == requests.codes["found"] - assert response.url == "/login/?next=/" + assert ( + typing.cast(typing.Union[str, None], getattr(response, "url", None)) + == "/login/?next=/" + ) -def test_crashmanager_noperm(client): +def test_crashmanager_noperm(client: Client) -> None: """Request without permission results in 403""" client.login(username="test-noperm", password="test") resp = client.get(reverse("crashmanager:index")) diff --git a/server/crashmanager/tests/test_inbox_rest.py b/server/crashmanager/tests/test_inbox_rest.py index 8cfaea91c..15d93c037 100644 --- a/server/crashmanager/tests/test_inbox_rest.py +++ b/server/crashmanager/tests/test_inbox_rest.py @@ -9,16 +9,20 @@ file, You can obtain one at http://mozilla.org/MPL/2.0/. """ +from __future__ import annotations + import json import logging import pytest import requests from django.conf import settings +from django.contrib.auth.models import User from django.urls import reverse from django.utils import timezone from notifications.models import Notification from notifications.signals import notify +from rest_framework.test import APIClient from crashmanager.models import ( OS, @@ -33,11 +37,15 @@ ) from taskmanager.models import Pool, Task +from .conftest import _cm_result + LOG = logging.getLogger("fm.crashmanager.tests.inbox.rest") @pytest.mark.parametrize("method", ["delete", "get", "patch", "post", "put"]) -def test_rest_notifications_no_auth(db, api_client, method): +def test_rest_notifications_no_auth( + db: None, api_client: APIClient, method: str +) -> None: """must yield unauthorized without authentication""" assert ( getattr(api_client, method)("/crashmanager/rest/inbox/", {}).status_code @@ -46,7 +54,9 @@ def test_rest_notifications_no_auth(db, api_client, method): @pytest.mark.parametrize("method", ["delete", "get", "patch", "post", "put"]) -def test_rest_notifications_no_perm(user_noperm, api_client, method): +def test_rest_notifications_no_perm( + user_noperm: User, api_client: APIClient, method: str +) -> None: """must yield forbidden without permission""" assert ( getattr(api_client, method)("/crashmanager/rest/inbox/", {}).status_code @@ -68,7 +78,9 @@ def test_rest_notifications_no_perm(user_noperm, api_client, method): ], indirect=["user"], ) -def test_rest_notifications_methods(api_client, user, method, url): +def test_rest_notifications_methods( + api_client: APIClient, user: User, method: str, url: str +) -> None: """must yield method-not-allowed for unsupported methods""" assert ( getattr(api_client, method)(url, {}).status_code @@ -92,7 +104,9 @@ def test_rest_notifications_methods(api_client, user, method, url): ], indirect=["user"], ) -def test_rest_notifications_methods_not_found(api_client, user, method, url): +def test_rest_notifications_methods_not_found( + api_client: APIClient, user: User, method: str, url: str +) -> None: """must yield not-found for undeclared methods""" assert ( getattr(api_client, method)(url, {}).status_code == requests.codes["not_found"] @@ -100,7 +114,9 @@ def test_rest_notifications_methods_not_found(api_client, user, method, url): @pytest.mark.parametrize("user", ["normal", "restricted"], indirect=True) -def test_rest_notifications_list_unread(api_client, user, cm): +def test_rest_notifications_list_unread( + api_client: APIClient, user: User, cm: _cm_result +) -> None: """test that list returns the right notifications""" provider = BugProvider.objects.create( classname="BugzillaProvider", hostname="provider.com", urlTemplate="%s" @@ -209,7 +225,9 @@ def test_rest_notifications_list_unread(api_client, user, cm): @pytest.mark.parametrize("user", ["normal", "restricted"], indirect=True) -def test_rest_notifications_mark_as_read(api_client, user, cm): +def test_rest_notifications_mark_as_read( + api_client: APIClient, user: User, cm: _cm_result +) -> None: """test that mark_as_read only marks the targetted notification as read""" bucket = Bucket.objects.create( signature=json.dumps( @@ -252,7 +270,9 @@ def test_rest_notifications_mark_as_read(api_client, user, cm): @pytest.mark.parametrize("user", ["normal", "restricted"], indirect=True) -def test_rest_notifications_mark_all_as_read(api_client, user, cm): +def test_rest_notifications_mark_all_as_read( + api_client: APIClient, user: User, cm: _cm_result +) -> None: """test that mark_all_as_read marks all user notifications as read""" bucket = Bucket.objects.create( signature=json.dumps( diff --git a/server/crashmanager/tests/test_mgmt_add_permission.py b/server/crashmanager/tests/test_mgmt_add_permission.py index 82d4a9bb5..d287f2a98 100644 --- a/server/crashmanager/tests/test_mgmt_add_permission.py +++ b/server/crashmanager/tests/test_mgmt_add_permission.py @@ -8,6 +8,9 @@ License, v. 2.0. If a copy of the MPL was not distributed with this file, You can obtain one at http://mozilla.org/MPL/2.0/. """ + +from __future__ import annotations + import pytest from django.contrib.auth.models import User from django.core.management import CommandError, call_command @@ -15,12 +18,12 @@ pytestmark = pytest.mark.django_db() # pylint: disable=invalid-name -def test_args(): +def test_args() -> None: with pytest.raises(CommandError, match=r"Error: .*arguments.*"): call_command("add_permission") -def test_no_such_user(): +def test_no_such_user() -> None: with pytest.raises(User.DoesNotExist): call_command( "add_permission", @@ -29,20 +32,20 @@ def test_no_such_user(): ) -def test_no_perms(): +def test_no_perms() -> None: User.objects.create_user("test", "test@example.com", "test") with pytest.raises(CommandError, match=r"Error: .*arguments.*"): call_command("add_permission", "test") -def test_one_perm(): +def test_one_perm() -> None: user = User.objects.create_user("test", "test@example.com", "test") user.user_permissions.clear() # clear any default permissions call_command("add_permission", "test", "crashmanager.models.User:view_crashmanager") assert set(user.get_all_permissions()) == {"crashmanager.view_crashmanager"} -def test_two_perms(): +def test_two_perms() -> None: user = User.objects.create_user("test", "test@example.com", "test") user.user_permissions.clear() # clear any default permissions call_command( diff --git a/server/crashmanager/tests/test_mgmt_bug_update_status.py b/server/crashmanager/tests/test_mgmt_bug_update_status.py index 19b38c379..636c0f987 100644 --- a/server/crashmanager/tests/test_mgmt_bug_update_status.py +++ b/server/crashmanager/tests/test_mgmt_bug_update_status.py @@ -8,6 +8,9 @@ License, v. 2.0. If a copy of the MPL was not distributed with this file, You can obtain one at http://mozilla.org/MPL/2.0/. """ + +from __future__ import annotations + import json from unittest.mock import patch @@ -15,6 +18,7 @@ from django.contrib.auth.models import User from django.core.management import CommandError, call_command from notifications.models import Notification +from pytest_mock import MockerFixture from crashmanager.models import ( OS, @@ -33,12 +37,12 @@ pytestmark = pytest.mark.usefixtures("crashmanager_test") -def test_args(): +def test_args() -> None: with pytest.raises(CommandError, match=r"Error: unrecognized arguments: "): call_command("bug_update_status", "") -def test_none(): +def test_none() -> None: call_command("bug_update_status") @@ -46,7 +50,7 @@ def test_none(): "crashmanager.Bugtracker.BugzillaProvider.BugzillaProvider.getBugStatus", return_value={"0": None}, ) -def test_fake_with_notification(mock_get_bug_status): +def test_fake_with_notification(mock_get_bug_status: MockerFixture) -> None: provider = BugProvider.objects.create( classname="BugzillaProvider", hostname="provider.com", urlTemplate="%s" ) diff --git a/server/crashmanager/tests/test_mgmt_cleanup_old_crashes.py b/server/crashmanager/tests/test_mgmt_cleanup_old_crashes.py index 80ae63e5e..580f13baa 100644 --- a/server/crashmanager/tests/test_mgmt_cleanup_old_crashes.py +++ b/server/crashmanager/tests/test_mgmt_cleanup_old_crashes.py @@ -8,11 +8,16 @@ License, v. 2.0. If a copy of the MPL was not distributed with this file, You can obtain one at http://mozilla.org/MPL/2.0/. """ + +from __future__ import annotations + from datetime import timedelta +from typing import Any, cast import pytest from django.core.management import CommandError, call_command from django.utils import timezone +from pytest_django.fixtures import SettingsWrapper from crashmanager.models import ( OS, @@ -29,7 +34,7 @@ pytestmark = pytest.mark.django_db() # pylint: disable=invalid-name -def _crashentry_create(**kwds): +def _crashentry_create(**kwds: Any) -> CrashEntry: defaults = { "client": Client.objects.create(), "os": OS.objects.create(), @@ -38,22 +43,22 @@ def _crashentry_create(**kwds): "tool": Tool.objects.create(), } defaults.update(kwds) - return CrashEntry.objects.create(**defaults) + return cast(CrashEntry, CrashEntry.objects.create(**defaults)) -def test_args(): +def test_args() -> None: with pytest.raises(CommandError, match=r"Error: unrecognized arguments: "): call_command("cleanup_old_crashes", "") -def test_bug_cleanup(): +def test_bug_cleanup() -> None: prov = BugProvider.objects.create() Bug.objects.create(externalType=prov) call_command("cleanup_old_crashes") assert Bug.objects.count() == 0 -def test_closed_bugs(settings): +def test_closed_bugs(settings: SettingsWrapper) -> None: """all buckets that have been closed for x days""" settings.CLEANUP_CRASHES_AFTER_DAYS = 4 settings.CLEANUP_FIXED_BUCKETS_AFTER_DAYS = 2 @@ -83,7 +88,7 @@ def test_closed_bugs(settings): } -def test_empty_bucket(settings): +def test_empty_bucket(settings: SettingsWrapper) -> None: """all buckets that are empty""" settings.CLEANUP_CRASHES_AFTER_DAYS = 4 settings.CLEANUP_FIXED_BUCKETS_AFTER_DAYS = 2 @@ -102,7 +107,7 @@ def test_empty_bucket(settings): assert Bug.objects.count() == 1 -def test_old_crashes(settings): +def test_old_crashes(settings: SettingsWrapper) -> None: """all entries that are older than x days and not in any bucket or bucket has no bug associated with it""" settings.CLEANUP_CRASHES_AFTER_DAYS = 3 diff --git a/server/crashmanager/tests/test_mgmt_export_signatures.py b/server/crashmanager/tests/test_mgmt_export_signatures.py index 18b79e974..e6a72fab6 100644 --- a/server/crashmanager/tests/test_mgmt_export_signatures.py +++ b/server/crashmanager/tests/test_mgmt_export_signatures.py @@ -9,6 +9,8 @@ file, You can obtain one at http://mozilla.org/MPL/2.0/. """ +from __future__ import annotations + import json import os import re @@ -23,12 +25,12 @@ pytestmark = pytest.mark.django_db() # pylint: disable=invalid-name -def test_args(): +def test_args() -> None: with pytest.raises(CommandError, match=r"Error: .* arguments"): call_command("export_signatures") -def test_none(): +def test_none() -> None: fd, tmpf = tempfile.mkstemp() os.close(fd) try: @@ -40,7 +42,7 @@ def test_none(): os.unlink(tmpf) -def test_some(): +def test_some() -> None: sig1 = Bucket.objects.create(signature="sig1", frequent=True) sig2 = Bucket.objects.create(signature="sig2", shortDescription="desc") members = set() diff --git a/server/crashmanager/tests/test_mgmt_get_auth_token.py b/server/crashmanager/tests/test_mgmt_get_auth_token.py index f8d9a6041..04c152410 100644 --- a/server/crashmanager/tests/test_mgmt_get_auth_token.py +++ b/server/crashmanager/tests/test_mgmt_get_auth_token.py @@ -8,6 +8,9 @@ License, v. 2.0. If a copy of the MPL was not distributed with this file, You can obtain one at http://mozilla.org/MPL/2.0/. """ + +from __future__ import annotations + import re import pytest @@ -18,17 +21,17 @@ pytestmark = pytest.mark.django_db() # pylint: disable=invalid-name -def test_args(): +def test_args() -> None: with pytest.raises(CommandError, match=r"Error: Enter at least one label."): call_command("get_auth_token") -def test_no_such_user(): +def test_no_such_user() -> None: with pytest.raises(User.DoesNotExist): call_command("get_auth_token", "user") -def test_one_user(capsys): +def test_one_user(capsys: pytest.CaptureFixture[str]) -> None: user = User.objects.create_user("test", "test@example.com", "test") call_command("get_auth_token", "test") out, _ = capsys.readouterr() @@ -39,7 +42,7 @@ def test_one_user(capsys): assert len(key) > 32 # just check that it's reasonably long -def test_two_users(capsys): +def test_two_users(capsys: pytest.CaptureFixture[str]) -> None: users = ( User.objects.create_user("test", "test@example.com", "test"), User.objects.create_user("test2", "test2@example.com", "test2"), diff --git a/server/crashmanager/tests/test_mgmt_triage_new_crashes.py b/server/crashmanager/tests/test_mgmt_triage_new_crashes.py index 3cbb42192..293435447 100644 --- a/server/crashmanager/tests/test_mgmt_triage_new_crashes.py +++ b/server/crashmanager/tests/test_mgmt_triage_new_crashes.py @@ -8,6 +8,9 @@ License, v. 2.0. If a copy of the MPL was not distributed with this file, You can obtain one at http://mozilla.org/MPL/2.0/. """ + +from __future__ import annotations + import json import pytest @@ -31,16 +34,16 @@ pytestmark = pytest.mark.usefixtures("crashmanager_test") -def test_args(): +def test_args() -> None: with pytest.raises(CommandError, match=r"Error: unrecognized arguments: "): call_command("triage_new_crashes", "") -def test_none(): +def test_none() -> None: call_command("triage_new_crashes") -def test_some(): +def test_some() -> None: buckets = [ Bucket.objects.create( signature=json.dumps( @@ -80,7 +83,7 @@ def test_some(): assert crashes[2].bucket is None -def test_some_with_notification(): +def test_some_with_notification() -> None: buckets = [ Bucket.objects.create( signature=json.dumps( diff --git a/server/crashmanager/tests/test_rest_live.py b/server/crashmanager/tests/test_rest_live.py index 4bbf79c02..ac218ef0a 100644 --- a/server/crashmanager/tests/test_rest_live.py +++ b/server/crashmanager/tests/test_rest_live.py @@ -11,19 +11,27 @@ @contact: choller@mozilla.com """ + +from __future__ import annotations + from urllib.parse import urlsplit import pytest import requests +from django.contrib.auth.models import User +from pytest_django.live_server_helper import LiveServer pytestmark = pytest.mark.django_db() # pylint: disable=invalid-name pytest_plugins = "server.tests" # pylint: disable=invalid-name @pytest.mark.skip -def test_RESTCrashEntryInterface(live_server, fm_user): - url = urlsplit(live_server.url) - url = f"{url.scheme}://{url.hostname}:{url.port}/crashmanager/rest/crashes/" +def test_RESTCrashEntryInterface(live_server: LiveServer, fm_user: User) -> None: + url_split = urlsplit(live_server.url) + url = ( + f"{url_split.scheme}://{url_split.hostname}:{url_split.port}" + "/crashmanager/rest/crashes/" + ) # Must yield forbidden without authentication assert requests.get(url).status_code == requests.codes["unauthorized"] @@ -67,9 +75,12 @@ def test_RESTCrashEntryInterface(live_server, fm_user): assert json[lengthBeforePost]["product_version"] == "ba0bc4f26681" -def test_RESTSignatureInterface(live_server): - url = urlsplit(live_server.url) - url = f"{url.scheme}://{url.hostname}:{url.port}/crashmanager/rest/signatures/" +def test_RESTSignatureInterface(live_server: LiveServer) -> None: + url_split = urlsplit(live_server.url) + url = ( + f"{url_split.scheme}://{url_split.hostname}:{url_split.port}" + "/crashmanager/rest/signatures/" + ) # Must yield forbidden without authentication assert requests.get(url).status_code == requests.codes["not_found"] diff --git a/server/crashmanager/tests/test_signatures.py b/server/crashmanager/tests/test_signatures.py index 1b531b3f8..3aa50d8f2 100644 --- a/server/crashmanager/tests/test_signatures.py +++ b/server/crashmanager/tests/test_signatures.py @@ -9,16 +9,23 @@ License, v. 2.0. If a copy of the MPL was not distributed with this file, You can obtain one at http://mozilla.org/MPL/2.0/. """ + +from __future__ import annotations + import json import logging +import typing import pytest import requests +from django.http.response import HttpResponse +from django.test.client import Client from django.urls import reverse from crashmanager.models import Bucket, BucketWatch, CrashEntry from . import assert_contains +from .conftest import _cm_result LOG = logging.getLogger("fm.crashmanager.tests.signatures") pytestmark = pytest.mark.usefixtures( @@ -37,24 +44,29 @@ ("crashmanager:sigedit", {"sigid": 1}), ], ) -def test_signatures_no_login(client, name, kwargs): +def test_signatures_no_login(client: Client, name: str, kwargs: dict[str, int]) -> None: """Request without login hits the login redirect""" path = reverse(name, kwargs=kwargs) resp = client.get(path) assert resp.status_code == requests.codes["found"] - assert resp.url == "/login/?next=" + path + assert ( + typing.cast(typing.Union[str, None], getattr(resp, "url", None)) + == "/login/?next=" + path + ) -def test_signatures_view(client): # pylint: disable=invalid-name +def test_signatures_view(client: Client) -> None: # pylint: disable=invalid-name """Check that the Vue component is called""" client.login(username="test", password="test") response = client.get(reverse("crashmanager:signatures")) LOG.debug(response) assert response.status_code == requests.codes["ok"] - assert_contains(response, "signatureslist") + assert_contains(typing.cast(HttpResponse, response), "signatureslist") -def test_del_signature_simple_get(client, cm): # pylint: disable=invalid-name +def test_del_signature_simple_get( + client: Client, cm: _cm_result +) -> None: # pylint: disable=invalid-name """No errors are thrown in template""" client.login(username="test", password="test") @@ -67,8 +79,13 @@ def test_del_signature_simple_get(client, cm): # pylint: disable=invalid-name response = client.get(reverse("crashmanager:sigdel", kwargs={"sigid": bucket.pk})) LOG.debug(response) assert response.status_code == requests.codes["ok"] - assert_contains(response, "Are you sure that you want to delete this signature?") - assert_contains(response, "Bucket contains no crash entries.") + assert_contains( + typing.cast(HttpResponse, response), + "Are you sure that you want to delete this signature?", + ) + assert_contains( + typing.cast(HttpResponse, response), "Bucket contains no crash entries." + ) # 1 crash not in toolfilter cm.create_toolfilter(crash1.tool) @@ -77,7 +94,10 @@ def test_del_signature_simple_get(client, cm): # pylint: disable=invalid-name response = client.get(reverse("crashmanager:sigdel", kwargs={"sigid": bucket.pk})) LOG.debug(response) assert response.status_code == requests.codes["ok"] - assert_contains(response, "Are you sure that you want to delete this signature?") + assert_contains( + typing.cast(HttpResponse, response), + "Are you sure that you want to delete this signature?", + ) assert_contains( response, "Also delete all crash entries with this bucket: 0 in tool filter, " @@ -92,7 +112,10 @@ def test_del_signature_simple_get(client, cm): # pylint: disable=invalid-name response = client.get(reverse("crashmanager:sigdel", kwargs={"sigid": bucket.pk})) LOG.debug(response) assert response.status_code == requests.codes["ok"] - assert_contains(response, "Are you sure that you want to delete this signature?") + assert_contains( + typing.cast(HttpResponse, response), + "Are you sure that you want to delete this signature?", + ) assert_contains( response, "Also delete all crash entries with this bucket: 1 in tool filter " @@ -105,7 +128,10 @@ def test_del_signature_simple_get(client, cm): # pylint: disable=invalid-name response = client.get(reverse("crashmanager:sigdel", kwargs={"sigid": bucket.pk})) LOG.debug(response) assert response.status_code == requests.codes["ok"] - assert_contains(response, "Are you sure that you want to delete this signature?") + assert_contains( + typing.cast(HttpResponse, response), + "Are you sure that you want to delete this signature?", + ) assert_contains( response, "Also delete all crash entries with this bucket: 1 in tool filter, " @@ -118,7 +144,10 @@ def test_del_signature_simple_get(client, cm): # pylint: disable=invalid-name response = client.get(reverse("crashmanager:sigdel", kwargs={"sigid": bucket.pk})) LOG.debug(response) assert response.status_code == requests.codes["ok"] - assert_contains(response, "Are you sure that you want to delete this signature?") + assert_contains( + typing.cast(HttpResponse, response), + "Are you sure that you want to delete this signature?", + ) assert_contains( response, "Also delete all crash entries with this bucket: 1 in tool filter, " @@ -126,7 +155,9 @@ def test_del_signature_simple_get(client, cm): # pylint: disable=invalid-name ) -def test_find_signature_simple_get(client, cm): # pylint: disable=invalid-name +def test_find_signature_simple_get( + client: Client, cm: _cm_result +) -> None: # pylint: disable=invalid-name """No errors are thrown in template""" client.login(username="test", password="test") crash = cm.create_crash() @@ -143,7 +174,9 @@ def test_find_signature_simple_get(client, cm): # pylint: disable=invalid-name assert response.status_code == requests.codes["ok"] -def test_opt_signature_simple_get(client, cm): # pylint: disable=invalid-name +def test_opt_signature_simple_get( + client: Client, cm: _cm_result +) -> None: # pylint: disable=invalid-name """No errors are thrown in template""" client.login(username="test", password="test") bucket = cm.create_bucket( @@ -156,7 +189,9 @@ def test_opt_signature_simple_get(client, cm): # pylint: disable=invalid-name assert response.status_code == requests.codes["ok"] -def test_try_signature_simple_get(client, cm): # pylint: disable=invalid-name +def test_try_signature_simple_get( + client: Client, cm: _cm_result +) -> None: # pylint: disable=invalid-name """No errors are thrown in template""" client.login(username="test", password="test") bucket = cm.create_bucket( @@ -172,16 +207,18 @@ def test_try_signature_simple_get(client, cm): # pylint: disable=invalid-name assert response.status_code == requests.codes["ok"] -def test_new_signature_view(client): +def test_new_signature_view(client: Client) -> None: """Check that the Vue component is called""" client.login(username="test", password="test") response = client.get(reverse("crashmanager:signew")) LOG.debug(response) assert response.status_code == requests.codes["ok"] - assert_contains(response, "createoredit") + assert_contains(typing.cast(HttpResponse, response), "createoredit") -def test_edit_signature_view(client, cm): # pylint: disable=invalid-name +def test_edit_signature_view( + client: Client, cm: _cm_result +) -> None: # pylint: disable=invalid-name """Check that the Vue component is called""" client.login(username="test", password="test") sig = json.dumps( @@ -191,22 +228,28 @@ def test_edit_signature_view(client, cm): # pylint: disable=invalid-name response = client.get(reverse("crashmanager:sigedit", kwargs={"sigid": bucket.pk})) LOG.debug(response) assert response.status_code == requests.codes["ok"] - assert_contains(response, "createoredit") + assert_contains(typing.cast(HttpResponse, response), "createoredit") assert response.context["bucketId"] == bucket.pk -def test_del_signature_empty(client, cm): # pylint: disable=invalid-name +def test_del_signature_empty( + client: Client, cm: _cm_result +) -> None: # pylint: disable=invalid-name """Test deleting a signature with no crashes""" client.login(username="test", password="test") bucket = cm.create_bucket(shortDescription="bucket #1") response = client.post(reverse("crashmanager:sigdel", kwargs={"sigid": bucket.pk})) LOG.debug(response) assert response.status_code == requests.codes["found"] - assert response.url == reverse("crashmanager:signatures") + assert typing.cast( + typing.Union[str, None], getattr(response, "url", None) + ) == reverse("crashmanager:signatures") assert not Bucket.objects.count() -def test_del_signature_leave_entries(client, cm): # pylint: disable=invalid-name +def test_del_signature_leave_entries( + client: Client, cm: _cm_result +) -> None: # pylint: disable=invalid-name """Test deleting a signature with crashes and leave entries""" client.login(username="test", password="test") bucket = cm.create_bucket(shortDescription="bucket #1") @@ -214,13 +257,17 @@ def test_del_signature_leave_entries(client, cm): # pylint: disable=invalid-nam response = client.post(reverse("crashmanager:sigdel", kwargs={"sigid": bucket.pk})) LOG.debug(response) assert response.status_code == requests.codes["found"] - assert response.url == reverse("crashmanager:signatures") + assert typing.cast( + typing.Union[str, None], getattr(response, "url", None) + ) == reverse("crashmanager:signatures") assert not Bucket.objects.count() crash = CrashEntry.objects.get(pk=crash.pk) # re-read assert crash.bucket is None -def test_del_signature_del_entries(client, cm): # pylint: disable=invalid-name +def test_del_signature_del_entries( + client: Client, cm: _cm_result +) -> None: # pylint: disable=invalid-name """Test deleting a signature with crashes and removing entries""" client.login(username="test", password="test") bucket = cm.create_bucket(shortDescription="bucket #1") @@ -230,12 +277,14 @@ def test_del_signature_del_entries(client, cm): # pylint: disable=invalid-name ) LOG.debug(response) assert response.status_code == requests.codes["found"] - assert response.url == reverse("crashmanager:signatures") + assert typing.cast( + typing.Union[str, None], getattr(response, "url", None) + ) == reverse("crashmanager:signatures") assert not Bucket.objects.count() assert not CrashEntry.objects.count() -def test_watch_signature_empty(client): +def test_watch_signature_empty(client: Client) -> None: """If no watched signatures, that should be shown""" client.login(username="test", password="test") response = client.get(reverse("crashmanager:sigwatch")) @@ -245,7 +294,9 @@ def test_watch_signature_empty(client): ) -def test_watch_signature_buckets(client, cm): # pylint: disable=invalid-name +def test_watch_signature_buckets( + client: Client, cm: _cm_result +) -> None: # pylint: disable=invalid-name """Watched signatures should be listed""" client.login(username="test", password="test") bucket = cm.create_bucket(shortDescription="bucket #1") @@ -261,8 +312,8 @@ def test_watch_signature_buckets(client, cm): # pylint: disable=invalid-name def test_watch_signature_buckets_new_crashes( - client, cm -): # pylint: disable=invalid-name + client: Client, cm: _cm_result +) -> None: # pylint: disable=invalid-name """Watched signatures should show new crashes""" client.login(username="test", password="test") buckets = ( @@ -289,7 +340,9 @@ def test_watch_signature_buckets_new_crashes( assert not siglist[1].newCrashes -def test_watch_signature_del(client, cm): # pylint: disable=invalid-name +def test_watch_signature_del( + client: Client, cm: _cm_result +) -> None: # pylint: disable=invalid-name """Deleting a signature watch""" client.login(username="test", password="test") bucket = cm.create_bucket(shortDescription="bucket #1") @@ -303,7 +356,7 @@ def test_watch_signature_del(client, cm): # pylint: disable=invalid-name "Are you sure that you want to stop watching this signature for new crash " "entries?", ) - assert_contains(response, bucket.shortDescription) + assert_contains(typing.cast(HttpResponse, response), bucket.shortDescription) response = client.post( reverse("crashmanager:sigwatchdel", kwargs={"sigid": bucket.pk}) ) @@ -311,10 +364,14 @@ def test_watch_signature_del(client, cm): # pylint: disable=invalid-name assert not BucketWatch.objects.count() assert Bucket.objects.get() == bucket assert response.status_code == requests.codes["found"] - assert response.url == reverse("crashmanager:sigwatch") + assert typing.cast( + typing.Union[str, None], getattr(response, "url", None) + ) == reverse("crashmanager:sigwatch") -def test_watch_signature_delsig(client, cm): # pylint: disable=invalid-name +def test_watch_signature_delsig( + client: Client, cm: _cm_result +) -> None: # pylint: disable=invalid-name """Deleting a watched signature""" client.login(username="test", password="test") bucket = cm.create_bucket(shortDescription="bucket #1") @@ -323,7 +380,9 @@ def test_watch_signature_delsig(client, cm): # pylint: disable=invalid-name assert not BucketWatch.objects.count() -def test_watch_signature_update(client, cm): # pylint: disable=invalid-name +def test_watch_signature_update( + client: Client, cm: _cm_result +) -> None: # pylint: disable=invalid-name """Updating a signature watch""" client.login(username="test", password="test") bucket = cm.create_bucket(shortDescription="bucket #1") @@ -336,13 +395,17 @@ def test_watch_signature_update(client, cm): # pylint: disable=invalid-name ) LOG.debug(response) assert response.status_code == requests.codes["found"] - assert response.url == reverse("crashmanager:sigwatch") + assert typing.cast( + typing.Union[str, None], getattr(response, "url", None) + ) == reverse("crashmanager:sigwatch") watch = BucketWatch.objects.get(pk=watch.pk) assert watch.bucket == bucket assert watch.lastCrash == crash2.pk -def test_watch_signature_new(client, cm): # pylint: disable=invalid-name +def test_watch_signature_new( + client: Client, cm: _cm_result +) -> None: # pylint: disable=invalid-name """Creating a signature watch""" client.login(username="test", password="test") bucket = cm.create_bucket(shortDescription="bucket #1") @@ -353,13 +416,17 @@ def test_watch_signature_new(client, cm): # pylint: disable=invalid-name ) LOG.debug(response) assert response.status_code == requests.codes["found"] - assert response.url == reverse("crashmanager:sigwatch") + assert typing.cast( + typing.Union[str, None], getattr(response, "url", None) + ) == reverse("crashmanager:sigwatch") watch = BucketWatch.objects.get() assert watch.bucket == bucket assert watch.lastCrash == crash.pk -def test_watch_signature_crashes(client, cm): # pylint: disable=invalid-name +def test_watch_signature_crashes( + client: Client, cm: _cm_result +) -> None: # pylint: disable=invalid-name """Crashes in a signature watch should be shown correctly.""" client.login(username="test", password="test") bucket = cm.create_bucket(shortDescription="bucket #1") @@ -371,4 +438,4 @@ def test_watch_signature_crashes(client, cm): # pylint: disable=invalid-name assert response.status_code == requests.codes["ok"] assert response.context["watchId"] == watch.id assert response.context["restricted"] is False - assert_contains(response, "crasheslist") + assert_contains(typing.cast(HttpResponse, response), "crasheslist") diff --git a/server/crashmanager/tests/test_signatures_rest.py b/server/crashmanager/tests/test_signatures_rest.py index 3ca7b1b3d..1c61d3dfe 100644 --- a/server/crashmanager/tests/test_signatures_rest.py +++ b/server/crashmanager/tests/test_signatures_rest.py @@ -8,17 +8,22 @@ License, v. 2.0. If a copy of the MPL was not distributed with this file, You can obtain one at http://mozilla.org/MPL/2.0/. """ + +from __future__ import annotations + import json import logging import pytest import requests +from django.contrib.auth.models import User as DjangoUser from django.urls import reverse from rest_framework import status +from rest_framework.test import APIClient -from crashmanager.models import Bucket, Bug, CrashEntry +from crashmanager.models import Bucket, Bug, CrashEntry, User -from .conftest import _create_user +from .conftest import _cm_result, _create_user # What should be allowed: # @@ -77,8 +82,15 @@ def _compare_rest_result_to_bucket( - result, bucket, size, quality, best_entry=None, latest=None, hist=[], vue=False -): + result, + bucket: Bucket, + size: int, + quality: int | None, + best_entry: int | None = None, + latest: int | None = None, + hist: list[dict[str, object]] = [], + vue: bool = False, +) -> None: attributes = { "best_entry", "best_quality", @@ -139,7 +151,9 @@ def _compare_rest_result_to_bucket( @pytest.mark.parametrize( "url", ["/crashmanager/rest/buckets/", "/crashmanager/rest/buckets/1/"] ) -def test_rest_signatures_no_auth(db, api_client, method, url): +def test_rest_signatures_no_auth( + db: None, api_client: APIClient, method: str, url: str +) -> None: """must yield unauthorized without authentication""" assert ( getattr(api_client, method)(url, {}).status_code @@ -151,7 +165,9 @@ def test_rest_signatures_no_auth(db, api_client, method, url): @pytest.mark.parametrize( "url", ["/crashmanager/rest/buckets/", "/crashmanager/rest/buckets/1/"] ) -def test_rest_signatures_no_perm(user_noperm, api_client, method, url): +def test_rest_signatures_no_perm( + user_noperm: User, api_client: APIClient, method: str, url: str +) -> None: """must yield forbidden without permission""" assert ( getattr(api_client, method)(url, {}).status_code == requests.codes["forbidden"] @@ -181,7 +197,9 @@ def test_rest_signatures_no_perm(user_noperm, api_client, method, url): ], indirect=["user"], ) -def test_rest_signatures_methods(api_client, user, method, url): +def test_rest_signatures_methods( + api_client: APIClient, user: User, method: str, url: str +) -> None: """must yield method-not-allowed for unsupported methods""" assert ( getattr(api_client, method)(url, {}).status_code @@ -192,7 +210,13 @@ def test_rest_signatures_methods(api_client, user, method, url): @pytest.mark.parametrize("user", ["normal", "restricted"], indirect=True) @pytest.mark.parametrize("ignore_toolfilter", [True, False]) @pytest.mark.parametrize("vue", [True, False]) -def test_rest_signatures_list(api_client, cm, user, ignore_toolfilter, vue): +def test_rest_signatures_list( + api_client: APIClient, + cm: _cm_result, + user: DjangoUser, + ignore_toolfilter: bool, + vue: bool, +) -> None: """test that buckets can be listed""" bucket1 = cm.create_bucket(shortDescription="bucket #1") bucket2 = cm.create_bucket(shortDescription="bucket #2") @@ -249,7 +273,9 @@ def test_rest_signatures_list(api_client, cm, user, ignore_toolfilter, vue): @pytest.mark.parametrize("user", ["normal", "restricted"], indirect=True) @pytest.mark.parametrize("ignore_toolfilter", [True, False]) -def test_rest_signatures_retrieve(api_client, cm, user, ignore_toolfilter): +def test_rest_signatures_retrieve( + api_client: APIClient, cm: _cm_result, user: DjangoUser, ignore_toolfilter: bool +) -> None: """test that individual Signature can be fetched""" bucket1 = cm.create_bucket(shortDescription="bucket #1") bucket2 = cm.create_bucket(shortDescription="bucket #2") @@ -291,6 +317,7 @@ def test_rest_signatures_retrieve(api_client, cm, user, ignore_toolfilter): status_code = resp.status_code resp = resp.json() assert status_code == requests.codes["ok"], resp["detail"] + quality: int | None if user.username == "test": if ignore_toolfilter: size, quality, best, latest = [ @@ -310,8 +337,8 @@ def test_rest_signatures_retrieve(api_client, cm, user, ignore_toolfilter): @pytest.mark.parametrize("user", ["normal", "restricted"], indirect=True) @pytest.mark.parametrize("from_crash", [False, True]) def test_new_signature_create( - api_client, cm, user, from_crash -): # pylint: disable=invalid-name + api_client: APIClient, cm: _cm_result, user: DjangoUser, from_crash: bool +) -> None: # pylint: disable=invalid-name if from_crash: if user.username == "test-restricted": _create_user("test") @@ -369,8 +396,8 @@ def test_new_signature_create( @pytest.mark.parametrize("user", ["normal", "restricted"], indirect=True) @pytest.mark.parametrize("many", [False, True]) def test_new_signature_create_w_reassign( - api_client, cm, user, many -): # pylint: disable=invalid-name + api_client: APIClient, cm: _cm_result, user: User, many: bool +) -> None: # pylint: disable=invalid-name if many: crashes = [ cm.create_crash(shortSignature="crash #1", stderr="blah") @@ -417,8 +444,8 @@ def test_new_signature_create_w_reassign( @pytest.mark.parametrize("user", ["normal", "restricted"], indirect=True) @pytest.mark.parametrize("many", [False, True]) def test_new_signature_preview( - api_client, cm, user, many -): # pylint: disable=invalid-name + api_client: APIClient, cm: _cm_result, user: User, many: bool +) -> None: # pylint: disable=invalid-name if many: crashes = [ cm.create_crash(shortSignature="crash #1", stderr="blah") @@ -473,13 +500,13 @@ def test_new_signature_preview( @pytest.mark.parametrize("frequent", [True, False]) @pytest.mark.parametrize("permanent", [True, False]) def test_edit_signature_edit( - api_client, - cm, # pylint: disable=invalid-name - user, - do_not_reduce, - frequent, - permanent, -): + api_client: APIClient, + cm: _cm_result, # pylint: disable=invalid-name + user: User, + do_not_reduce: bool, + frequent: bool, + permanent: bool, +) -> None: # pylint: disable=invalid-name bucket = cm.create_bucket() crash = cm.create_crash(shortSignature="crash #1", stderr="blah") sig = json.dumps( @@ -516,8 +543,8 @@ def test_edit_signature_edit( @pytest.mark.parametrize("user", ["normal"], indirect=True) @pytest.mark.parametrize("many", [False, True]) def test_edit_signature_edit_w_reassign( - api_client, cm, user, many -): # pylint: disable=invalid-name + api_client: APIClient, cm: _cm_result, user: User, many: bool +) -> None: # pylint: disable=invalid-name bucket = cm.create_bucket() if many: crashes = [ @@ -565,8 +592,8 @@ def test_edit_signature_edit_w_reassign( @pytest.mark.parametrize("user", ["normal"], indirect=True) @pytest.mark.parametrize("many", [False, True]) def test_edit_signature_edit_preview( - api_client, cm, user, many -): # pylint: disable=invalid-name + api_client: APIClient, cm: _cm_result, user: User, many: bool +) -> None: # pylint: disable=invalid-name bucket = cm.create_bucket() if many: crashes1 = [ @@ -631,7 +658,9 @@ def test_edit_signature_edit_preview( assert in_list[0]["id"] == crash2.pk -def test_edit_signature_set_frequent(api_client, cm, user_normal): +def test_edit_signature_set_frequent( + api_client: APIClient, cm: _cm_result, user_normal: User +) -> None: """test that partial_update action marks a signature frequent without touching anything else""" bug = cm.create_bug("123") @@ -657,7 +686,9 @@ def test_edit_signature_set_frequent(api_client, cm, user_normal): assert bucket.bug == bug -def test_edit_signature_unassign_external_bug(api_client, cm, user_normal): +def test_edit_signature_unassign_external_bug( + api_client: APIClient, cm: _cm_result, user_normal: User +) -> None: """test that partial_update action marks a signature frequent without touching anything else""" bug = cm.create_bug("123") @@ -681,7 +712,9 @@ def test_edit_signature_unassign_external_bug(api_client, cm, user_normal): assert bucket.bug is None -def test_edit_signature_assign_external_bug(api_client, cm, user_normal): +def test_edit_signature_assign_external_bug( + api_client: APIClient, cm: _cm_result, user_normal: User +) -> None: """test that partial_update action create a new Bug and assign it to this Bucket""" provider = cm.create_bugprovider( hostname="test-provider.com", urlTemplate="test-provider.com/template" diff --git a/server/crashmanager/tests/test_stats.py b/server/crashmanager/tests/test_stats.py index 509c37af1..08939f0e4 100644 --- a/server/crashmanager/tests/test_stats.py +++ b/server/crashmanager/tests/test_stats.py @@ -8,14 +8,21 @@ License, v. 2.0. If a copy of the MPL was not distributed with this file, You can obtain one at http://mozilla.org/MPL/2.0/. """ + +from __future__ import annotations + import datetime import logging +import typing import pytest import requests +from django.http.response import HttpResponse +from django.test.client import Client from django.urls import reverse from . import assert_contains +from .conftest import _cm_result LOG = logging.getLogger("fm.crashmanager.tests.stats") VIEW_NAME = "crashmanager:stats" @@ -25,36 +32,43 @@ ) # pylint: disable=invalid-name -def test_stats_view_no_login(client): +def test_stats_view_no_login(client: Client) -> None: """Request without login hits the login redirect""" path = reverse(VIEW_NAME) resp = client.get(path) assert resp.status_code == requests.codes["found"] - assert resp.url == "/login/?next=" + path + assert ( + typing.cast(typing.Union[str, None], getattr(resp, "url", None)) + == "/login/?next=" + path + ) -def test_stats_view_no_crashes(client): +def test_stats_view_no_crashes(client: Client) -> None: """If no crashes in db, an appropriate message is shown.""" client.login(username="test", password="test") response = client.get(reverse(VIEW_NAME)) assert response.status_code == requests.codes["ok"] assert response.context["total_reports_per_hour"] == 0 - assert_contains(response, VIEW_ENTRIES_FMT % 0) + assert_contains(typing.cast(HttpResponse, response), VIEW_ENTRIES_FMT % 0) assert not response.context["frequentBuckets"] -def test_stats_view_with_crash(client, cm): # pylint: disable=invalid-name +def test_stats_view_with_crash( + client: Client, cm: _cm_result +) -> None: # pylint: disable=invalid-name """Insert one crash and check that it is shown ok.""" client.login(username="test", password="test") cm.create_crash(shortSignature="crash #1") response = client.get(reverse(VIEW_NAME)) assert response.status_code == requests.codes["ok"] assert response.context["total_reports_per_hour"] == 1 - assert_contains(response, VIEW_ENTRIES_FMT % 1) + assert_contains(typing.cast(HttpResponse, response), VIEW_ENTRIES_FMT % 1) assert not response.context["frequentBuckets"] -def test_stats_view_with_crashes(client, cm): # pylint: disable=invalid-name +def test_stats_view_with_crashes( + client: Client, cm: _cm_result +) -> None: # pylint: disable=invalid-name """Insert crashes and check that they are shown ok.""" client.login(username="test", password="test") bucket = cm.create_bucket(shortDescription="bucket #1") @@ -66,26 +80,28 @@ def test_stats_view_with_crashes(client, cm): # pylint: disable=invalid-name response = client.get(reverse(VIEW_NAME)) assert response.status_code == requests.codes["ok"] assert response.context["total_reports_per_hour"] == 4 - assert_contains(response, VIEW_ENTRIES_FMT % 4) + assert_contains(typing.cast(HttpResponse, response), VIEW_ENTRIES_FMT % 4) response_buckets = response.context["frequentBuckets"] assert len(response_buckets) == 1 assert response_buckets[0] == bucket assert response_buckets[0].rph == 2 -def test_stats_view_old(client, cm): # pylint: disable=invalid-name +def test_stats_view_old( + client: Client, cm: _cm_result +) -> None: # pylint: disable=invalid-name """Insert one crash in the past and check that it is not shown.""" client.login(username="test", password="test") crash = cm.create_crash(shortSignature="crash #1") response = client.get(reverse(VIEW_NAME)) assert response.status_code == requests.codes["ok"] assert response.context["total_reports_per_hour"] == 1 - assert_contains(response, VIEW_ENTRIES_FMT % 1) + assert_contains(typing.cast(HttpResponse, response), VIEW_ENTRIES_FMT % 1) assert not response.context["frequentBuckets"] crash.created -= datetime.timedelta(hours=1, seconds=1) crash.save() response = client.get(reverse(VIEW_NAME)) assert response.status_code == requests.codes["ok"] assert response.context["total_reports_per_hour"] == 0 - assert_contains(response, VIEW_ENTRIES_FMT % 0) + assert_contains(typing.cast(HttpResponse, response), VIEW_ENTRIES_FMT % 0) assert not response.context["frequentBuckets"] diff --git a/server/crashmanager/tests/test_templates_rest.py b/server/crashmanager/tests/test_templates_rest.py index 99ad5fa62..c3f7ee8cb 100644 --- a/server/crashmanager/tests/test_templates_rest.py +++ b/server/crashmanager/tests/test_templates_rest.py @@ -1,7 +1,15 @@ +from __future__ import annotations + import logging import pytest import requests +from django.contrib.auth.models import User +from rest_framework.test import APIClient + +from crashmanager.models import BugzillaTemplate + +from .conftest import _cm_result LOG = logging.getLogger("fm.crashmanager.tests.templates.rest") @@ -14,7 +22,9 @@ "/crashmanager/rest/bugzilla/templates/1/", ], ) -def test_rest_templates_no_auth(db, api_client, method, url): +def test_rest_templates_no_auth( + db: None, api_client: APIClient, method: str, url: str +) -> None: """must yield unauthorized without authentication""" assert ( getattr(api_client, method)(url, {}).status_code @@ -30,7 +40,9 @@ def test_rest_templates_no_auth(db, api_client, method, url): "/crashmanager/rest/bugzilla/templates/1/", ], ) -def test_rest_templates_no_perm(user_noperm, api_client, method, url): +def test_rest_templates_no_perm( + user_noperm: User, api_client: APIClient, method: str, url: str +) -> None: """must yield forbidden without permission""" assert ( getattr(api_client, method)(url, {}).status_code == requests.codes["forbidden"] @@ -59,7 +71,9 @@ def test_rest_templates_no_perm(user_noperm, api_client, method, url): ], indirect=["user"], ) -def test_rest_templates_methods(api_client, user, method, url): +def test_rest_templates_methods( + api_client: APIClient, user: str, method: str, url: str +) -> None: """must yield method-not-allowed for unsupported methods""" assert ( getattr(api_client, method)(url, {}).status_code @@ -67,7 +81,9 @@ def test_rest_templates_methods(api_client, user, method, url): ) -def _compare_rest_result_to_template(result, template): +def _compare_rest_result_to_template( + result: dict[str, str], template: BugzillaTemplate +) -> None: expected_fields = { "id", "mode", @@ -105,7 +121,7 @@ def _compare_rest_result_to_template(result, template): @pytest.mark.parametrize("user", ["normal", "restricted"], indirect=True) -def test_rest_templates_list(api_client, user, cm): +def test_rest_templates_list(api_client: APIClient, user: str, cm: _cm_result) -> None: """test that list returns the right templates""" expected = 4 templates = [ @@ -132,7 +148,9 @@ def test_rest_templates_list(api_client, user, cm): @pytest.mark.parametrize("user", ["normal", "restricted"], indirect=True) -def test_rest_templates_retrieve(api_client, user, cm): +def test_rest_templates_retrieve( + api_client: APIClient, user: str, cm: _cm_result +) -> None: """test that retrieve returns the right template""" expected = 4 templates = [ diff --git a/server/crashmanager/tests/test_user_settings.py b/server/crashmanager/tests/test_user_settings.py index 994e632b7..cf1c9023c 100644 --- a/server/crashmanager/tests/test_user_settings.py +++ b/server/crashmanager/tests/test_user_settings.py @@ -8,27 +8,37 @@ License, v. 2.0. If a copy of the MPL was not distributed with this file, You can obtain one at http://mozilla.org/MPL/2.0/. """ + +from __future__ import annotations + import logging +import typing import pytest import requests +from django.test.client import Client from django.urls import reverse from crashmanager.models import Tool, User +from .conftest import _cm_result + LOG = logging.getLogger("fm.crashmanager.tests.usersettings") pytestmark = pytest.mark.usefixtures("crashmanager_test") -def test_user_settings_no_login(client): +def test_user_settings_no_login(client: Client) -> None: """Request without login hits the login redirect""" path = reverse("crashmanager:usersettings") resp = client.get(path) assert resp.status_code == requests.codes["found"] - assert resp.url == "/login/?next=" + path + assert ( + typing.cast(typing.Union[str, None], getattr(resp, "url", None)) + == "/login/?next=" + path + ) -def test_user_settings_simple_get(client): +def test_user_settings_simple_get(client: Client) -> None: """No errors are thrown in template""" client.login(username="test", password="test") path = reverse("crashmanager:usersettings") @@ -38,7 +48,7 @@ def test_user_settings_simple_get(client): assert response.context["user"] == User.objects.get(user__username="test").user -def test_user_settings_edit(client, cm): +def test_user_settings_edit(client: Client, cm: _cm_result) -> None: """No errors are thrown in template""" tools = [Tool.objects.create(name="Tool #%d" % (i + 1)) for i in range(2)] providers = [ @@ -69,7 +79,10 @@ def test_user_settings_edit(client, cm): LOG.debug(response) # Redirecting to user settings when the action is successful assert response.status_code == requests.codes["found"] - assert response.url == "/crashmanager/usersettings/" + assert ( + typing.cast(typing.Union[str, None], getattr(response, "url", None)) + == "/crashmanager/usersettings/" + ) user.refresh_from_db() assert Tool.objects.count() == 2 assert list(user.defaultToolsFilter.all()) == list(Tool.objects.all()) diff --git a/server/crashmanager/urls.py b/server/crashmanager/urls.py index ffc34a61f..4505977ac 100644 --- a/server/crashmanager/urls.py +++ b/server/crashmanager/urls.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from django.conf.urls import include from django.urls import re_path from rest_framework import routers diff --git a/server/crashmanager/views.py b/server/crashmanager/views.py index 62b7db955..830eda620 100644 --- a/server/crashmanager/views.py +++ b/server/crashmanager/views.py @@ -1,14 +1,20 @@ +from __future__ import annotations + import json import os from collections import OrderedDict from datetime import datetime, timedelta +from typing import Any, TypeVar, cast from wsgiref.util import FileWrapper from django.conf import settings as django_settings from django.core.exceptions import FieldError, PermissionDenied, SuspiciousOperation -from django.db.models import F, Q +from django.db.models import F, Model, Q from django.db.models.aggregates import Count, Min +from django.db.models.query import QuerySet from django.http import Http404, HttpResponse +from django.http.request import HttpRequest +from django.http.response import HttpResponsePermanentRedirect, HttpResponseRedirect from django.shortcuts import get_object_or_404, redirect, render from django.urls import reverse, reverse_lazy from django.utils import timezone @@ -21,6 +27,7 @@ from rest_framework.decorators import action from rest_framework.exceptions import MethodNotAllowed, ValidationError from rest_framework.filters import BaseFilterBackend, OrderingFilter +from rest_framework.request import Request from rest_framework.response import Response from rest_framework.views import APIView @@ -56,6 +63,8 @@ NotificationSerializer, ) +MT = TypeVar("MT", bound=Model) + class JSONDateEncoder(json.JSONEncoder): def default(self, obj): @@ -64,8 +73,8 @@ def default(self, obj): return super().default(obj) -def check_authorized_for_crash_entry(request, entry): - user = User.get_or_create_restricted(request.user)[0] +def check_authorized_for_crash_entry(request: HttpRequest, entry: CrashEntry) -> None: + user = cast(User, User.get_or_create_restricted(request.user)[0]) if user.restricted: defaultToolsFilter = user.defaultToolsFilter.all() if not defaultToolsFilter or entry.tool not in defaultToolsFilter: @@ -73,11 +82,11 @@ def check_authorized_for_crash_entry(request, entry): {"message": "You don't have permission to access this crash entry."} ) - return + return None -def check_authorized_for_signature(request, signature): - user = User.get_or_create_restricted(request.user)[0] +def check_authorized_for_signature(request: HttpRequest, signature: Bucket) -> None: + user = cast(User, User.get_or_create_restricted(request.user)[0]) if user.restricted: defaultToolsFilter = user.defaultToolsFilter.all() if not defaultToolsFilter: @@ -93,17 +102,19 @@ def check_authorized_for_signature(request, signature): {"message": "You don't have permission to access this signature."} ) - return + return None -def deny_restricted_users(request): - user = User.get_or_create_restricted(request.user)[0] +def deny_restricted_users(request: HttpRequest) -> None: + user = cast(User, User.get_or_create_restricted(request.user)[0]) if user.restricted: raise PermissionDenied({"message": "Restricted users cannot use this feature."}) -def filter_crash_entries_by_toolfilter(request, entries, restricted_only=False): - user = User.get_or_create_restricted(request.user)[0] +def filter_crash_entries_by_toolfilter( + request: HttpRequest, entries: QuerySet[MT], restricted_only: bool = False +) -> QuerySet[MT]: + user = cast(User, User.get_or_create_restricted(request.user)[0]) if restricted_only and not user.restricted: return entries @@ -118,9 +129,12 @@ def filter_crash_entries_by_toolfilter(request, entries, restricted_only=False): def filter_signatures_by_toolfilter( - request, signatures, restricted_only=False, legacy_filters=True -): - user = User.get_or_create_restricted(request.user)[0] + request: HttpRequest, + signatures: QuerySet[MT], + restricted_only: bool = False, + legacy_filters: bool = True, +) -> QuerySet[MT]: + user = cast(User, User.get_or_create_restricted(request.user)[0]) if restricted_only and not user.restricted: return signatures @@ -147,8 +161,10 @@ def filter_signatures_by_toolfilter( return signatures -def filter_bucket_hits_by_toolfilter(request, hits, restricted_only=False): - user = User.get_or_create_restricted(request.user)[0] +def filter_bucket_hits_by_toolfilter( + request: HttpRequest, hits: QuerySet[MT], restricted_only: bool = False +) -> QuerySet[MT]: + user = cast(User, User.get_or_create_restricted(request.user)[0]) if restricted_only and not user.restricted: return hits @@ -162,11 +178,11 @@ def filter_bucket_hits_by_toolfilter(request, hits, restricted_only=False): return hits -def renderError(request, err): +def renderError(request: HttpRequest, err: str) -> HttpResponse: return render(request, "error.html", {"error_message": err}) -def stats(request): +def stats(request: HttpRequest) -> HttpResponse: lastHourDelta = timezone.now() - timedelta(hours=1) print(lastHourDelta) entries = CrashEntry.objects.filter(created__gt=lastHourDelta).select_related( @@ -186,10 +202,10 @@ def stats(request): frequentBuckets = [] if bucketFrequencyMap: - bucketFrequencyMap = sorted( + bucketFrequencyMap_ = sorted( bucketFrequencyMap.items(), key=lambda t: t[1], reverse=True )[:10] - for pk, freq in bucketFrequencyMap: + for pk, freq in bucketFrequencyMap_: obj = Bucket.objects.get(pk=pk) obj.rph = freq frequentBuckets.append(obj) @@ -207,18 +223,18 @@ def stats(request): ) -def settings(request): +def settings(request: HttpRequest) -> HttpResponse: return render(request, "settings.html") -def watchedSignatures(request): +def watchedSignatures(request: HttpRequest) -> HttpResponse: # for this user, list watches # buckets sig new crashes remove # ======================================== # 1 crash 10 tr # 2 assert 0 tr # 3 blah 0 tr - user = User.get_or_create_restricted(request.user)[0] + user = cast(User, User.get_or_create_restricted(request.user)[0]) filters = { "user": user, @@ -243,22 +259,25 @@ def watchedSignatures(request): bucketsAll = bucketsAll.extra( select={"lastCrash": "crashmanager_bucketwatch.lastCrash"} ) - buckets = list(bucketsAll) - for idx, bucket in enumerate(buckets): + buckets_list = list(bucketsAll) + for idx, bucket in enumerate(buckets_list): for newIdx, newBucket in enumerate(newBuckets): if newBucket == bucket: # replace with this one - buckets[idx] = newBucket + buckets_list[idx] = newBucket newBuckets.pop(newIdx) break else: bucket.newCrashes = 0 - return render(request, "signatures/watch.html", {"siglist": buckets}) + return render(request, "signatures/watch.html", {"siglist": buckets_list}) -def deleteBucketWatch(request, sigid): - user = User.get_or_create_restricted(request.user)[0] +def deleteBucketWatch( + request: HttpRequest, sigid: int +) -> HttpResponseRedirect | HttpResponsePermanentRedirect | HttpResponse: + user = cast(User, User.get_or_create_restricted(request.user)[0]) + entry: BucketWatch | Bucket if request.method == "POST": entry = get_object_or_404(BucketWatch, user=user, bucket=sigid) entry.delete() @@ -270,9 +289,11 @@ def deleteBucketWatch(request, sigid): raise SuspiciousOperation() -def newBucketWatch(request): +def newBucketWatch( + request: HttpRequest, +) -> HttpResponseRedirect | HttpResponsePermanentRedirect: if request.method == "POST": - user = User.get_or_create_restricted(request.user)[0] + user = cast(User, User.get_or_create_restricted(request.user)[0]) bucket = get_object_or_404(Bucket, pk=int(request.POST["bucket"])) for watch in BucketWatch.objects.filter(user=user, bucket=bucket): watch.lastCrash = int(request.POST["crash"]) @@ -286,8 +307,8 @@ def newBucketWatch(request): raise SuspiciousOperation() -def bucketWatchCrashes(request, sigid): - user = User.get_or_create_restricted(request.user)[0] +def bucketWatchCrashes(request: HttpRequest, sigid: int) -> HttpResponse: + user = cast(User, User.get_or_create_restricted(request.user)[0]) bucket = get_object_or_404(Bucket, pk=sigid) watch = get_object_or_404(BucketWatch, user=user, bucket=bucket) return render( @@ -297,7 +318,7 @@ def bucketWatchCrashes(request, sigid): ) -def signatures(request): +def signatures(request: HttpRequest) -> HttpResponse: providers = BugProviderSerializer(BugProvider.objects.all(), many=True).data return render( request, @@ -311,16 +332,16 @@ def signatures(request): ) -def index(request): +def index(request: HttpRequest) -> HttpResponse: return redirect("crashmanager:crashes") -def crashes(request): - user = User.get_or_create_restricted(request.user)[0] +def crashes(request: HttpRequest) -> HttpResponse: + user = cast(User, User.get_or_create_restricted(request.user)[0]) return render(request, "crashes/index.html", {"restricted": user.restricted}) -def viewCrashEntry(request, crashid): +def viewCrashEntry(request: HttpRequest, crashid: int) -> HttpResponse: entry = get_object_or_404(CrashEntry, pk=crashid) check_authorized_for_crash_entry(request, entry) entry.deserializeFields() @@ -337,7 +358,9 @@ def viewCrashEntry(request, crashid): ) -def editCrashEntry(request, crashid): +def editCrashEntry( + request: HttpRequest, crashid: int +) -> HttpResponseRedirect | HttpResponsePermanentRedirect | HttpResponse: entry = get_object_or_404(CrashEntry, pk=crashid) check_authorized_for_crash_entry(request, entry) entry.deserializeFields() @@ -361,6 +384,7 @@ def editCrashEntry(request, crashid): if entry.testcase: if entry.testcase.isBinary: if request.POST["testcase"] != "(binary)": + assert entry.testcase.content is not None entry.testcase.content = request.POST["testcase"] entry.testcase.isBinary = False # TODO: The file extension stored on the server remains and is @@ -368,6 +392,7 @@ def editCrashEntry(request, crashid): entry.testcase.storeTestAndSave() else: if request.POST["testcase"] != entry.testcase.content: + assert entry.testcase.content is not None entry.testcase.content = request.POST["testcase"] entry.testcase.storeTestAndSave() @@ -376,7 +401,9 @@ def editCrashEntry(request, crashid): return render(request, "crashes/edit.html", {"entry": entry}) -def deleteCrashEntry(request, crashid): +def deleteCrashEntry( + request: HttpRequest, crashid: int +) -> HttpResponseRedirect | HttpResponsePermanentRedirect | HttpResponse: entry = get_object_or_404(CrashEntry, pk=crashid) check_authorized_for_crash_entry(request, entry) @@ -389,7 +416,7 @@ def deleteCrashEntry(request, crashid): raise SuspiciousOperation -def newSignature(request): +def newSignature(request: HttpRequest) -> HttpResponse: if request.method != "GET": raise SuspiciousOperation @@ -442,11 +469,11 @@ def newSignature(request): errorMsg = crashInfo.failureReason proposedSignature = crashInfo.createCrashSignature(maxFrames=maxStackFrames) - proposedSignature = str(proposedSignature) + proposedSignature_str = str(proposedSignature) proposedShortDesc = crashInfo.createShortSignature() data = { - "proposedSig": json.loads(proposedSignature), + "proposedSig": json.loads(proposedSignature_str), "proposedDesc": proposedShortDesc, "warningMessage": errorMsg, } @@ -454,7 +481,9 @@ def newSignature(request): return render(request, "signatures/edit.html", data) -def deleteSignature(request, sigid): +def deleteSignature( + request: HttpRequest, sigid: int +) -> HttpResponseRedirect | HttpResponsePermanentRedirect | HttpResponse: user = User.get_or_create_restricted(request.user)[0] if user.restricted: raise PermissionDenied( @@ -504,7 +533,7 @@ def deleteSignature(request, sigid): raise SuspiciousOperation -def viewSignature(request, sigid): +def viewSignature(request: HttpRequest, sigid: int) -> HttpResponse: response = BucketVueViewSet.as_view({"get": "retrieve"})(request, pk=sigid) if response.status_code == 404: raise Http404 @@ -535,7 +564,7 @@ def viewSignature(request, sigid): ) -def editSignature(request, sigid): +def editSignature(request: HttpRequest, sigid: int) -> HttpResponse: if request.method != "GET" or sigid is None: raise SuspiciousOperation @@ -554,7 +583,7 @@ def editSignature(request, sigid): ) -def trySignature(request, sigid, crashid): +def trySignature(request: HttpRequest, sigid: int, crashid: int) -> HttpResponse: bucket = get_object_or_404(Bucket, pk=sigid) check_authorized_for_signature(request, bucket) @@ -572,7 +601,7 @@ def trySignature(request, sigid, crashid): ) -def optimizeSignature(request, sigid): +def optimizeSignature(request: HttpRequest, sigid: int) -> HttpResponse: bucket = get_object_or_404(Bucket, pk=sigid) check_authorized_for_signature(request, bucket) @@ -603,7 +632,7 @@ def optimizeSignature(request, sigid): ) -def optimizeSignaturePrecomputed(request, sigid): +def optimizeSignaturePrecomputed(request: HttpRequest, sigid: int) -> HttpResponse: bucket = get_object_or_404(Bucket, pk=sigid) check_authorized_for_signature(request, bucket) @@ -651,7 +680,7 @@ def optimizeSignaturePrecomputed(request, sigid): ) -def findSignatures(request, crashid): +def findSignatures(request: HttpRequest, crashid: int) -> HttpResponse: entry = get_object_or_404(CrashEntry, pk=crashid) check_authorized_for_crash_entry(request, entry) @@ -765,7 +794,7 @@ def findSignatures(request, crashid): ) -def createExternalBug(request, crashid): +def createExternalBug(request: HttpRequest, crashid: int) -> HttpResponse: entry = get_object_or_404(CrashEntry, pk=crashid) check_authorized_for_crash_entry(request, entry) @@ -782,7 +811,7 @@ def createExternalBug(request, crashid): if "provider" in request.GET: provider = get_object_or_404(BugProvider, pk=request.GET["provider"]) else: - user = User.get_or_create_restricted(request.user)[0] + user = cast(User, User.get_or_create_restricted(request.user)[0]) provider = get_object_or_404(BugProvider, pk=user.defaultProviderId) template = provider.getInstance().getTemplateForUser(request, entry) @@ -797,7 +826,7 @@ def createExternalBug(request, crashid): raise SuspiciousOperation -def createExternalBugComment(request, crashid): +def createExternalBugComment(request: HttpRequest, crashid: int) -> HttpResponse: entry = get_object_or_404(CrashEntry, pk=crashid) check_authorized_for_crash_entry(request, entry) @@ -805,7 +834,7 @@ def createExternalBugComment(request, crashid): if "provider" in request.GET: provider = get_object_or_404(BugProvider, pk=request.GET["provider"]) else: - user = User.get_or_create_restricted(request.user)[0] + user = cast(User, User.get_or_create_restricted(request.user)[0]) provider = get_object_or_404(BugProvider, pk=user.defaultProviderId) template = provider.getInstance().getTemplateForUser(request, entry) @@ -820,12 +849,14 @@ def createExternalBugComment(request, crashid): raise SuspiciousOperation -def viewBugProviders(request): +def viewBugProviders(request: HttpRequest) -> HttpResponse: providers = BugProvider.objects.annotate(size=Count("bug")) return render(request, "providers/index.html", {"providers": providers}) -def deleteBugProvider(request, providerId): +def deleteBugProvider( + request: HttpRequest, providerId: int +) -> HttpResponseRedirect | HttpResponsePermanentRedirect | HttpResponse: deny_restricted_users(request) provider = get_object_or_404(BugProvider, pk=providerId) @@ -845,7 +876,7 @@ def deleteBugProvider(request, providerId): raise SuspiciousOperation -def viewBugProvider(request, providerId): +def viewBugProvider(request: HttpRequest, providerId: int) -> HttpResponse: provider = BugProvider.objects.filter(pk=providerId).annotate(size=Count("bug")) if not provider: @@ -856,7 +887,9 @@ def viewBugProvider(request, providerId): return render(request, "providers/view.html", {"provider": provider}) -def editBugProvider(request, providerId): +def editBugProvider( + request: HttpRequest, providerId: int +) -> HttpResponseRedirect | HttpResponsePermanentRedirect | HttpResponse: deny_restricted_users(request) provider = get_object_or_404(BugProvider, pk=providerId) @@ -882,7 +915,9 @@ def editBugProvider(request, providerId): raise SuspiciousOperation -def createBugProvider(request): +def createBugProvider( + request: HttpRequest, +) -> HttpResponseRedirect | HttpResponsePermanentRedirect | HttpResponse: deny_restricted_users(request) if request.method == "POST": @@ -909,7 +944,9 @@ def createBugProvider(request): raise SuspiciousOperation -def duplicateBugzillaTemplate(request, templateId): +def duplicateBugzillaTemplate( + request: HttpRequest, templateId: int +) -> HttpResponseRedirect | HttpResponsePermanentRedirect: clone = get_object_or_404(BugzillaTemplate, pk=templateId) clone.pk = None # to autogen a new pk on save() clone.name = "Clone of " + clone.name @@ -923,7 +960,9 @@ class JsonQueryFilterBackend(BaseFilterBackend): (see json_to_query) """ - def filter_queryset(self, request, queryset, view): + def filter_queryset( + self, request: Request, queryset: QuerySet[MT], view: APIView + ) -> QuerySet[MT]: """ Return a filtered queryset. """ @@ -946,11 +985,13 @@ class ToolFilterCrashesBackend(BaseFilterBackend): given. Only unrestricted users can use ignore_toolfilter. """ - def filter_queryset(self, request, queryset, view): + def filter_queryset( + self, request: Request, queryset: QuerySet[MT], view: APIView + ) -> QuerySet[MT]: """Return a filtered queryset""" - ignore_toolfilter = request.query_params.get("ignore_toolfilter", "0") + ignore_toolfilter_requested = request.query_params.get("ignore_toolfilter", "0") try: - ignore_toolfilter = int(ignore_toolfilter) + ignore_toolfilter = int(ignore_toolfilter_requested) assert ignore_toolfilter in {0, 1} except (AssertionError, ValueError): raise InvalidArgumentException({"ignore_toolfilter": ["Expecting 0 or 1."]}) @@ -965,7 +1006,9 @@ def filter_queryset(self, request, queryset, view): class WatchFilterCrashesBackend(BaseFilterBackend): """Filters the queryset to retrieve watched entries if '?watch='""" - def filter_queryset(self, request, queryset, view): + def filter_queryset( + self, request: Request, queryset: QuerySet[MT], view: APIView + ) -> QuerySet[MT]: watch_id = request.query_params.get("watch", "false").lower() if watch_id == "false": return queryset @@ -979,11 +1022,13 @@ class ToolFilterSignaturesBackend(BaseFilterBackend): given. Only unrestricted users can use ignore_toolfilter. """ - def filter_queryset(self, request, queryset, view): + def filter_queryset( + self, request: Request, queryset: QuerySet[MT], view: APIView + ) -> QuerySet[MT]: """Return a filtered queryset""" - ignore_toolfilter = request.query_params.get("ignore_toolfilter", "0") + ignore_toolfilter_requested = request.query_params.get("ignore_toolfilter", "0") try: - ignore_toolfilter = int(ignore_toolfilter) + ignore_toolfilter = int(ignore_toolfilter_requested) assert ignore_toolfilter in {0, 1} except (AssertionError, ValueError): raise InvalidArgumentException({"ignore_toolfilter": ["Expecting 0 or 1."]}) @@ -999,7 +1044,9 @@ def filter_queryset(self, request, queryset, view): class BucketAnnotateFilterBackend(BaseFilterBackend): """Annotates bucket queryset with size and best_quality""" - def filter_queryset(self, request, queryset, view): + def filter_queryset( + self, request: Request, queryset: QuerySet[MT], view: APIView + ) -> QuerySet[MT]: return queryset.annotate( size=Count("crashentry"), quality=Min("crashentry__testcase__quality") ) @@ -1008,10 +1055,12 @@ def filter_queryset(self, request, queryset, view): class DeferRawFilterBackend(BaseFilterBackend): """Optionally defer raw fields""" - def filter_queryset(self, request, queryset, view): - include_raw = request.query_params.get("include_raw", "1") + def filter_queryset( + self, request: Request, queryset: QuerySet[MT], view: APIView + ) -> QuerySet[MT]: + include_raw_requested = request.query_params.get("include_raw", "1") try: - include_raw = int(include_raw) + include_raw = int(include_raw_requested) assert include_raw in {0, 1} except (AssertionError, ValueError): raise InvalidArgumentException({"include_raw": ["Expecting 0 or 1."]}) @@ -1044,7 +1093,7 @@ class CrashEntryViewSet( DeferRawFilterBackend, ] - def get_serializer(self, *args, **kwds): + def get_serializer(self, *args: Any, **kwds: Any): kwds["include_raw"] = getattr(self, "include_raw", True) vue = self.request.query_params.get("vue", "false").lower() not in ( "false", @@ -1055,10 +1104,11 @@ def get_serializer(self, *args, **kwds): else: return super().get_serializer(*args, **kwds) - def partial_update(self, request, pk=None): + def partial_update(self, request: Request, pk: int | None = None) -> Response: """Update individual crash fields.""" - user = User.get_or_create_restricted(request.user)[0] + user = cast(User, User.get_or_create_restricted(request.user)[0]) if user.restricted: + assert request.method is not None raise MethodNotAllowed(request.method) allowed_fields = {"testcase_quality"} @@ -1115,7 +1165,7 @@ class BucketViewSet( ] pagination_class = None - def get_serializer(self, *args, **kwds): + def get_serializer(self, *args: Any, **kwds: Any): self.vue = self.request.query_params.get("vue", "false").lower() not in ( "false", "0", @@ -1125,7 +1175,7 @@ def get_serializer(self, *args, **kwds): else: return super().get_serializer(*args, **kwds) - def list(self, request, *args, **kwargs): + def list(self, request: Request, *args: Any, **kwargs: Any) -> Response: response = super().list(request, *args, **kwargs) if self.vue and response.status_code == 200: @@ -1162,8 +1212,8 @@ def list(self, request, *args, **kwargs): return response - def retrieve(self, request, *args, **kwargs): - user = User.get_or_create_restricted(request.user)[0] + def retrieve(self, request: Request, *args: str, **kwargs: str) -> Response: + user = cast(User, User.get_or_create_restricted(request.user)[0]) instance = self.get_object() ignore_toolfilter = getattr(self, "ignore_toolfilter", False) @@ -1191,10 +1241,12 @@ def retrieve(self, request, *args, **kwargs): .order_by("testcase__size", "-id") .first() ) + assert best_crash is not None instance.best_entry = best_crash.id if instance.size: latest_crash = crashes_in_filter.order_by("id").last() + assert latest_crash is not None instance.latest_entry = latest_crash.id serializer = self.get_serializer(instance) @@ -1221,7 +1273,9 @@ def retrieve(self, request, *args, **kwargs): return response - def __validate(self, request, bucket, submitSave, reassign): + def __validate( + self, request: Request, bucket: Bucket, submitSave: bool, reassign: bool + ): try: bucket.getSignature() except RuntimeError as e: @@ -1241,8 +1295,8 @@ def __validate(self, request, bucket, submitSave, reassign): bucket.bug = bucket.bug bucket.save() - inList, outList = [], [] - inListCount, outListCount = 0, 0 + inList: list[str] = [] + outList: list[str] = [] # If the reassign checkbox is checked if reassign: inList, outList, inListCount, outListCount = bucket.reassign(submitSave) @@ -1260,12 +1314,14 @@ def __validate(self, request, bucket, submitSave, reassign): "outListCount": outListCount, } - def update(self, request, *args, **kwargs): + def update(self, request: Request, *args: str, **kwargs: str) -> Response: + assert request.method is not None raise MethodNotAllowed(request.method) - def partial_update(self, request, *args, **kwargs): - user = User.get_or_create_restricted(request.user)[0] + def partial_update(self, request: Request, *args: str, **kwargs: str) -> Response: + user = cast(User, User.get_or_create_restricted(request.user)[0]) if user.restricted: + assert request.method is not None raise MethodNotAllowed(request.method) serializer = self.get_serializer(data=request.data, partial=True) @@ -1274,6 +1330,7 @@ def partial_update(self, request, *args, **kwargs): bucket = get_object_or_404(Bucket, id=self.kwargs["pk"]) check_authorized_for_signature(request, bucket) + bug: QuerySet[Bug] | Bug | None if "bug" in serializer.validated_data: if serializer.validated_data["bug"] is None: bug = None @@ -1314,7 +1371,7 @@ def partial_update(self, request, *args, **kwargs): data=data, ) - def create(self, request, *args, **kwargs): + def create(self, request: Request, *args: str, **kwargs: str) -> Response: serializer = self.get_serializer(data=request.data) serializer.is_valid(raise_exception=True) @@ -1341,7 +1398,7 @@ def create(self, request, *args, **kwargs): class BucketVueViewSet(BucketViewSet): """API endpoint that allows viewing Buckets and always uses Vue serializer""" - def get_serializer(self, *args, **kwds): + def get_serializer(self, *args: Any, **kwds: Any) -> BucketVueSerializer: self.vue = True return BucketVueSerializer(*args, **kwds) @@ -1379,11 +1436,11 @@ class NotificationViewSet(mixins.ListModelMixin, viewsets.GenericViewSet): JsonQueryFilterBackend, ] - def get_queryset(self): + def get_queryset(self) -> QuerySet[Notification]: return Notification.objects.unread().filter(recipient=self.request.user) @action(detail=True, methods=["patch"]) - def mark_as_read(self, request, pk=None): + def mark_as_read(self, request: Request, pk: int | None = None) -> Response: notification = self.get_object() if notification.recipient != request.user: @@ -1393,13 +1450,13 @@ def mark_as_read(self, request, pk=None): return Response(status=status.HTTP_200_OK) @action(detail=False, methods=["patch"]) - def mark_all_as_read(self, request): + def mark_all_as_read(self, request: Request) -> Response: notifications = self.get_queryset() notifications.mark_all_as_read() return Response(status=status.HTTP_200_OK) -def json_to_query(json_str): +def json_to_query(json_str: str): """ This method converts JSON objects into trees of Django Q objects. It can be used to provide the user the ability to perform complex @@ -1428,7 +1485,7 @@ def json_to_query(json_str): except ValueError as e: raise RuntimeError(f"Invalid JSON: {e}") - def get_query_obj(obj, key=None): + def get_query_obj(obj, key: str | None = None) -> Q: if obj is None or isinstance(obj, (str, list, int)): kwargs = {key: obj} @@ -1470,7 +1527,12 @@ class AbstractDownloadView(APIView): authentication_classes = (TokenAuthentication, SessionAuthentication) permission_classes = (CheckAppPermission,) - def response(self, file_path, filename, content_type="application/octet-stream"): + def response( + self, + file_path: str, + filename: str, + content_type: str = "application/octet-stream", + ) -> HttpResponse: if not os.path.exists(file_path): return HttpResponse(status=404) @@ -1481,12 +1543,12 @@ def response(self, file_path, filename, content_type="application/octet-stream") response["Content-Disposition"] = f'attachment; filename="{filename}"' return response - def get(self): + def get(self, _request: HttpRequest, _crashid: int) -> HttpResponse: return HttpResponse(status=500) class TestDownloadView(AbstractDownloadView): - def get(self, request, crashid): + def get(self, request: HttpRequest, crashid: int) -> HttpResponse: storage_base = getattr(django_settings, "TEST_STORAGE", None) if not storage_base: # This is a misconfiguration @@ -1503,7 +1565,7 @@ def get(self, request, crashid): class SignaturesDownloadView(AbstractDownloadView): - def get(self, request, format=None): + def get(self, request: HttpRequest, format: int | None = None) -> HttpResponse: deny_restricted_users(request) storage_base = getattr(django_settings, "SIGNATURE_STORAGE", None) @@ -1517,7 +1579,7 @@ def get(self, request, format=None): return self.response(file_path, filename) -class BugzillaTemplateListView(ListView): +class BugzillaTemplateListView(ListView[BugzillaTemplate]): model = BugzillaTemplate template_name = "bugzilla/list.html" paginate_by = 100 @@ -1536,12 +1598,14 @@ class BugzillaTemplateEditView(UpdateView): success_url = reverse_lazy("crashmanager:templates") pk_url_kwarg = "templateId" - def get_context_data(self, **kwargs): + def get_context_data(self, **kwargs: Any): context = super().get_context_data(**kwargs) context["title"] = "Edit template" return context - def get_form_class(self): + def get_form_class( + self, + ) -> type[BugzillaTemplateBugForm] | type[BugzillaTemplateCommentForm]: if self.object.mode == BugzillaTemplateMode.Bug: return BugzillaTemplateBugForm else: @@ -1554,23 +1618,23 @@ class BugzillaTemplateBugCreateView(CreateView): form_class = BugzillaTemplateBugForm success_url = reverse_lazy("crashmanager:templates") - def get_context_data(self, **kwargs): + def get_context_data(self, **kwargs: Any): context = super().get_context_data(**kwargs) context["title"] = "Create a bug template" return context - def form_valid(self, form): + def form_valid(self, form) -> HttpResponse: form.instance.mode = BugzillaTemplateMode.Bug return super().form_valid(form) -class BugzillaTemplateCommentCreateView(CreateView): +class BugzillaTemplateCommentCreateView(CreateView[BugzillaTemplate]): model = BugzillaTemplate template_name = "bugzilla/create_edit.html" form_class = BugzillaTemplateCommentForm success_url = reverse_lazy("crashmanager:templates") - def get_context_data(self, **kwargs): + def get_context_data(self, **kwargs: Any): context = super().get_context_data(**kwargs) context["title"] = "Create a comment template" return context @@ -1580,13 +1644,13 @@ def form_valid(self, form): return super().form_valid(form) -class UserSettingsEditView(UpdateView): +class UserSettingsEditView(UpdateView[User]): model = User template_name = "usersettings.html" form_class = UserSettingsForm success_url = reverse_lazy("crashmanager:usersettings") - def get_form_kwargs(self, **kwargs): + def get_form_kwargs(self, **kwargs: Any): kwargs = super().get_form_kwargs(**kwargs) kwargs["user"] = self.get_queryset().get(user=self.request.user) return kwargs @@ -1594,7 +1658,7 @@ def get_form_kwargs(self, **kwargs): def get_object(self): return self.get_queryset().get(user=self.request.user) - def get_context_data(self, **kwargs): + def get_context_data(self, **kwargs: Any): context = super().get_context_data(**kwargs) context["bugzilla_providers"] = BugProvider.objects.filter( classname="BugzillaProvider" diff --git a/server/ec2spotmanager/CloudProvider/CloudProvider.py b/server/ec2spotmanager/CloudProvider/CloudProvider.py index 92e59028c..f9b827bdd 100644 --- a/server/ec2spotmanager/CloudProvider/CloudProvider.py +++ b/server/ec2spotmanager/CloudProvider/CloudProvider.py @@ -11,11 +11,18 @@ @contact: truber@mozilla.com """ + +from __future__ import annotations + import functools import logging import ssl import traceback from abc import ABCMeta, abstractmethod +from decimal import Decimal +from typing import Any, Callable, TypeVar + +from ec2spotmanager.models import PoolConfiguration INSTANCE_STATE_CODE = { -1: "requested", @@ -26,34 +33,35 @@ 64: "stopping", 80: "stopped", } -INSTANCE_STATE = {val: key for key, val in INSTANCE_STATE_CODE.items()} +INSTANCE_STATE: dict[str, int] = {val: key for key, val in INSTANCE_STATE_CODE.items()} # List of currently supported providers. This and what is returned by get_name() must # match PROVIDERS = ["EC2Spot", "GCE"] +RetType = TypeVar("RetType") class CloudProviderError(Exception): - TYPE = "unclassified" + TYPE: str = "unclassified" - def __init__(self, message): + def __init__(self, message: str) -> None: self.message = message - def __str__(self): + def __str__(self) -> str: return f"{type(self).__name__}: {self.message} ({self.TYPE})" class CloudProviderTemporaryFailure(CloudProviderError): - TYPE = "temporary-failure" + TYPE: str = "temporary-failure" class CloudProviderInstanceCountError(CloudProviderError): - TYPE = "max-spot-instance-count-exceeded" + TYPE: str = "max-spot-instance-count-exceeded" -def wrap_provider_errors(wrapped): +def wrap_provider_errors(wrapped: Callable[..., RetType]) -> Callable[..., RetType]: @functools.wraps(wrapped) - def wrapper(*args, **kwds): + def wrapper(*args: Any, **kwds: Any) -> RetType: try: return wrapped(*args, **kwds) except (ssl.SSLError, OSError) as exc: @@ -77,61 +85,49 @@ class CloudProvider(metaclass=ABCMeta): """ @abstractmethod - def terminate_instances(self, instances_ids_by_region): + def terminate_instances(self, instances_ids_by_region: dict[str, int]) -> None: """ Take a list of running instances and stop them in the cloud provider. - @ptype instances_ids_by_region: dictionary @param instances_ids_by_region: keys are regions and values are instances. - - @rtype: none - @return: none """ return @abstractmethod - def cancel_requests(self, requested_instances_by_region): + def cancel_requests(self, requested_instances_by_region: dict[str, int]) -> None: """ Cancel requests that have not become running instances. - @ptype requested_instances_region: dictionary @param requested_instances_region: keys are regions and values are request ids. """ return @abstractmethod def start_instances( - self, config, region, zone, userdata, image, instance_type, count, tags - ): + self, + config: PoolConfiguration, + region: str, + zone: str, + userdata, + image: str, + instance_type: str, + count: int, + tags: dict[str, str], + ) -> dict[str, Any]: """ Start instances using specified configuration. - @ptype config: FlatObject @param config: flattened config. We use this for any cloud provider specific fields needed to create an instance - - @ptype region: string @param region: region where instances are to be started - - @ptype zone: string @param zone: zone the instances will be started in @ptype userdata: UserData object @param userdata: userdata script for instances - - @ptype image: string @param image: image reference used to start instances - - @ptype instance_type: string @param instance_type: type of instance - - @ptype count: int @param count: number of instances to start - - @ptype tags: dictionary @param tags: instance tags. - - @rtype: list @return: Request IDs given to us by the cloud provider. This can be the instance ID if the provider does not use different IDs for instances and requests. @@ -139,7 +135,9 @@ def start_instances( return @abstractmethod - def check_instances_requests(self, region, instances, tags): + def check_instances_requests( + self, region: str, instances: list[str], tags: dict[str, str] + ) -> tuple[dict[str, str], dict[str, str]]: """ Take a list of requested instances and determines the state of each instance. Since this is the first point we see an actual running instance @@ -152,182 +150,142 @@ def check_instances_requests(self, region, instances, tags): Failed requests will have an action and instance type. Currently, we support actions of 'blacklist' and disable_pool. - @ptype region: string @param region: the region the instances are in - - @ptype instances: list - @param isntances: instance request IDs - - @ptype tags: dictionary + @param instances: instance request IDs @param tags: instance tags. - - @rtype: tuple @return: tuple containing 2 dicts: successful request IDs and failed request IDs """ return @abstractmethod - def check_instances_state(self, pool_id, region): + def check_instances_state(self, pool_id: int, region: str) -> None: """ Takes a pool ID, searches the cloud provider for instances in that pool (using the tag) and returns a dictionary of instances with their state as value. - @ptype pool_id: int @param list of pool instances are located in. We search for instances using the poolID tag - - @ptype region: string @param region: region where instances are located - - @rtype: dictionary @return: running instances and their states. State must comply with INSTANCE_STATE defined in CloudProvider """ return @abstractmethod - def get_image(self, region, config): + def get_image(self, region: str, config: PoolConfiguration) -> str | None: """ Takes a configuration and returns a provider specific image name. - @ptype region: string @param region: region - - @ptype config: FlatObject @param config: flattened config - - @rtype: string @return: cloud provider ID for image """ return @staticmethod @abstractmethod - def get_cores_per_instance(): + def get_cores_per_instance() -> dict[str, int]: """ returns dictionary of instance types and their number of cores - @rtype: dictionary @return: instance types and how many cores per instance type """ return @staticmethod @abstractmethod - def get_allowed_regions(config): + def get_allowed_regions(config: PoolConfiguration) -> list[str]: """ Takes a configuration and returns cloud provider specific regions. - @ptype config: FlatObject @param config: pulling regions from config - - @rtype: list @return: regions pulled from config """ return @staticmethod @abstractmethod - def get_image_name(config): + def get_image_name(config: PoolConfiguration) -> str | None: """ Takes a configuration and returns cloud provider specific image name. - @ptype config: FlatObject @param config: pulling image name from config - - @rtype: string @return: cloud specific image name from config """ return @staticmethod @abstractmethod - def get_instance_types(config): + def get_instance_types(config: PoolConfiguration) -> str: """ Takes a configuration and returns a list of cloud provider specific instance_types. - @ptype config: FlatObject @param config: pulling instance types from config - - @rtype: list @return: list of cloud specific instance_types from config """ return @staticmethod @abstractmethod - def get_max_price(config): + def get_max_price(config: PoolConfiguration) -> Decimal: """ Takes a configuration and returns the cloud provider specific max_price. - @ptype config: FlatObject @param config: pulling max_price from config - - @rtype: float @return: cloud specific max_price """ return @staticmethod @abstractmethod - def get_tags(config): + def get_tags(config: PoolConfiguration) -> str: """ Takes a configuration and returns a dictionary of cloud provider specific tags. - @ptype config: FlatObject @param config: pulling tags field - - @rtype: dictionary @return: cloud specific tags field """ return @staticmethod @abstractmethod - def get_name(): + def get_name() -> str: """ used to return name of cloud provider - @rtype: string @return: string representation of the cloud provider """ return @staticmethod @abstractmethod - def config_supported(config): + def config_supported(config: PoolConfiguration) -> bool: """Compares the fields provided in the config with those required by the cloud provider. If any field is missing, return False. - @ptype config: FlatObject @param config: Flattened config - - @rtype: bool @return: True if all required cloud specific fields in config """ return @abstractmethod - def get_prices_per_region(self, region_name, instance_types): + def get_prices_per_region( + self, region_name: str, instance_types: list[str] | None + ) -> dict[str, dict[str, dict[str, float]]]: """ takes region and instance_types and returns a dictionary of prices prices are stored with keys like 'provider:price:{instance-type}' and values {region: {zone (if used): [price value, ...]}} - @ptype region_name: string @param region_name: region to grab prices - - @ptype instance_types: list @param instance_types: list of instance_types - - @rtype: dictionary @return: dictionary of prices as specified above. """ return @staticmethod - def get_instance(provider): + def get_instance(provider: str): """ This is a method that is used to instantiate the provider class. """ diff --git a/server/ec2spotmanager/CloudProvider/EC2SpotCloudProvider.py b/server/ec2spotmanager/CloudProvider/EC2SpotCloudProvider.py index 7acb5e740..dfdbe549a 100644 --- a/server/ec2spotmanager/CloudProvider/EC2SpotCloudProvider.py +++ b/server/ec2spotmanager/CloudProvider/EC2SpotCloudProvider.py @@ -1,6 +1,9 @@ +from __future__ import annotations + import datetime import logging import re +from typing import Any, cast import boto3 import boto.ec2 @@ -23,15 +26,15 @@ class EC2SpotCloudProvider(CloudProvider): - def __init__(self): + def __init__(self) -> None: self.logger = logging.getLogger("ec2spotmanager") - self.cluster = None - self.connected_region = None + self.cluster: EC2Manager | None = None + self.connected_region: str | None = None - def _connect(self, region): + def _connect(self, region: str) -> EC2Manager: if self.connected_region != region: - self.cluster = EC2Manager( - None + self.cluster = cast( + EC2Manager, EC2Manager(None) ) # create a new Manager to invalidate cached image names, etc. self.cluster.connect( region=region, @@ -42,7 +45,7 @@ def _connect(self, region): return self.cluster @wrap_provider_errors - def terminate_instances(self, instances_ids_by_region): + def terminate_instances(self, instances_ids_by_region: dict[str, Any]) -> None: for region, instance_ids in instances_ids_by_region.items(): cluster = self._connect(region) self.logger.info( @@ -72,7 +75,7 @@ def terminate_instances(self, instances_ids_by_region): cluster.terminate(boto_instances) @wrap_provider_errors - def cancel_requests(self, requested_instances_by_region): + def cancel_requests(self, requested_instances_by_region: dict[str, Any]) -> None: for region, instance_ids in requested_instances_by_region.items(): cluster = self._connect(region) cluster.cancel_spot_requests(instance_ids) @@ -83,9 +86,17 @@ def cancel_requests(self, requested_instances_by_region): @wrap_provider_errors def start_instances( - self, config, region, zone, userdata, image, instance_type, count, _tags - ): - images = self._create_laniakea_images(config) + self, + config, + region: str, + zone: str, + userdata, + image: str, + instance_type: str, + count: str, + _tags: str, + ) -> dict[str, Any]: + images: dict[str, dict[str, str]] = self._create_laniakea_images(config) self.logger.info( "Using instance type %s in region %s with availability zone %s.", @@ -157,8 +168,8 @@ def start_instances( @wrap_provider_errors def check_instances_requests(self, region, instances, tags): - successful_requests = {} - failed_requests = {} + successful_requests: dict[str, dict[str, str]] = {} + failed_requests: dict[str, dict[str, str]] = {} cluster = self._connect(region) try: @@ -252,9 +263,9 @@ def check_instances_requests(self, region, instances, tags): return (successful_requests, failed_requests) @wrap_provider_errors - def check_instances_state(self, pool_id, region): + def check_instances_state(self, pool_id: int | None, region: str): - instance_states = {} + instance_states: dict[str, Any] = {} cluster = self._connect(region) try: @@ -284,13 +295,13 @@ def check_instances_state(self, pool_id, region): return instance_states @wrap_provider_errors - def get_image(self, region, config): + def get_image(self, region: str, config) -> str: cluster = self._connect(region) - ami = cluster.resolve_image_name(config.ec2_image_name) + ami = cast(str, cluster.resolve_image_name(config.ec2_image_name)) return ami @staticmethod - def get_cores_per_instance(): + def get_cores_per_instance() -> dict[str, int]: return CORES_PER_INSTANCE @staticmethod @@ -314,11 +325,11 @@ def get_tags(config): return config.instance_tags @staticmethod - def get_name(): + def get_name() -> str: return "EC2Spot" @staticmethod - def config_supported(config): + def config_supported(config) -> bool: fields = [ "ec2_allowed_regions", "max_price", @@ -330,9 +341,11 @@ def config_supported(config): return all(config.get(key) for key in fields) @wrap_provider_errors - def get_prices_per_region(self, region_name, instance_types=None): + def get_prices_per_region( + self, region_name: str, instance_types: str | None = None + ) -> dict[str, Any]: """Gets spot prices of the specified region and instance type""" - prices = {} # {instance-type: region: {az: [prices]}}} + prices: dict[str, Any] = {} # {instance-type: region: {az: [prices]}}} zone_blacklist = ["us-east-1a", "us-east-1f"] now = timezone.now() @@ -369,8 +382,8 @@ def get_prices_per_region(self, region_name, instance_types=None): return prices @staticmethod - def _create_laniakea_images(config): - images = {"default": {}} + def _create_laniakea_images(config) -> dict[str, dict[str, str]]: + images: dict[str, dict[str, str]] = {"default": {}} # These are the configuration keys we want to put into the target configuration # without further preprocessing, except for the adjustment of the key name diff --git a/server/ec2spotmanager/CloudProvider/GCECloudProvider.py b/server/ec2spotmanager/CloudProvider/GCECloudProvider.py index 3306a4525..ff0589400 100644 --- a/server/ec2spotmanager/CloudProvider/GCECloudProvider.py +++ b/server/ec2spotmanager/CloudProvider/GCECloudProvider.py @@ -11,8 +11,13 @@ @contact: truber@mozilla.com """ + +from __future__ import annotations + import logging import time +from decimal import Decimal +from typing import Any import requests import yaml @@ -23,6 +28,7 @@ ) from ..common.gce import CORES_PER_INSTANCE, RAM_PER_INSTANCE +from ..models import PoolConfiguration from ..tasks import SPOTMGR_TAG from .CloudProvider import ( INSTANCE_STATE, @@ -42,7 +48,7 @@ class _LowercaseDict(dict): - def __init__(self, *args, **kwds): + def __init__(self, *args: Any, **kwds: Any) -> None: super().__init__() if len(args) > 1: raise TypeError(f"dict expected at most 1 arguments, got {len(args)}") @@ -57,13 +63,13 @@ def __init__(self, *args, **kwds): for (k, v) in kwds.items(): self[k] = v - def __getitem__(self, key): + def __getitem__(self, key: str): return super().__getitem__(key.lower()) - def __setitem__(self, key, value): + def __setitem__(self, key: str, value: str): return super().__setitem__(key.lower(), value.lower()) - def __delitem__(self, key): + def __delitem__(self, key: str): return super().__delitem__(key.lower()) def __contains__(self, item): @@ -82,11 +88,11 @@ class GCECloudProvider(CloudProvider): "UNKNOWN": INSTANCE_STATE["pending"], } - def __init__(self): + def __init__(self) -> None: self.logger = logging.getLogger("ec2spotmanager.gce") self.cluster = None - def _connect(self): + def _connect(self) -> ComputeEngineManager: if self.cluster is None: self.cluster = ComputeEngineManager( settings.GCE_CLIENT_EMAIL, @@ -96,6 +102,7 @@ def _connect(self): retries = [1, 5, 10, 30, None] for retry in retries: try: + assert self.cluster is not None self.cluster.connect(credential_file=settings.GCE_AUTH_CACHE) break except ComputeEngineManagerException as error: @@ -109,7 +116,9 @@ def _connect(self): return self.cluster @wrap_provider_errors - def terminate_instances(self, instances_ids_by_region): + def terminate_instances( + self, instances_ids_by_region: dict[str, list[int]] + ) -> None: for region, instance_ids in instances_ids_by_region.items(): assert ( region == "global" @@ -129,7 +138,9 @@ def terminate_instances(self, instances_ids_by_region): cluster.terminate_nowait(nodes) @wrap_provider_errors - def cancel_requests(self, requested_instances_by_region): + def cancel_requests( + self, requested_instances_by_region: dict[str, list[int]] + ) -> None: # no difference in how pending nodes are terminated in GCE self.logger.info( "Canceling %d requests in GCE", len(requested_instances_by_region) @@ -143,12 +154,20 @@ def _node_to_instance(self, node): ".".join(reversed(ip_addr.split("."))) + ".bc.googleusercontent.com" ) instance["instance_id"] = node.name - instance["status_code"] = self.NODE_STATE_MAP[node.extra["status"]] + instance["status_code"] = str(self.NODE_STATE_MAP[node.extra["status"]]) return instance @wrap_provider_errors def start_instances( - self, config, region, zone, _userdata, image, instance_type, count, tags + self, + config, + region: str, + zone: str, + _userdata: str, + image: str, + instance_type: str, + count: int, + tags: dict[str, str], ): assert ( region == "global" @@ -205,9 +224,9 @@ def start_instances( conf = cluster.build_container_vm( yaml.safe_dump(container_spec), disk, zone=zone, preemptible=True ) - tags = _LowercaseDict(tags) - tags[SPOTMGR_TAG + "-Updatable"] = "1" - conf["ex_labels"] = tags + tags_ = _LowercaseDict(tags) + tags_[SPOTMGR_TAG + "-Updatable"] = "1" + conf["ex_labels"] = tags_ self.logger.info( "Creating %dx %s instances... (%d cores total)", count, @@ -218,7 +237,9 @@ def start_instances( return {node.name: self._node_to_instance(node) for node in nodes} @wrap_provider_errors - def check_instances_requests(self, region, instances, tags): + def check_instances_requests( + self, region: str, instances: list[str], tags: dict[str, str] + ) -> tuple[dict[str, str], dict[str, str]]: # this isn't a spot provider, and tags were already set at instance creation assert ( region == "global" @@ -250,7 +271,7 @@ def check_instances_requests(self, region, instances, tags): return (requests, {}) @wrap_provider_errors - def check_instances_state(self, pool_id, region): + def check_instances_state(self, pool_id: int, region: str): # TODO: if we could return a hostname, `check_instances_requests` would be # unnecessary for this provider assert ( @@ -288,42 +309,42 @@ def check_instances_state(self, pool_id, region): return instance_states - def get_image(self, region, config): + def get_image(self, region: str, config: PoolConfiguration) -> str | None: assert ( region == "global" ), f"Invalid region name for GCE: {region} (only 'global' supported)" return config.gce_image_name @staticmethod - def get_cores_per_instance(): + def get_cores_per_instance() -> dict[str, float]: return CORES_PER_INSTANCE @staticmethod - def get_allowed_regions(_config): + def get_allowed_regions(_config: PoolConfiguration) -> list[str]: return ["global"] @staticmethod - def get_image_name(config): + def get_image_name(config: PoolConfiguration) -> str | None: return config.gce_image_name @staticmethod - def get_instance_types(config): + def get_instance_types(config: PoolConfiguration) -> str: return config.gce_machine_types @staticmethod - def get_max_price(config): + def get_max_price(config: PoolConfiguration) -> Decimal: return config.max_price @staticmethod - def get_tags(config): + def get_tags(config: PoolConfiguration) -> str: return config.instance_tags @staticmethod - def get_name(): + def get_name() -> str: return "GCE" @staticmethod - def config_supported(config): + def config_supported(config) -> bool: fields = [ "gce_machine_types", "max_price", @@ -333,7 +354,9 @@ def config_supported(config): ] return all(config.get(key) for key in fields) - def get_prices_per_region(self, region_name, instance_types=None): + def get_prices_per_region( + self, region_name: str, instance_types: list[str] | None = None + ) -> dict[str, dict[str, dict[str, float]]]: # Pricing information is not perfect. The API Zones don't map to the data # provided by Cloud Billing API. We usually get one price per zone (eg. # us-east1), so we just assume that -a, -b, and -c have the same price. In some @@ -342,7 +365,7 @@ def get_prices_per_region(self, region_name, instance_types=None): # worst price and apply it to the whole zone to be conservative. assert region_name == "global" - def get_price(sku): + def get_price(sku) -> tuple[float, str]: assert len(sku["pricingInfo"]) == 1 expr = sku["pricingInfo"][0]["pricingExpression"] assert len(expr["tieredRates"]) == 1, expr["tieredRates"] @@ -440,7 +463,9 @@ def _get_skus_paginated(): # now we have all the data, and just have to calculate for our instance types # and return - result = {} # {instance-type: {region: {az: [prices]}}} + result: dict[ + str, dict[str, dict[str, float]] + ] = {} # {instance-type: {region: {az: [prices]}}} for instance_type, cores in CORES_PER_INSTANCE.items(): mem = RAM_PER_INSTANCE[instance_type] @@ -474,7 +499,7 @@ def _get_skus_paginated(): # since we can't distinguish between zones using the billing API (see TODO # above) return the pricing data for all zones within the GCE region - all_zones_result = {} + all_zones_result: dict[str, dict[str, dict[str, float]]] = {} for instance_type, all_regions in result.items(): for region, region_data in all_regions.items(): all_zones_result[instance_type] = {} diff --git a/server/ec2spotmanager/__init__.py b/server/ec2spotmanager/__init__.py index ba00c6d1d..2e82c2ba6 100644 --- a/server/ec2spotmanager/__init__.py +++ b/server/ec2spotmanager/__init__.py @@ -1 +1,3 @@ +from __future__ import annotations + from . import tasks # noqa diff --git a/server/ec2spotmanager/common/ec2.py b/server/ec2spotmanager/common/ec2.py index 40bc16fe0..dde9c489d 100644 --- a/server/ec2spotmanager/common/ec2.py +++ b/server/ec2spotmanager/common/ec2.py @@ -1,9 +1,16 @@ # data exported from: https://ec2instances.info/ at 2018-08-16 17:54:16 UTC # .. see disclaimers -import collections +from __future__ import annotations -InstanceType = collections.namedtuple("InstanceType", ("api_name", "vCPUs")) +from typing import NamedTuple + + +class InstanceType(NamedTuple): + """InstanceType NamedTuple type information.""" + + api_name: str + vCPUs: int INSTANCE_TYPES = ( diff --git a/server/ec2spotmanager/common/gce.py b/server/ec2spotmanager/common/gce.py index 37bd9ffab..407b5569a 100644 --- a/server/ec2spotmanager/common/gce.py +++ b/server/ec2spotmanager/common/gce.py @@ -1,5 +1,7 @@ # extracted from https://cloud.google.com/compute/docs/machine-types +from __future__ import annotations + import collections InstanceType = collections.namedtuple("InstanceType", ("api_name", "vCPUs", "RAM")) diff --git a/server/ec2spotmanager/common/prices.py b/server/ec2spotmanager/common/prices.py index f5e82eb55..ef8a8cfc6 100644 --- a/server/ec2spotmanager/common/prices.py +++ b/server/ec2spotmanager/common/prices.py @@ -12,8 +12,15 @@ @contact: choller@mozilla.com """ +from __future__ import annotations -def get_prices(regions, cloud_provider, instance_types=None, use_multiprocess=False): + +def get_prices( + regions: list[str], + cloud_provider, + instance_types: list[str] | None = None, + use_multiprocess: bool = False, +) -> dict[str, str]: if use_multiprocess: from multiprocessing import Pool, cpu_count @@ -45,7 +52,7 @@ def get_prices(regions, cloud_provider, instance_types=None, use_multiprocess=Fa return prices -def get_price_median(data): +def get_price_median(data: list[float]) -> float: sdata = sorted(data) n = len(sdata) if not n % 2: diff --git a/server/ec2spotmanager/cron.py b/server/ec2spotmanager/cron.py index f2efb0222..910af796d 100644 --- a/server/ec2spotmanager/cron.py +++ b/server/ec2spotmanager/cron.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import datetime import json import logging @@ -26,7 +28,7 @@ @app.task(ignore_result=True) -def update_stats(): +def update_stats() -> None: from .models import ( Instance, InstancePool, @@ -77,7 +79,7 @@ def update_stats(): # Now check if we need to aggregate some of the detail entries we have entries = PoolUptimeDetailedEntry.objects.filter(pool=pool).order_by("created") - n = entries.count() - (STATS_TOTAL_DETAILED * 60 * 60) / STATS_DELTA_SECS + n = int(entries.count() - (STATS_TOTAL_DETAILED * 60 * 60) / STATS_DELTA_SECS) if n > 0: # We need to aggregate some entries entriesAggr = entries[:n] @@ -132,7 +134,7 @@ def update_stats(): @app.task -def _release_lock(lock_key): +def _release_lock(lock_key: str) -> None: cache = redis.StrictRedis.from_url(settings.REDIS_URL) lock = RedisLock(cache, "ec2spotmanager:check_instance_pools", unique_id=lock_key) if not lock.release(): @@ -143,7 +145,7 @@ def _release_lock(lock_key): @app.task(ignore_result=True) -def check_instance_pools(): +def check_instance_pools() -> None: """EC2SpotManager daemon. - checks all instance pools @@ -281,7 +283,7 @@ def check_instance_pools(): @app.task(ignore_result=True) -def update_prices(): +def update_prices() -> None: """Periodically refresh spot price history and store it in redis to be consumed when spot instances are created. @@ -304,7 +306,7 @@ def update_prices(): if not regions: continue - prices = {} + prices: dict[str, dict[str, str]] = {} for region in regions: for instance_type, price_data in cloud_provider.get_prices_per_region( region diff --git a/server/ec2spotmanager/migrations/0001_squashed_0013_add_gce_fields.py b/server/ec2spotmanager/migrations/0001_squashed_0013_add_gce_fields.py index 1bf871a0d..d3707bcb7 100644 --- a/server/ec2spotmanager/migrations/0001_squashed_0013_add_gce_fields.py +++ b/server/ec2spotmanager/migrations/0001_squashed_0013_add_gce_fields.py @@ -30,7 +30,7 @@ class Migration(migrations.Migration): initial = True - dependencies = [] + dependencies: list[tuple[str, str]] = [] operations = [ migrations.CreateModel( diff --git a/server/ec2spotmanager/migrations/0002_auto_20210429_0908.py b/server/ec2spotmanager/migrations/0002_auto_20210429_0908.py index 7241a6e05..048a3d6b1 100644 --- a/server/ec2spotmanager/migrations/0002_auto_20210429_0908.py +++ b/server/ec2spotmanager/migrations/0002_auto_20210429_0908.py @@ -1,5 +1,7 @@ # Generated by Django 3.0.14 on 2021-04-29 09:08 +from __future__ import annotations + from django.conf import settings from django.db import migrations, models diff --git a/server/ec2spotmanager/models.py b/server/ec2spotmanager/models.py index a9b210e20..e5cd0e125 100644 --- a/server/ec2spotmanager/models.py +++ b/server/ec2spotmanager/models.py @@ -1,15 +1,21 @@ +from __future__ import annotations + import json import os +from datetime import datetime +from decimal import Decimal +from typing import Any from django.conf import settings from django.core.files.base import ContentFile from django.core.files.storage import FileSystemStorage +from django.core.files.uploadedfile import UploadedFile from django.db import models from django.dispatch.dispatcher import receiver from django.utils import timezone -def get_storage_path(self, name): +def get_storage_path(self: models.Model, name: str) -> str: return os.path.join(f"poolconfig-{self.pk}-files", name) @@ -29,28 +35,34 @@ class FlatObject(dict): class OverwritingStorage(FileSystemStorage): - def get_available_name(self, name, max_length=None): + def get_available_name(self, name: str, max_length: int | None = None) -> str: if self.exists(name): os.remove(os.path.join(getattr(settings, "USERDATA_STORAGE", None), name)) return name class PoolConfiguration(models.Model): - parent = models.ForeignKey( + parent: PoolConfiguration | None = models.ForeignKey( "self", blank=True, null=True, on_delete=models.deletion.CASCADE ) name = models.CharField(max_length=255, blank=False) - size = models.IntegerField(default=1, blank=True, null=True) - cycle_interval = models.IntegerField(default=86400, blank=True, null=True) - max_price = models.DecimalField( + size: int | None = models.IntegerField(default=1, blank=True, null=True) + cycle_interval: int | None = models.IntegerField( + default=86400, blank=True, null=True + ) + max_price: Decimal | float | None = models.DecimalField( max_digits=12, decimal_places=6, blank=True, null=True ) instance_tags = models.CharField(max_length=1023, blank=True, null=True) - ec2_key_name = models.CharField(max_length=255, blank=True, null=True) + ec2_key_name: str | None = str( + models.CharField(max_length=255, blank=True, null=True) + ) ec2_security_groups = models.CharField(max_length=255, blank=True, null=True) ec2_instance_types = models.TextField(blank=True, null=True) - ec2_image_name = models.CharField(max_length=255, blank=True, null=True) - ec2_userdata_file = models.FileField( + ec2_image_name: str | None = str( + models.CharField(max_length=255, blank=True, null=True) + ) + ec2_userdata_file: UploadedFile | None = models.FileField( storage=OverwritingStorage( location=getattr(settings, "USERDATA_STORAGE", None) ), @@ -62,10 +74,14 @@ class PoolConfiguration(models.Model): ec2_allowed_regions = models.CharField(max_length=1023, blank=True, null=True) ec2_raw_config = models.TextField(blank=True, null=True) gce_machine_types = models.TextField(blank=True, null=True) - gce_image_name = models.CharField(max_length=255, blank=True, null=True) - gce_container_name = models.CharField(max_length=512, blank=True, null=True) + gce_image_name: str | None = str( + models.CharField(max_length=255, blank=True, null=True) + ) + gce_container_name: str | None = str( + models.CharField(max_length=512, blank=True, null=True) + ) gce_docker_privileged = models.BooleanField(default=False) - gce_disk_size = models.IntegerField(blank=True, null=True) + gce_disk_size: int | None = models.IntegerField(blank=True, null=True) gce_cmd = models.TextField(blank=True, null=True) gce_args = models.TextField(blank=True, null=True) gce_env = models.TextField(blank=True, null=True) @@ -75,31 +91,31 @@ class PoolConfiguration(models.Model): gce_env_include_macros = models.BooleanField(default=False) gce_raw_config = models.TextField(blank=True, null=True) - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: # These variables can hold temporarily deserialized data - self.instance_tags_dict = None - self.instance_tags_override = None - self.ec2_raw_config_dict = None - self.ec2_raw_config_override = None - self.ec2_userdata_macros_dict = None - self.ec2_userdata_macros_override = None - self.ec2_userdata = None - self.ec2_security_groups_list = None - self.ec2_security_groups_override = None - self.ec2_allowed_regions_list = None - self.ec2_allowed_regions_override = None - self.ec2_instance_types_list = None - self.ec2_instance_types_override = None - self.gce_machine_types_list = None - self.gce_machine_types_override = None - self.gce_cmd_list = None - self.gce_cmd_override = None - self.gce_args_list = None - self.gce_args_override = None - self.gce_env_dict = None - self.gce_env_override = None - self.gce_raw_config_dict = None - self.gce_raw_config_override = None + self.instance_tags_dict: dict[str, str] | str | None = None + self.instance_tags_override: bool | None = None + self.ec2_raw_config_dict: dict[str, str] | str | None = None + self.ec2_raw_config_override: bool | None = None + self.ec2_userdata_macros_dict: dict[str, str] | str | None = None + self.ec2_userdata_macros_override: bool | None = None + self.ec2_userdata: bytes | str | None = None + self.ec2_security_groups_list: list[str] | str | None = None + self.ec2_security_groups_override: bool | None = None + self.ec2_allowed_regions_list: list[str] | str | None = None + self.ec2_allowed_regions_override: bool | None = None + self.ec2_instance_types_list: list[str] | str | None = None + self.ec2_instance_types_override: bool | None = None + self.gce_machine_types_list: list[str] | None = None + self.gce_machine_types_override: bool | None = None + self.gce_cmd_list: list[str] | None = None + self.gce_cmd_override: bool | None = None + self.gce_args_list: list[str] | None = None + self.gce_args_override: bool | None = None + self.gce_env_dict: dict[str, str] | None = None + self.gce_env_override: bool | None = None + self.gce_raw_config_dict: dict[str, str] | None = None + self.gce_raw_config_override: bool | None = None # This list is used to update the parent configuration with our own # values and to check for missing fields in our flat config. @@ -226,7 +242,7 @@ def flatten(self, cache=None): return flat_parent_config - def save(self, *args, **kwargs): + def save(self, *args: Any, **kwargs: Any) -> None: # Reserialize data, then call regular save method for field in self.dict_config_fields: obj = getattr(self, field + "_dict") @@ -252,7 +268,7 @@ def save(self, *args, **kwargs): super().save(*args, **kwargs) - def deserializeFields(self): + def deserializeFields(self) -> None: for field in self.dict_config_fields: sobj = getattr(self, field) or "" setattr(self, field + "_override", sobj.startswith("!")) @@ -272,11 +288,12 @@ def deserializeFields(self): self.ec2_userdata = self.ec2_userdata_file.read() self.ec2_userdata_file.close() - def storeTestAndSave(self): + def storeTestAndSave(self) -> None: if self.ec2_userdata: # Save the file using save() to avoid problems when initially # creating the directory. We use os.path.split to keep the # original filename assigned when saving the file. + assert self.ec2_userdata_file is not None self.ec2_userdata_file.save( os.path.split(self.ec2_userdata_file.name)[-1], ContentFile(self.ec2_userdata), @@ -295,7 +312,7 @@ def _cache_parent(self, cache): return self.parent return cache.get(self.parent_id) - def isCyclic(self, cache=None): + def isCyclic(self, cache=None) -> bool: # cache is optionally a prefetched {config_id: config} dictionary used for # parent lookups if self._cache_parent(cache) is None: @@ -310,9 +327,9 @@ def isCyclic(self, cache=None): break tortoise = tortoise._cache_parent(cache) hare = hare._cache_parent(cache)._cache_parent(cache) - return tortoise == hare + return bool(tortoise == hare) - def getMissingParameters(self): + def getMissingParameters(self) -> list[str]: flat_config = self.flatten() ec2_missing_fields = [] gce_missing_fields = [] @@ -348,7 +365,9 @@ def getMissingParameters(self): @receiver(models.signals.post_delete, sender=PoolConfiguration) -def deletePoolConfigurationFiles(sender, instance, **kwargs): +def deletePoolConfigurationFiles( + sender: PoolConfiguration, instance: PoolConfiguration, **kwargs: Any +) -> None: if instance.ec2_userdata: filename = instance.file.path filedir = os.path.dirname(filename) @@ -362,7 +381,7 @@ def deletePoolConfigurationFiles(sender, instance, **kwargs): class InstancePool(models.Model): config = models.ForeignKey(PoolConfiguration, on_delete=models.deletion.CASCADE) isEnabled = models.BooleanField(default=False) - last_cycled = models.DateTimeField(blank=True, null=True) + last_cycled: datetime | None = models.DateTimeField(blank=True, null=True) class Instance(models.Model): @@ -372,7 +391,7 @@ class Instance(models.Model): ) hostname = models.CharField(max_length=255, blank=True, null=True) status_code = models.IntegerField() - status_data = models.TextField(blank=True, null=True) + status_data: str | None = models.TextField(blank=True, null=True) instance_id = models.CharField(max_length=255, blank=True, null=True) region = models.CharField(max_length=255) zone = models.CharField(max_length=255) diff --git a/server/ec2spotmanager/serializers.py b/server/ec2spotmanager/serializers.py index 1de6ed5ab..fc0c7b21b 100644 --- a/server/ec2spotmanager/serializers.py +++ b/server/ec2spotmanager/serializers.py @@ -1,12 +1,15 @@ +from __future__ import annotations + import itertools +from typing import Any from django.http.response import Http404 # noqa from rest_framework import serializers -from ec2spotmanager.models import Instance +from ec2spotmanager.models import Instance, PoolConfiguration -class PoolConfigurationSerializer(serializers.BaseSerializer): +class PoolConfigurationSerializer(serializers.BaseSerializer[PoolConfiguration]): id = serializers.IntegerField(read_only=True) parent = serializers.IntegerField(min_value=0, allow_null=True) name = serializers.CharField(max_length=255) @@ -69,7 +72,7 @@ class PoolConfigurationSerializer(serializers.BaseSerializer): ) gce_raw_config_override = serializers.BooleanField(default=False) - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: self._flatten = kwargs.pop("flatten", False) super().__init__(*args, **kwargs) @@ -104,14 +107,14 @@ def to_representation(self, obj): return result -class MachineStatusSerializer(serializers.ModelSerializer): +class MachineStatusSerializer(serializers.ModelSerializer[Instance]): status_data = serializers.CharField(max_length=4095) class Meta: model = Instance fields = ["status_data"] - def update(self, instance, attrs): + def update(self, instance: Instance, attrs: dict[str, str]) -> Instance: """ Update the status_data field of a given instance """ diff --git a/server/ec2spotmanager/tasks.py b/server/ec2spotmanager/tasks.py index 105340854..3e2e0d338 100644 --- a/server/ec2spotmanager/tasks.py +++ b/server/ec2spotmanager/tasks.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import itertools import json import logging @@ -18,6 +20,7 @@ CloudProviderError, ) from .common.prices import get_price_median +from .models import InstancePool logger = logging.getLogger("ec2spotmanager") @@ -26,7 +29,7 @@ @app.task -def _terminate_instance_ids(provider, region, instance_ids): +def _terminate_instance_ids(provider: str, region: str, instance_ids: str) -> None: cloud_provider = CloudProvider.get_instance(provider) try: cloud_provider.terminate_instances({region: instance_ids}) @@ -37,7 +40,9 @@ def _terminate_instance_ids(provider, region, instance_ids): @app.task -def _terminate_instance_request_ids(provider, region, request_ids): +def _terminate_instance_request_ids( + provider: str, region: str, request_ids: str +) -> None: cloud_provider = CloudProvider.get_instance(provider) try: cloud_provider.cancel_requests({region: request_ids}) @@ -47,19 +52,21 @@ def _terminate_instance_request_ids(provider, region, request_ids): _update_provider_status(provider, "unclassified", str(msg)) -def _determine_best_location(config, count, cache=None): +def _determine_best_location( + config, count: int, cache=None +) -> tuple[str | None, str | None, str | None, str | None, dict[str, int]]: from .models import Instance, ProviderStatusEntry if cache is None: cache = redis.StrictRedis.from_url(settings.REDIS_URL) - best_provider = None - best_zone = None - best_region = None - best_type = None - best_median = None - best_instances = None - rejected_prices = {} + best_provider: str | None = None + best_zone: str | None = None + best_region: str | None = None + best_type: str | None = None + best_median: float | None = None + best_instances: int | None = None + rejected_prices: dict[str, int] = {} for provider in PROVIDERS: cloud_provider = CloudProvider.get_instance(provider) @@ -76,7 +83,7 @@ def _determine_best_location(config, count, cache=None): # Filter machine sizes that would put us over the number of cores required. If # all do, then choose the smallest. - smallest = [] + smallest: list[str] = [] smallest_size = None acceptable_types = [] for instance_type in cloud_provider.get_instance_types(config): @@ -110,8 +117,9 @@ def _determine_best_location(config, count, cache=None): # zone+type is blacklisted because a previous spot request timed-out if ( cache.get( - "%s:blacklist:%s:%s:%s" - % (cloud_provider.get_name(), region, zone, instance_type) + "{}:blacklist:{}:{}:{}".format( + cloud_provider.get_name(), region, zone, instance_type + ) ) is not None ): @@ -148,6 +156,7 @@ def _determine_best_location(config, count, cache=None): provider=provider, region=region, zone=zone ).count() ) + assert best_instances is not None if median == best_median and instances >= best_instances: continue best_provider = provider @@ -169,7 +178,7 @@ def _determine_best_location(config, count, cache=None): return (best_provider, best_region, best_zone, best_type, rejected_prices) -def _start_pool_instances(pool, config, count=1): +def _start_pool_instances(pool, config, count: int = 1) -> None: """Start an instance with the given configuration""" from .models import POOL_STATUS_ENTRY_TYPE, Instance, PoolStatusEntry @@ -200,11 +209,12 @@ def _start_pool_instances(pool, config, count=1): for zone in rejected_prices: msg += f"\n{zone} at {rejected_prices[zone]}" _update_pool_status(pool, "price-too-low", msg) - return + return None elif priceLowEntries: priceLowEntries.delete() + assert provider is not None cloud_provider = CloudProvider.get_instance(provider) image_name = cloud_provider.get_image_name(config) cores_per_instance = cloud_provider.get_cores_per_instance() @@ -273,6 +283,7 @@ def _start_pool_instances(pool, config, count=1): instance.instance_id = instance_name instance.hostname = requested_instance["hostname"] instance.region = region + assert zone is not None instance.zone = zone instance.status_code = requested_instance["status_code"] instance.pool = pool @@ -286,7 +297,7 @@ def _start_pool_instances(pool, config, count=1): _update_pool_status(pool, "unclassified", str(msg)) -def _update_provider_status(provider, type_, message): +def _update_provider_status(provider: str, type_: str, message: str) -> None: from .models import POOL_STATUS_ENTRY_TYPE, ProviderStatusEntry is_critical = type_ not in { @@ -319,7 +330,7 @@ def _update_provider_status(provider, type_, message): logger.warning("Ignoring provider error: already exists.") -def _update_pool_status(pool, type_, message): +def _update_pool_status(pool: InstancePool, type_: str, message: str) -> None: from .models import POOL_STATUS_ENTRY_TYPE, PoolStatusEntry is_critical = type_ not in { @@ -353,16 +364,11 @@ def _update_pool_status(pool, type_, message): @app.task -def update_requests(provider, region, pool_id): +def update_requests(provider: str, region: str, pool_id: int) -> None: """Update all requests in a given provider/region/pool. - @ptype provider: str @param provider: CloudProvider name - - @ptype region: str @param region: Region name within the given provider - - @ptype pool_id: int @param pool_id: InstancePool pk """ from .models import ( @@ -466,13 +472,10 @@ def update_requests(provider, region, pool_id): @app.task -def update_instances(provider, region): +def update_instances(provider: str, region: str) -> None: """Reconcile database instances with cloud provider for a given provider/region. - @ptype provider: str @param provider: CloudProvider name - - @ptype region: str @param region: Region name within the given provider """ from .models import Instance @@ -620,13 +623,10 @@ def update_instances(provider, region): @app.task -def cycle_and_terminate_disabled(provider, region): +def cycle_and_terminate_disabled(provider: str, region: str) -> None: """Kill off instances if pools need to be cycled or disabled. - @ptype provider: str @param provider: CloudProvider name - - @ptype region: str @param region: Region name within the given provider """ from .models import Instance, PoolStatusEntry, ProviderStatusEntry @@ -641,7 +641,7 @@ def cycle_and_terminate_disabled(provider, region): # check if the pool has any instances to be terminated requests_to_terminate = [] instances_to_terminate = [] - instances_by_pool = {} + instances_by_pool: dict[str, int] = {} pool_disable = {} # pool_id -> reason (or blank for enabled) for instance in Instance.objects.filter(provider=provider, region=region): if instance.pool_id not in pool_disable: @@ -704,11 +704,10 @@ def cycle_and_terminate_disabled(provider, region): @app.task -def check_and_resize_pool(pool_id): +def check_and_resize_pool(pool_id: int) -> list[int]: """Check pool size and either request more instances from cheapest provider/region, or terminate unneeded instances. - @ptype pool_id: int @param pool_id: InstancePool pk """ from .models import Instance, InstancePool, PoolStatusEntry @@ -748,7 +747,7 @@ def check_and_resize_pool(pool_id): instance_cores_missing = config.size running_instances = [] - instances = Instance.objects.filter(pool=pool) + instances = list(Instance.objects.filter(pool=pool)) for instance in instances: if instance.status_code in [ @@ -821,10 +820,9 @@ def check_and_resize_pool(pool_id): @app.task -def terminate_instances(pool_instances): +def terminate_instances(pool_instances: list[list[int]]) -> None: """Terminate a given list of instances. - @ptype pool_instances: list of lists of instance ids @param pool_instances: Takes the results from multiple calls to check_and_resize_pool(), and aggregates the results into one call to terminate instances/requests per provider/region. diff --git a/server/ec2spotmanager/templatetags/datetags.py b/server/ec2spotmanager/templatetags/datetags.py index c316e2f2c..ca91840f5 100644 --- a/server/ec2spotmanager/templatetags/datetags.py +++ b/server/ec2spotmanager/templatetags/datetags.py @@ -1,3 +1,7 @@ +from __future__ import annotations + +from datetime import datetime + from django import template from django.utils import timezone @@ -5,7 +9,7 @@ @register.filter -def date_ago(d): +def date_ago(d: datetime) -> str: delta = timezone.now() - d days = delta.days diff --git a/server/ec2spotmanager/templatetags/recursetags.py b/server/ec2spotmanager/templatetags/recursetags.py index 039f95607..26731b17e 100644 --- a/server/ec2spotmanager/templatetags/recursetags.py +++ b/server/ec2spotmanager/templatetags/recursetags.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from django import template from django.utils.safestring import mark_safe @@ -5,11 +7,15 @@ class RecurseConfigTree(template.Node): - def __init__(self, template_nodes, config_var): + def __init__( + self, template_nodes: template.NodeList, config_var: template.Variable + ) -> None: self.template_nodes = template_nodes self.config_var = config_var - def _render_node(self, context, node): + def _render_node( + self, context: template.context.Context, node: template.Node + ) -> str: context.push() context["node"] = node children = [self._render_node(context, x) for x in node.children] @@ -19,12 +25,14 @@ def _render_node(self, context, node): context.pop() return rendered - def render(self, context): + def render(self, context: template.context.Context) -> str: return self._render_node(context, self.config_var.resolve(context)) @register.tag -def recurseconfig(parser, token): +def recurseconfig( + parser: template.base.Parser, token: template.base.Token +) -> RecurseConfigTree: bits = token.contents.split() if len(bits) != 2: raise template.TemplateSyntaxError( diff --git a/server/ec2spotmanager/tests/__init__.py b/server/ec2spotmanager/tests/__init__.py index 5db94192f..5eff17fda 100644 --- a/server/ec2spotmanager/tests/__init__.py +++ b/server/ec2spotmanager/tests/__init__.py @@ -8,9 +8,16 @@ License, v. 2.0. If a copy of the MPL was not distributed with this file, You can obtain one at http://mozilla.org/MPL/2.0/. """ + +from __future__ import annotations + +from datetime import datetime import logging +from typing import cast from django.core.files.base import ContentFile +from django.core.files.uploadedfile import UploadedFile +from django.http.response import HttpResponse from django.test import SimpleTestCase as DjangoTestCase from django.utils import timezone @@ -29,46 +36,49 @@ class UncatchableException(BaseException): exception handling.""" -def assert_contains(response, text): +def assert_contains(response: HttpResponse, text: str) -> None: """Assert that the response was successful, and contains the given text.""" class _(DjangoTestCase): - def runTest(self): + def runTest(self) -> None: pass _().assertContains(response, text) def create_config( - name, - parent=None, - size=None, - cycle_interval=None, - ec2_key_name=None, - ec2_security_groups=None, - ec2_instance_types=None, - ec2_image_name=None, - ec2_userdata_macros=None, - ec2_allowed_regions=None, - max_price=None, - instance_tags=None, - ec2_raw_config=None, - ec2_userdata=None, - gce_image_name=None, - gce_container_name=None, - gce_disk_size=None, -): - result = PoolConfiguration.objects.create( - name=name, - parent=parent, - size=size, - cycle_interval=cycle_interval, - ec2_key_name=ec2_key_name, - ec2_image_name=ec2_image_name, - max_price=max_price, - gce_image_name=gce_image_name, - gce_disk_size=gce_disk_size, - gce_container_name=gce_container_name, + name: str, + parent: PoolConfiguration | None = None, + size: int | None = None, + cycle_interval: int | None = None, + ec2_key_name: str | None = None, + ec2_security_groups: list[str] | str | None = None, + ec2_instance_types: list[str] | None = None, + ec2_image_name: str | None = None, + ec2_userdata_macros: dict[str, str] | None = None, + ec2_allowed_regions: list[str] | None = None, + max_price: float | str | None = None, + instance_tags: dict[str, str] | None = None, + ec2_raw_config: dict[str, str] | None = None, + ec2_userdata: UploadedFile | str | bytes | None = None, + gce_image_name: str | None = None, + gce_container_name: str | None = None, + gce_disk_size: int | None = None, +) -> PoolConfiguration: + result = cast( + PoolConfiguration, + PoolConfiguration.objects.create( + name=name, + parent=parent, + size=size, + cycle_interval=cycle_interval, + ec2_key_name=ec2_key_name, + ec2_image_name=ec2_image_name, + max_price=max_price, + gce_image_name=gce_image_name, + gce_disk_size=gce_disk_size, + gce_container_name=gce_container_name, + ), ) if ec2_security_groups is not None: result.ec2_security_groups_list = ec2_security_groups @@ -83,6 +93,7 @@ def create_config( if ec2_raw_config is not None: result.ec2_raw_config_dict = ec2_raw_config if ec2_userdata is not None: + assert result.ec2_userdata_file is not None if not result.ec2_userdata_file.name: result.ec2_userdata_file.save("default.sh", ContentFile("")) result.ec2_userdata = ec2_userdata @@ -91,45 +102,55 @@ def create_config( return result -def create_pool(config, enabled=False, last_cycled=None): - result = InstancePool.objects.create( - config=config, isEnabled=enabled, last_cycled=last_cycled +def create_pool( + config: PoolConfiguration, + enabled: bool = False, + last_cycled: datetime | None = None, +) -> InstancePool: + result = cast( + InstancePool, + InstancePool.objects.create( + config=config, isEnabled=enabled, last_cycled=last_cycled + ), ) LOG.debug("Created InstancePool pk=%d", result.pk) return result -def create_poolmsg(pool): - result = PoolStatusEntry.objects.create(pool=pool, type=0) +def create_poolmsg(pool: InstancePool) -> PoolStatusEntry: + result = cast(PoolStatusEntry, PoolStatusEntry.objects.create(pool=pool, type=0)) LOG.debug("Created PoolStatusEntry pk=%d", result.pk) return result def create_instance( - hostname, - pool=None, - status_code=0, - status_data=None, - ec2_instance_id=None, - ec2_region="", - ec2_zone="", - size=1, - created=None, - provider="EC2Spot", -): + hostname: str | None, + pool: InstancePool | None = None, + status_code: int = 0, + status_data: str | None = None, + ec2_instance_id: str | int | None = None, + ec2_region: str = "", + ec2_zone: str = "", + size: int = 1, + created: datetime | None = None, + provider: str = "EC2Spot", +) -> Instance: if created is None: created = timezone.now() - result = Instance.objects.create( - pool=pool, - hostname=hostname, - status_code=status_code, - status_data=status_data, - instance_id=ec2_instance_id, - region=ec2_region, - zone=ec2_zone, - size=size, - created=created, - provider=provider, + result = cast( + Instance, + Instance.objects.create( + pool=pool, + hostname=hostname, + status_code=status_code, + status_data=status_data, + instance_id=ec2_instance_id, + region=ec2_region, + zone=ec2_zone, + size=size, + created=created, + provider=provider, + ), ) LOG.debug("Created Instance pk=%d", result.pk) return result diff --git a/server/ec2spotmanager/tests/conftest.py b/server/ec2spotmanager/tests/conftest.py index 83ed76de1..ab17fc04c 100644 --- a/server/ec2spotmanager/tests/conftest.py +++ b/server/ec2spotmanager/tests/conftest.py @@ -9,6 +9,9 @@ License, v. 2.0. If a copy of the MPL was not distributed with this file, You can obtain one at http://mozilla.org/MPL/2.0/. """ + +from __future__ import annotations + import functools import sys from unittest.mock import Mock @@ -16,16 +19,21 @@ import pytest from django.contrib.auth.models import Permission, User from django.contrib.contenttypes.models import ContentType +from pytest_mock import MockerFixture from crashmanager.models import User as cmUser from ec2spotmanager.CloudProvider.CloudProvider import CloudProvider +from ec2spotmanager.models import InstancePool, PoolConfiguration from . import UncatchableException def _create_user( - username, email="test@mozilla.com", password="test", has_permission=True -): + username: str, + email: str = "test@mozilla.com", + password: str = "test", + has_permission: bool = True, +) -> User: user = User.objects.create_user(username, email, password) user.user_permissions.clear() if has_permission: @@ -40,7 +48,9 @@ def _create_user( @pytest.fixture -def ec2spotmanager_test(db): # pylint: disable=invalid-name,unused-argument +def ec2spotmanager_test( + db: None, +) -> None: # pylint: disable=invalid-name,unused-argument """Common testcase class for all ec2spotmanager unittests""" # Create one unrestricted and one restricted test user _create_user("test") @@ -48,11 +58,11 @@ def ec2spotmanager_test(db): # pylint: disable=invalid-name,unused-argument @pytest.fixture -def mock_provider(mocker): +def mock_provider(mocker: MockerFixture) -> Mock: prv_t = Mock(spec=CloudProvider) - def allowed_regions(cls, cfg): - result = [] + def allowed_regions(cls, cfg: PoolConfiguration) -> list[str]: + result: list[str] = [] if cls.provider == "prov1": result.extend(set(cfg.ec2_allowed_regions) & set("abcd")) if cls.provider == "prov2": @@ -60,7 +70,7 @@ def allowed_regions(cls, cfg): result.sort() return result - def get_instance(cls, provider): + def get_instance(cls, provider: str): cls.provider = provider return cls @@ -77,8 +87,8 @@ def get_instance(cls, provider): @pytest.fixture -def raise_on_status(mocker): - def _mock_pool_status(_pool, type_, message): +def raise_on_status(mocker: MockerFixture) -> None: + def _mock_pool_status(_pool: InstancePool, type_: str, message: str) -> None: if sys.exc_info() != (None, None, None): raise # pylint: disable=misplaced-bare-raise raise UncatchableException(f"{type_}: {message}") diff --git a/server/ec2spotmanager/tests/test_configs.py b/server/ec2spotmanager/tests/test_configs.py index d44256cf0..2262fd489 100644 --- a/server/ec2spotmanager/tests/test_configs.py +++ b/server/ec2spotmanager/tests/test_configs.py @@ -9,12 +9,18 @@ License, v. 2.0. If a copy of the MPL was not distributed with this file, You can obtain one at http://mozilla.org/MPL/2.0/. """ + +from __future__ import annotations + import decimal import json import logging +import typing import pytest import requests +from django.http.response import HttpResponse +from django.test.client import Client from django.urls import reverse from ec2spotmanager.models import PoolConfiguration @@ -35,16 +41,19 @@ ("ec2spotmanager:configdel", {"configid": 0}), ], ) -def test_configs_no_login(client, name, kwargs): +def test_configs_no_login(client: Client, name: str, kwargs: dict[str, object]) -> None: """Request without login hits the login redirect""" path = reverse(name, kwargs=kwargs) response = client.get(path) LOG.debug(response) assert response.status_code == requests.codes["found"] - assert response.url == "/login/?next=" + path + assert ( + typing.cast(typing.Union[str, None], getattr(response, "url", None)) + == "/login/?next=" + path + ) -def test_configs_view_no_configs(client): +def test_configs_view_no_configs(client: Client) -> None: """If no configs in db, an appropriate message is shown.""" client.login(username="test", password="test") response = client.get(reverse("ec2spotmanager:configs")) @@ -54,7 +63,7 @@ def test_configs_view_no_configs(client): assert len(configtree) == 0 # 0 configs -def test_configs_view_config(client): +def test_configs_view_config(client: Client) -> None: """Create config and see that it is shown.""" client.login(username="test", password="test") config = create_config("config #1") @@ -65,10 +74,10 @@ def test_configs_view_config(client): assert len(configtree) == 1 # 1 config assert set(configtree) == {config} # same config assert len(configtree[0].children) == 0 - assert_contains(response, "config #1") + assert_contains(typing.cast(HttpResponse, response), "config #1") -def test_configs_view_configs(client): +def test_configs_view_configs(client: Client) -> None: """Create configs and see that they are shown.""" client.login(username="test", password="test") configs = (create_config("config #1"), create_config("config #2")) @@ -80,11 +89,11 @@ def test_configs_view_configs(client): assert set(configtree) == set(configs) # same configs assert len(configtree[0].children) == 0 assert len(configtree[1].children) == 0 - assert_contains(response, "config #1") - assert_contains(response, "config #2") + assert_contains(typing.cast(HttpResponse, response), "config #1") + assert_contains(typing.cast(HttpResponse, response), "config #2") -def test_configs_view_config_tree(client): +def test_configs_view_config_tree(client: Client) -> None: """Create nested configs and see that they are shown.""" client.login(username="test", password="test") config1 = create_config("config #1") @@ -111,24 +120,24 @@ def test_configs_view_config_tree(client): raise Exception(f"unexpected configuration: {cfg.name}") assert seen1 assert seen3 - assert_contains(response, "config #1") - assert_contains(response, "config #2") - assert_contains(response, "config #3") + assert_contains(typing.cast(HttpResponse, response), "config #1") + assert_contains(typing.cast(HttpResponse, response), "config #2") + assert_contains(typing.cast(HttpResponse, response), "config #3") -def test_create_config_view_create_form(client): +def test_create_config_view_create_form(client: Client) -> None: """Config creation form should be shown""" client.login(username="test", password="test") response = client.get(reverse("ec2spotmanager:configcreate")) LOG.debug(response) assert response.status_code == requests.codes["ok"] - assert_contains(response, "Create Configuration") - assert_contains(response, 'name="name"') - assert_contains(response, 'name="size"') - assert_contains(response, 'name="cycle_interval"') + assert_contains(typing.cast(HttpResponse, response), "Create Configuration") + assert_contains(typing.cast(HttpResponse, response), 'name="name"') + assert_contains(typing.cast(HttpResponse, response), 'name="size"') + assert_contains(typing.cast(HttpResponse, response), 'name="cycle_interval"') -def test_create_config_view_create(client): +def test_create_config_view_create(client: Client) -> None: """Config created via form should be added to db""" client.login(username="test", password="test") response = client.post( @@ -180,15 +189,15 @@ def test_create_config_view_create(client): assert cfg.gce_disk_size == 12 assert json.loads(cfg.gce_env) == {"tag1": "value1", "tag2": "value2"} assert response.status_code == requests.codes["found"] - assert response.url == reverse( - "ec2spotmanager:configview", kwargs={"configid": cfg.pk} - ) + assert typing.cast( + typing.Union[str, None], getattr(response, "url", None) + ) == reverse("ec2spotmanager:configview", kwargs={"configid": cfg.pk}) assert json.loads(cfg.gce_cmd) == ["cat"] assert json.loads(cfg.gce_args) == ["foo", "bar"] assert json.loads(cfg.gce_raw_config) == {"tag3": "value3", "tag4": "value4"} -def test_create_config_view_clone(client): +def test_create_config_view_clone(client: Client) -> None: """Creation form should contain source data""" client.login(username="test", password="test") cfg = create_config( @@ -208,24 +217,24 @@ def test_create_config_view_clone(client): response = client.get(reverse("ec2spotmanager:configcreate"), {"clone": cfg.pk}) LOG.debug(response) assert response.status_code == requests.codes["ok"] - assert_contains(response, "Clone Configuration") - assert_contains(response, "config #1 (Cloned)") - assert_contains(response, "1234567") - assert_contains(response, "7654321") - assert_contains(response, "key #1") - assert_contains(response, "group #1") - assert_contains(response, "machine #1") - assert_contains(response, "ami #1") - assert_contains(response, "yup=123") - assert_contains(response, "nope=456") - assert_contains(response, "nowhere") - assert_contains(response, "0.01") - assert_contains(response, "bad=false") - assert_contains(response, "good=true") - assert_contains(response, "hello=world") - - -def test_view_config_view(client): + assert_contains(typing.cast(HttpResponse, response), "Clone Configuration") + assert_contains(typing.cast(HttpResponse, response), "config #1 (Cloned)") + assert_contains(typing.cast(HttpResponse, response), "1234567") + assert_contains(typing.cast(HttpResponse, response), "7654321") + assert_contains(typing.cast(HttpResponse, response), "key #1") + assert_contains(typing.cast(HttpResponse, response), "group #1") + assert_contains(typing.cast(HttpResponse, response), "machine #1") + assert_contains(typing.cast(HttpResponse, response), "ami #1") + assert_contains(typing.cast(HttpResponse, response), "yup=123") + assert_contains(typing.cast(HttpResponse, response), "nope=456") + assert_contains(typing.cast(HttpResponse, response), "nowhere") + assert_contains(typing.cast(HttpResponse, response), "0.01") + assert_contains(typing.cast(HttpResponse, response), "bad=false") + assert_contains(typing.cast(HttpResponse, response), "good=true") + assert_contains(typing.cast(HttpResponse, response), "hello=world") + + +def test_view_config_view(client: Client) -> None: """Create a config and view it""" cfg = create_config( name="config #1", @@ -247,23 +256,23 @@ def test_view_config_view(client): ) LOG.debug(response) assert response.status_code == requests.codes["ok"] - assert_contains(response, "config #1") - assert_contains(response, "1234567") - assert_contains(response, "7654321") - assert_contains(response, "key #1") - assert_contains(response, "group #1") - assert_contains(response, "machine #1") - assert_contains(response, "ami #1") - assert_contains(response, "yup=123") - assert_contains(response, "nope=456") - assert_contains(response, "nowhere") - assert_contains(response, "0.01") - assert_contains(response, "bad=false") - assert_contains(response, "good=true") - assert_contains(response, "hello=world") - - -def test_edit_config_view(client): + assert_contains(typing.cast(HttpResponse, response), "config #1") + assert_contains(typing.cast(HttpResponse, response), "1234567") + assert_contains(typing.cast(HttpResponse, response), "7654321") + assert_contains(typing.cast(HttpResponse, response), "key #1") + assert_contains(typing.cast(HttpResponse, response), "group #1") + assert_contains(typing.cast(HttpResponse, response), "machine #1") + assert_contains(typing.cast(HttpResponse, response), "ami #1") + assert_contains(typing.cast(HttpResponse, response), "yup=123") + assert_contains(typing.cast(HttpResponse, response), "nope=456") + assert_contains(typing.cast(HttpResponse, response), "nowhere") + assert_contains(typing.cast(HttpResponse, response), "0.01") + assert_contains(typing.cast(HttpResponse, response), "bad=false") + assert_contains(typing.cast(HttpResponse, response), "good=true") + assert_contains(typing.cast(HttpResponse, response), "hello=world") + + +def test_edit_config_view(client: Client) -> None: """Edit an existing config""" cfg = create_config( name="config #1", @@ -285,24 +294,24 @@ def test_edit_config_view(client): ) LOG.debug(response) assert response.status_code == requests.codes["ok"] - assert_contains(response, "Edit Configuration") - assert_contains(response, "config #1") - assert_contains(response, "1234567") - assert_contains(response, "7654321") - assert_contains(response, "key #1") - assert_contains(response, "group #1") - assert_contains(response, "machine #1") - assert_contains(response, "ami #1") - assert_contains(response, "yup=123") - assert_contains(response, "nope=456") - assert_contains(response, "nowhere") - assert_contains(response, "0.01") - assert_contains(response, "bad=false") - assert_contains(response, "good=true") - assert_contains(response, "hello=world") - - -def test_del_config_view_delete(client): + assert_contains(typing.cast(HttpResponse, response), "Edit Configuration") + assert_contains(typing.cast(HttpResponse, response), "config #1") + assert_contains(typing.cast(HttpResponse, response), "1234567") + assert_contains(typing.cast(HttpResponse, response), "7654321") + assert_contains(typing.cast(HttpResponse, response), "key #1") + assert_contains(typing.cast(HttpResponse, response), "group #1") + assert_contains(typing.cast(HttpResponse, response), "machine #1") + assert_contains(typing.cast(HttpResponse, response), "ami #1") + assert_contains(typing.cast(HttpResponse, response), "yup=123") + assert_contains(typing.cast(HttpResponse, response), "nope=456") + assert_contains(typing.cast(HttpResponse, response), "nowhere") + assert_contains(typing.cast(HttpResponse, response), "0.01") + assert_contains(typing.cast(HttpResponse, response), "bad=false") + assert_contains(typing.cast(HttpResponse, response), "good=true") + assert_contains(typing.cast(HttpResponse, response), "hello=world") + + +def test_del_config_view_delete(client: Client) -> None: """Delete an existing config""" cfg = create_config(name="config #1") client.login(username="test", password="test") @@ -311,11 +320,13 @@ def test_del_config_view_delete(client): ) LOG.debug(response) assert response.status_code == requests.codes["found"] - assert response.url == reverse("ec2spotmanager:configs") + assert typing.cast( + typing.Union[str, None], getattr(response, "url", None) + ) == reverse("ec2spotmanager:configs") assert PoolConfiguration.objects.count() == 0 -def test_del_config_view_simple_get(client): +def test_del_config_view_simple_get(client: Client) -> None: """No errors are thrown in template""" cfg = create_config(name="config #1") client.login(username="test", password="test") diff --git a/server/ec2spotmanager/tests/test_configs_rest.py b/server/ec2spotmanager/tests/test_configs_rest.py index fb6b136ee..e66e4ecd8 100644 --- a/server/ec2spotmanager/tests/test_configs_rest.py +++ b/server/ec2spotmanager/tests/test_configs_rest.py @@ -9,12 +9,16 @@ License, v. 2.0. If a copy of the MPL was not distributed with this file, You can obtain one at http://mozilla.org/MPL/2.0/. """ + +from __future__ import annotations + import json import logging import pytest import requests from django.contrib.auth.models import User +from rest_framework.test import APIClient from . import create_config @@ -22,7 +26,7 @@ pytestmark = pytest.mark.usefixtures("ec2spotmanager_test") -def test_rest_pool_configs_no_auth(api_client): +def test_rest_pool_configs_no_auth(api_client: APIClient) -> None: """must yield forbidden without authentication""" url = "/ec2spotmanager/rest/configurations/" assert api_client.get(url).status_code == requests.codes["unauthorized"] @@ -32,7 +36,7 @@ def test_rest_pool_configs_no_auth(api_client): assert api_client.delete(url).status_code == requests.codes["unauthorized"] -def test_rest_pool_configs_no_perm(api_client): +def test_rest_pool_configs_no_perm(api_client: APIClient) -> None: """must yield forbidden without permission""" user = User.objects.get(username="test-noperm") api_client.force_authenticate(user=user) @@ -44,7 +48,7 @@ def test_rest_pool_configs_no_perm(api_client): assert api_client.delete(url).status_code == requests.codes["forbidden"] -def test_rest_pool_configs_auth(api_client): +def test_rest_pool_configs_auth(api_client: APIClient) -> None: """test that authenticated requests work""" user = User.objects.get(username="test") api_client.force_authenticate(user=user) @@ -52,7 +56,7 @@ def test_rest_pool_configs_auth(api_client): assert resp.status_code == requests.codes["ok"] -def test_rest_pool_configs_patch(api_client): +def test_rest_pool_configs_patch(api_client: APIClient) -> None: """patch should not be allowed""" user = User.objects.get(username="test") api_client.force_authenticate(user=user) @@ -60,7 +64,7 @@ def test_rest_pool_configs_patch(api_client): assert resp.status_code == requests.codes["method_not_allowed"] -def test_rest_pool_configs_put(api_client): +def test_rest_pool_configs_put(api_client: APIClient) -> None: """put should not be allowed""" user = User.objects.get(username="test") api_client.force_authenticate(user=user) @@ -68,7 +72,7 @@ def test_rest_pool_configs_put(api_client): assert resp.status_code == requests.codes["method_not_allowed"] -def test_rest_pool_configs_post(api_client): +def test_rest_pool_configs_post(api_client: APIClient) -> None: """post should not be allowed""" user = User.objects.get(username="test") api_client.force_authenticate(user=user) @@ -76,7 +80,7 @@ def test_rest_pool_configs_post(api_client): assert resp.status_code == requests.codes["method_not_allowed"] -def test_rest_pool_configs_delete(api_client): +def test_rest_pool_configs_delete(api_client: APIClient) -> None: """delete should not be allowed""" user = User.objects.get(username="test") api_client.force_authenticate(user=user) @@ -84,7 +88,7 @@ def test_rest_pool_configs_delete(api_client): assert resp.status_code == requests.codes["method_not_allowed"] -def test_rest_pool_configs_list_no_configs(api_client): +def test_rest_pool_configs_list_no_configs(api_client: APIClient) -> None: """test empty response to config list""" user = User.objects.get(username="test") api_client.force_authenticate(user=user) @@ -99,7 +103,7 @@ def test_rest_pool_configs_list_no_configs(api_client): assert len(resp) == 0 -def test_rest_pool_configs_list_configs(api_client): +def test_rest_pool_configs_list_configs(api_client: APIClient) -> None: """test that configs can be listed""" cfg1 = create_config( name="config #1", @@ -185,7 +189,7 @@ def test_rest_pool_configs_list_configs(api_client): assert not resp[0]["ec2_userdata_macros_override"] -def test_rest_pool_config_no_auth(api_client): +def test_rest_pool_config_no_auth(api_client: APIClient) -> None: """must yield forbidden without authentication""" url = "/ec2spotmanager/rest/configurations/1/" assert api_client.get(url).status_code == requests.codes["unauthorized"] @@ -195,7 +199,7 @@ def test_rest_pool_config_no_auth(api_client): assert api_client.delete(url).status_code == requests.codes["unauthorized"] -def test_rest_pool_config_no_perm(api_client): +def test_rest_pool_config_no_perm(api_client: APIClient) -> None: """must yield forbidden without permission""" user = User.objects.get(username="test-noperm") api_client.force_authenticate(user=user) @@ -207,7 +211,7 @@ def test_rest_pool_config_no_perm(api_client): assert api_client.delete(url).status_code == requests.codes["forbidden"] -def test_rest_pool_config_auth(api_client): +def test_rest_pool_config_auth(api_client: APIClient) -> None: """test that authenticated requests work""" user = User.objects.get(username="test") api_client.force_authenticate(user=user) @@ -215,7 +219,7 @@ def test_rest_pool_config_auth(api_client): assert resp.status_code == requests.codes["ok"] -def test_rest_pool_config_delete(api_client): +def test_rest_pool_config_delete(api_client: APIClient) -> None: """delete should not be allowed""" user = User.objects.get(username="test") api_client.force_authenticate(user=user) @@ -223,7 +227,7 @@ def test_rest_pool_config_delete(api_client): assert resp.status_code == requests.codes["method_not_allowed"] -def test_rest_pool_config_patch(api_client): +def test_rest_pool_config_patch(api_client: APIClient) -> None: """patch should not be allowed""" user = User.objects.get(username="test") api_client.force_authenticate(user=user) @@ -231,7 +235,7 @@ def test_rest_pool_config_patch(api_client): assert resp.status_code == requests.codes["method_not_allowed"] -def test_rest_pool_config_put(api_client): +def test_rest_pool_config_put(api_client: APIClient) -> None: """put should not be allowed""" user = User.objects.get(username="test") api_client.force_authenticate(user=user) @@ -239,7 +243,7 @@ def test_rest_pool_config_put(api_client): assert resp.status_code == requests.codes["method_not_allowed"] -def test_rest_pool_config_post(api_client): +def test_rest_pool_config_post(api_client: APIClient) -> None: """post should not be allowed""" user = User.objects.get(username="test") api_client.force_authenticate(user=user) @@ -247,7 +251,7 @@ def test_rest_pool_config_post(api_client): assert resp.status_code == requests.codes["method_not_allowed"] -def test_rest_pool_config_get_0(api_client): +def test_rest_pool_config_get_0(api_client: APIClient) -> None: """test that non-existent PoolConfiguration is error""" user = User.objects.get(username="test") api_client.force_authenticate(user=user) @@ -255,7 +259,7 @@ def test_rest_pool_config_get_0(api_client): assert resp.status_code == requests.codes["not_found"] -def test_rest_pool_config_get_1(api_client): +def test_rest_pool_config_get_1(api_client: APIClient) -> None: """test that individual PoolConfiguration can be fetched""" cfg1 = create_config( name="config #1", @@ -335,7 +339,7 @@ def test_rest_pool_config_get_1(api_client): assert not resp["ec2_userdata_macros_override"] -def test_rest_pool_config_get_sub(api_client): +def test_rest_pool_config_get_sub(api_client: APIClient) -> None: """test that inherited Signature can be fetched unflattened""" cfg1 = create_config( name="config #1", @@ -416,7 +420,7 @@ def test_rest_pool_config_get_sub(api_client): assert not resp["ec2_userdata_macros_override"] -def test_rest_pool_config_get_sub_flat(api_client): +def test_rest_pool_config_get_sub_flat(api_client: APIClient) -> None: """test that inherited Signature can be fetched flattened""" cfg1 = create_config( name="config #1", diff --git a/server/ec2spotmanager/tests/test_ec2spotmanager.py b/server/ec2spotmanager/tests/test_ec2spotmanager.py index 48f61b061..2a470a692 100644 --- a/server/ec2spotmanager/tests/test_ec2spotmanager.py +++ b/server/ec2spotmanager/tests/test_ec2spotmanager.py @@ -9,30 +9,33 @@ License, v. 2.0. If a copy of the MPL was not distributed with this file, You can obtain one at http://mozilla.org/MPL/2.0/. """ + +from __future__ import annotations + import logging +import typing import pytest import requests +from django.test.client import Client from django.urls import reverse -LOG = logging.getLogger( - "fm.ec2spotmanager.tests.ec2spotmanager" -) # pylint: disable=invalid-name -pytestmark = pytest.mark.usefixtures( - "ec2spotmanager_test" -) # pylint: disable=invalid-name +LOG = logging.getLogger("fm.ec2spotmanager.tests.ec2spotmanager") +pytestmark = pytest.mark.usefixtures("ec2spotmanager_test") -def test_ec2spotmanager_index(client): +def test_ec2spotmanager_index(client: Client) -> None: """Request of root url redirects to pools view""" client.login(username="test", password="test") response = client.get(reverse("ec2spotmanager:index")) LOG.debug(response) assert response.status_code == requests.codes["found"] - assert response.url == reverse("ec2spotmanager:pools") + assert typing.cast( + typing.Union[str, None], getattr(response, "url", None) + ) == reverse("ec2spotmanager:pools") -def test_ec2spotmanager_logout(client): +def test_ec2spotmanager_logout(client: Client) -> None: """Logout url actually logs us out""" client.login(username="test", password="test") index = reverse("ec2spotmanager:pools") @@ -41,10 +44,13 @@ def test_ec2spotmanager_logout(client): LOG.debug(response) response = client.get(index) assert response.status_code == requests.codes["found"] - assert response.url == "/login/?next=" + index + assert ( + typing.cast(typing.Union[str, None], getattr(response, "url", None)) + == "/login/?next=" + index + ) -def test_ec2spotmanager_noperm(client): +def test_ec2spotmanager_noperm(client: Client) -> None: """Request without permission results in 403""" client.login(username="test-noperm", password="test") resp = client.get(reverse("ec2spotmanager:index")) diff --git a/server/ec2spotmanager/tests/test_pools.py b/server/ec2spotmanager/tests/test_pools.py index cc6873e5c..80deb4fe3 100644 --- a/server/ec2spotmanager/tests/test_pools.py +++ b/server/ec2spotmanager/tests/test_pools.py @@ -9,10 +9,16 @@ License, v. 2.0. If a copy of the MPL was not distributed with this file, You can obtain one at http://mozilla.org/MPL/2.0/. """ + +from __future__ import annotations + import logging +import typing import pytest import requests +from django.http.response import HttpResponse +from django.test.client import Client from django.urls import reverse from . import assert_contains, create_config, create_pool, create_poolmsg @@ -38,15 +44,18 @@ ("ec2spotmanager:poolmsgdel", {"msgid": 0}), ], ) -def test_pools_no_login(client, name, kwargs): +def test_pools_no_login(client: Client, name: str, kwargs: dict[str, object]) -> None: """Request without login hits the login redirect""" path = reverse(name, kwargs=kwargs) resp = client.get(path) assert resp.status_code == requests.codes["found"] - assert resp.url == "/login/?next=" + path + assert ( + typing.cast(typing.Union[str, None], getattr(resp, "url", None)) + == "/login/?next=" + path + ) -def test_pools_view_no_pools(client): +def test_pools_view_no_pools(client: Client) -> None: """If no pools in db, an appropriate message is shown.""" client.login(username="test", password="test") response = client.get(reverse("ec2spotmanager:pools")) @@ -54,10 +63,10 @@ def test_pools_view_no_pools(client): assert response.status_code == requests.codes["ok"] poollist = response.context["poollist"] assert len(poollist) == 0 # 0 pools - assert_contains(response, POOLS_ENTRIES_FMT % 0) + assert_contains(typing.cast(HttpResponse, response), POOLS_ENTRIES_FMT % 0) -def test_pools_view_pool(client): +def test_pools_view_pool(client: Client) -> None: """Create pool and see that it is shown.""" config = create_config(name="config #1") pool = create_pool(config=config) @@ -67,11 +76,11 @@ def test_pools_view_pool(client): assert response.status_code == requests.codes["ok"] poollist = response.context["poollist"] assert len(poollist) == 1 # 1 pools - assert_contains(response, POOLS_ENTRIES_FMT % 1) + assert_contains(typing.cast(HttpResponse, response), POOLS_ENTRIES_FMT % 1) assert set(poollist) == {pool} -def test_pools_view_pools(client): +def test_pools_view_pools(client: Client) -> None: """Create pool and see that it is shown.""" configs = (create_config(name="config #1"), create_config(name="config #2")) pools = [create_pool(config=cfg) for cfg in configs] @@ -81,11 +90,11 @@ def test_pools_view_pools(client): assert response.status_code == requests.codes["ok"] poollist = response.context["poollist"] assert len(poollist) == 2 # 2 pools - assert_contains(response, POOLS_ENTRIES_FMT % 2) + assert_contains(typing.cast(HttpResponse, response), POOLS_ENTRIES_FMT % 2) assert set(poollist) == set(pools) -def test_create_pool_view_simple_get(client): +def test_create_pool_view_simple_get(client: Client) -> None: """No errors are thrown in template""" client.login(username="test", password="test") response = client.get(reverse("ec2spotmanager:poolcreate")) @@ -93,7 +102,7 @@ def test_create_pool_view_simple_get(client): assert response.status_code == requests.codes["ok"] -def test_view_pool_view_simple_get(client): +def test_view_pool_view_simple_get(client: Client) -> None: """No errors are thrown in template""" cfg = create_config(name="config #1") pool = create_pool(config=cfg) @@ -105,7 +114,7 @@ def test_view_pool_view_simple_get(client): assert response.status_code == requests.codes["ok"] -def test_pool_prices_view_simple_get(client): +def test_pool_prices_view_simple_get(client: Client) -> None: """No errors are thrown in template""" cfg = create_config(name="config #1", ec2_instance_types=["c4.2xlarge"]) pool = create_pool(config=cfg) @@ -117,7 +126,7 @@ def test_pool_prices_view_simple_get(client): assert response.status_code == requests.codes["ok"] -def test_delete_pool_view_simple_get(client): +def test_delete_pool_view_simple_get(client: Client) -> None: """No errors are thrown in template""" cfg = create_config(name="config #1") pool = create_pool(config=cfg) @@ -127,7 +136,7 @@ def test_delete_pool_view_simple_get(client): assert response.status_code == requests.codes["ok"] -def test_enable_pool_view_simple_get(client): +def test_enable_pool_view_simple_get(client: Client) -> None: """No errors are thrown in template""" cfg = create_config(name="config #1") pool = create_pool(config=cfg) @@ -139,7 +148,7 @@ def test_enable_pool_view_simple_get(client): assert response.status_code == requests.codes["ok"] -def test_disable_pool_view_simple_get(client): +def test_disable_pool_view_simple_get(client: Client) -> None: """No errors are thrown in template""" cfg = create_config(name="config #1") pool = create_pool(config=cfg) @@ -151,7 +160,7 @@ def test_disable_pool_view_simple_get(client): assert response.status_code == requests.codes["ok"] -def test_cycle_pool_view_simple_get(client): +def test_cycle_pool_view_simple_get(client: Client) -> None: """No errors are thrown in template""" cfg = create_config(name="config #1") pool = create_pool(config=cfg) @@ -163,7 +172,7 @@ def test_cycle_pool_view_simple_get(client): assert response.status_code == requests.codes["ok"] -def test_delete_pool_message_view_simple_get(client): +def test_delete_pool_message_view_simple_get(client: Client) -> None: """No errors are thrown in template""" cfg = create_config(name="config #1") pool = create_pool(config=cfg) diff --git a/server/ec2spotmanager/tests/test_pools_rest.py b/server/ec2spotmanager/tests/test_pools_rest.py index f109524c4..5f428fedf 100644 --- a/server/ec2spotmanager/tests/test_pools_rest.py +++ b/server/ec2spotmanager/tests/test_pools_rest.py @@ -9,6 +9,9 @@ License, v. 2.0. If a copy of the MPL was not distributed with this file, You can obtain one at http://mozilla.org/MPL/2.0/. """ + +from __future__ import annotations + import json import logging @@ -17,6 +20,7 @@ from django.contrib.auth.models import User from django.urls import reverse from django.utils import timezone +from rest_framework.test import APIClient from ec2spotmanager.models import InstancePool @@ -26,7 +30,7 @@ pytestmark = pytest.mark.usefixtures("ec2spotmanager_test") -def test_rest_pool_cycle_no_auth(api_client): +def test_rest_pool_cycle_no_auth(api_client: APIClient) -> None: """must yield forbidden without authentication""" url = "/ec2spotmanager/rest/pool/1/cycle/" assert api_client.get(url).status_code == requests.codes["unauthorized"] @@ -36,7 +40,7 @@ def test_rest_pool_cycle_no_auth(api_client): assert api_client.delete(url, {}).status_code == requests.codes["unauthorized"] -def test_rest_pool_cycle_no_perm(api_client): +def test_rest_pool_cycle_no_perm(api_client: APIClient) -> None: """must yield forbidden without permission""" user = User.objects.get(username="test-noperm") api_client.force_authenticate(user=user) @@ -48,7 +52,7 @@ def test_rest_pool_cycle_no_perm(api_client): assert api_client.delete(url, {}).status_code == requests.codes["forbidden"] -def test_rest_pool_cycle_patch(api_client): +def test_rest_pool_cycle_patch(api_client: APIClient) -> None: """patch should not be allowed""" user = User.objects.get(username="test") api_client.force_authenticate(user=user) @@ -57,7 +61,7 @@ def test_rest_pool_cycle_patch(api_client): assert resp.status_code == requests.codes["method_not_allowed"] -def test_rest_pool_cycle_post(api_client): +def test_rest_pool_cycle_post(api_client: APIClient) -> None: """post should reset last_cycled""" config = create_config("testconfig") pool = create_pool(config, last_cycled=timezone.now()) @@ -76,7 +80,7 @@ def test_rest_pool_cycle_post(api_client): assert pool.last_cycled is None -def test_rest_pool_cycle_put(api_client): +def test_rest_pool_cycle_put(api_client: APIClient) -> None: """put should not be allowed""" user = User.objects.get(username="test") api_client.force_authenticate(user=user) @@ -85,7 +89,7 @@ def test_rest_pool_cycle_put(api_client): assert resp.status_code == requests.codes["method_not_allowed"] -def test_rest_pool_cycle_delete(api_client): +def test_rest_pool_cycle_delete(api_client: APIClient) -> None: """delete should not be allowed""" user = User.objects.get(username="test") api_client.force_authenticate(user=user) @@ -94,7 +98,7 @@ def test_rest_pool_cycle_delete(api_client): assert resp.status_code == requests.codes["method_not_allowed"] -def test_rest_pool_cycle_get(api_client): +def test_rest_pool_cycle_get(api_client: APIClient) -> None: """get should not be allowed""" user = User.objects.get(username="test") api_client.force_authenticate(user=user) @@ -103,7 +107,7 @@ def test_rest_pool_cycle_get(api_client): assert resp.status_code == requests.codes["method_not_allowed"] -def test_rest_pool_enable_no_auth(api_client): +def test_rest_pool_enable_no_auth(api_client: APIClient) -> None: """must yield forbidden without authentication""" url = "/ec2spotmanager/rest/pool/1/enable/" assert api_client.get(url).status_code == requests.codes["unauthorized"] @@ -113,7 +117,7 @@ def test_rest_pool_enable_no_auth(api_client): assert api_client.delete(url, {}).status_code == requests.codes["unauthorized"] -def test_rest_pool_enable_no_perm(api_client): +def test_rest_pool_enable_no_perm(api_client: APIClient) -> None: """must yield forbidden without permission""" user = User.objects.get(username="test-noperm") api_client.force_authenticate(user=user) @@ -125,7 +129,7 @@ def test_rest_pool_enable_no_perm(api_client): assert api_client.delete(url, {}).status_code == requests.codes["forbidden"] -def test_rest_pool_enable_patch(api_client): +def test_rest_pool_enable_patch(api_client: APIClient) -> None: """patch should not be allowed""" user = User.objects.get(username="test") api_client.force_authenticate(user=user) @@ -134,7 +138,7 @@ def test_rest_pool_enable_patch(api_client): assert resp.status_code == requests.codes["method_not_allowed"] -def test_rest_pool_enable_post(api_client): +def test_rest_pool_enable_post(api_client: APIClient) -> None: """post should flip isEnabled""" config = create_config("testconfig") pool = create_pool(config) @@ -152,7 +156,7 @@ def test_rest_pool_enable_post(api_client): assert pool.isEnabled -def test_rest_pool_enable_put(api_client): +def test_rest_pool_enable_put(api_client: APIClient) -> None: """put should not be allowed""" user = User.objects.get(username="test") api_client.force_authenticate(user=user) @@ -161,7 +165,7 @@ def test_rest_pool_enable_put(api_client): assert resp.status_code == requests.codes["method_not_allowed"] -def test_rest_pool_enable_delete(api_client): +def test_rest_pool_enable_delete(api_client: APIClient) -> None: """delete should not be allowed""" user = User.objects.get(username="test") api_client.force_authenticate(user=user) @@ -170,7 +174,7 @@ def test_rest_pool_enable_delete(api_client): assert resp.status_code == requests.codes["method_not_allowed"] -def test_rest_pool_enable_get(api_client): +def test_rest_pool_enable_get(api_client: APIClient) -> None: """get should not be allowed""" user = User.objects.get(username="test") api_client.force_authenticate(user=user) @@ -179,7 +183,7 @@ def test_rest_pool_enable_get(api_client): assert resp.status_code == requests.codes["method_not_allowed"] -def test_rest_pool_disable_no_auth(api_client): +def test_rest_pool_disable_no_auth(api_client: APIClient) -> None: """must yield forbidden without authentication""" url = "/ec2spotmanager/rest/pool/1/disable/" assert api_client.get(url).status_code == requests.codes["unauthorized"] @@ -189,7 +193,7 @@ def test_rest_pool_disable_no_auth(api_client): assert api_client.delete(url, {}).status_code == requests.codes["unauthorized"] -def test_rest_pool_disable_no_perm(api_client): +def test_rest_pool_disable_no_perm(api_client: APIClient) -> None: """must yield forbidden without permission""" user = User.objects.get(username="test-noperm") api_client.force_authenticate(user=user) @@ -201,7 +205,7 @@ def test_rest_pool_disable_no_perm(api_client): assert api_client.delete(url, {}).status_code == requests.codes["forbidden"] -def test_rest_pool_disable_patch(api_client): +def test_rest_pool_disable_patch(api_client: APIClient) -> None: """patch should not be allowed""" user = User.objects.get(username="test") api_client.force_authenticate(user=user) @@ -210,7 +214,7 @@ def test_rest_pool_disable_patch(api_client): assert resp.status_code == requests.codes["method_not_allowed"] -def test_rest_pool_disable_post(api_client): +def test_rest_pool_disable_post(api_client: APIClient) -> None: """post should flip isEnabled""" config = create_config("testconfig") pool = create_pool(config, enabled=True) @@ -228,7 +232,7 @@ def test_rest_pool_disable_post(api_client): assert not pool.isEnabled -def test_rest_pool_disable_put(api_client): +def test_rest_pool_disable_put(api_client: APIClient) -> None: """put should not be allowed""" user = User.objects.get(username="test") api_client.force_authenticate(user=user) @@ -237,7 +241,7 @@ def test_rest_pool_disable_put(api_client): assert resp.status_code == requests.codes["method_not_allowed"] -def test_rest_pool_disable_delete(api_client): +def test_rest_pool_disable_delete(api_client: APIClient) -> None: """delete should not be allowed""" user = User.objects.get(username="test") api_client.force_authenticate(user=user) @@ -246,7 +250,7 @@ def test_rest_pool_disable_delete(api_client): assert resp.status_code == requests.codes["method_not_allowed"] -def test_rest_pool_disable_get(api_client): +def test_rest_pool_disable_get(api_client: APIClient) -> None: """get should not be allowed""" user = User.objects.get(username="test") api_client.force_authenticate(user=user) @@ -258,7 +262,7 @@ def test_rest_pool_disable_get(api_client): @pytest.mark.xfail class TestRestPoolChartDetailed: @staticmethod - def test_rest_pool_chart_detailed_no_auth(api_client): + def test_rest_pool_chart_detailed_no_auth(api_client: APIClient) -> None: """must yield forbidden without authentication""" url = reverse("ec2spotmanager:line_chart_json_detailed", kwargs={"poolid": 1}) assert api_client.get(url).status_code == requests.codes["unauthorized"] @@ -268,7 +272,7 @@ def test_rest_pool_chart_detailed_no_auth(api_client): assert api_client.delete(url, {}).status_code == requests.codes["unauthorized"] @staticmethod - def test_rest_pool_chart_detailed_no_perm(api_client): + def test_rest_pool_chart_detailed_no_perm(api_client: APIClient) -> None: """must yield forbidden without permission""" user = User.objects.get(username="test-noperm") api_client.force_authenticate(user=user) @@ -280,7 +284,7 @@ def test_rest_pool_chart_detailed_no_perm(api_client): assert api_client.delete(url, {}).status_code == requests.codes["forbidden"] @staticmethod - def test_rest_pool_chart_detailed_patch(api_client): + def test_rest_pool_chart_detailed_patch(api_client: APIClient) -> None: """patch should not be allowed""" user = User.objects.get(username="test") api_client.force_authenticate(user=user) @@ -291,7 +295,7 @@ def test_rest_pool_chart_detailed_patch(api_client): assert resp.status_code == requests.codes["method_not_allowed"] @staticmethod - def test_rest_pool_chart_detailed_post(api_client): + def test_rest_pool_chart_detailed_post(api_client: APIClient) -> None: """post should not be allowed""" user = User.objects.get(username="test") api_client.force_authenticate(user=user) @@ -302,7 +306,7 @@ def test_rest_pool_chart_detailed_post(api_client): assert resp.status_code == requests.codes["method_not_allowed"] @staticmethod - def test_rest_pool_chart_detailed_put(api_client): + def test_rest_pool_chart_detailed_put(api_client: APIClient) -> None: """put should not be allowed""" user = User.objects.get(username="test") api_client.force_authenticate(user=user) @@ -313,7 +317,7 @@ def test_rest_pool_chart_detailed_put(api_client): assert resp.status_code == requests.codes["method_not_allowed"] @staticmethod - def test_rest_pool_chart_detailed_delete(api_client): + def test_rest_pool_chart_detailed_delete(api_client: APIClient) -> None: """delete should not be allowed""" user = User.objects.get(username="test") api_client.force_authenticate(user=user) @@ -324,7 +328,7 @@ def test_rest_pool_chart_detailed_delete(api_client): assert resp.status_code == requests.codes["method_not_allowed"] @staticmethod - def test_rest_pool_chart_detailed_get(api_client): + def test_rest_pool_chart_detailed_get(api_client: APIClient) -> None: """get should not be allowed""" pool = create_pool(create_config("testconfig", size=1)) user = User.objects.get(username="test") @@ -343,7 +347,7 @@ def test_rest_pool_chart_detailed_get(api_client): @pytest.mark.xfail class TestRestPoolChartAccumulated: @staticmethod - def test_rest_pool_chart_accumulated_no_auth(api_client): + def test_rest_pool_chart_accumulated_no_auth(api_client: APIClient) -> None: """must yield forbidden without authentication""" url = reverse( "ec2spotmanager:line_chart_json_accumulated", kwargs={"poolid": 1} @@ -355,7 +359,7 @@ def test_rest_pool_chart_accumulated_no_auth(api_client): assert api_client.delete(url, {}).status_code == requests.codes["unauthorized"] @staticmethod - def test_rest_pool_chart_accumulated_no_perm(api_client): + def test_rest_pool_chart_accumulated_no_perm(api_client: APIClient) -> None: """must yield forbidden without permission""" user = User.objects.get(username="test-noperm") api_client.force_authenticate(user=user) @@ -369,7 +373,7 @@ def test_rest_pool_chart_accumulated_no_perm(api_client): assert api_client.delete(url, {}).status_code == requests.codes["forbidden"] @staticmethod - def test_rest_pool_chart_accumulated_patch(api_client): + def test_rest_pool_chart_accumulated_patch(api_client: APIClient) -> None: """patch should not be allowed""" user = User.objects.get(username="test") api_client.force_authenticate(user=user) @@ -380,7 +384,7 @@ def test_rest_pool_chart_accumulated_patch(api_client): assert resp.status_code == requests.codes["method_not_allowed"] @staticmethod - def test_rest_pool_chart_accumulated_post(api_client): + def test_rest_pool_chart_accumulated_post(api_client: APIClient) -> None: """post should not be allowed""" user = User.objects.get(username="test") api_client.force_authenticate(user=user) @@ -391,7 +395,7 @@ def test_rest_pool_chart_accumulated_post(api_client): assert resp.status_code == requests.codes["method_not_allowed"] @staticmethod - def test_rest_pool_chart_accumulated_put(api_client): + def test_rest_pool_chart_accumulated_put(api_client: APIClient) -> None: """put should not be allowed""" user = User.objects.get(username="test") api_client.force_authenticate(user=user) @@ -402,7 +406,7 @@ def test_rest_pool_chart_accumulated_put(api_client): assert resp.status_code == requests.codes["method_not_allowed"] @staticmethod - def test_rest_pool_chart_accumulated_delete(api_client): + def test_rest_pool_chart_accumulated_delete(api_client: APIClient) -> None: """delete should not be allowed""" user = User.objects.get(username="test") api_client.force_authenticate(user=user) @@ -413,7 +417,7 @@ def test_rest_pool_chart_accumulated_delete(api_client): assert resp.status_code == requests.codes["method_not_allowed"] @staticmethod - def test_rest_pool_chart_accumulated_get(api_client): + def test_rest_pool_chart_accumulated_get(api_client: APIClient) -> None: """get should be allowed""" pool = create_pool(create_config("testconfig")) user = User.objects.get(username="test") diff --git a/server/ec2spotmanager/tests/test_status_rest.py b/server/ec2spotmanager/tests/test_status_rest.py index cb72e9df6..076848b7e 100644 --- a/server/ec2spotmanager/tests/test_status_rest.py +++ b/server/ec2spotmanager/tests/test_status_rest.py @@ -9,12 +9,16 @@ License, v. 2.0. If a copy of the MPL was not distributed with this file, You can obtain one at http://mozilla.org/MPL/2.0/. """ + +from __future__ import annotations + import json import logging import pytest import requests from django.contrib.auth.models import User +from rest_framework.test import APIClient from ec2spotmanager.models import Instance @@ -24,7 +28,7 @@ pytestmark = pytest.mark.usefixtures("ec2spotmanager_test") -def test_rest_status_no_auth(api_client): +def test_rest_status_no_auth(api_client: APIClient) -> None: """must yield forbidden without authentication""" url = "/ec2spotmanager/rest/report/" assert api_client.get(url).status_code == requests.codes["unauthorized"] @@ -34,7 +38,7 @@ def test_rest_status_no_auth(api_client): assert api_client.delete(url, {}).status_code == requests.codes["unauthorized"] -def test_rest_status_no_perm(api_client): +def test_rest_status_no_perm(api_client: APIClient) -> None: """must yield forbidden without permission""" user = User.objects.get(username="test-noperm") api_client.force_authenticate(user=user) @@ -46,7 +50,7 @@ def test_rest_status_no_perm(api_client): assert api_client.delete(url, {}).status_code == requests.codes["forbidden"] -def test_rest_status_get(api_client): +def test_rest_status_get(api_client: APIClient) -> None: """get always returns an empty object""" user = User.objects.get(username="test") api_client.force_authenticate(user=user) @@ -56,16 +60,16 @@ def test_rest_status_get(api_client): assert resp == {} -def test_rest_status_report(api_client): +def test_rest_status_report(api_client: APIClient) -> None: """post should update the status field on the instance""" user = User.objects.get(username="test") api_client.force_authenticate(user=user) host = create_instance("host1") - resp = api_client.post( + resp_response = api_client.post( "/ec2spotmanager/rest/report/", {"client": "host1", "status_data": "data"} ) - assert resp.status_code == requests.codes["created"] - resp = json.loads(resp.content.decode("utf-8")) + assert resp_response.status_code == requests.codes["created"] + resp = json.loads(resp_response.content.decode("utf-8")) assert resp == {"status_data": "data"} host = Instance.objects.get(pk=host.pk) # re-read assert host.status_data == "data" @@ -79,17 +83,17 @@ def test_rest_status_report(api_client): assert resp.status_code == requests.codes["not_found"] -def test_rest_status_report2(api_client): +def test_rest_status_report2(api_client: APIClient) -> None: """post should update the status field on the instance""" user = User.objects.get(username="test") api_client.force_authenticate(user=user) host1 = create_instance("host1") host2 = create_instance("host2") - resp = api_client.post( + resp_response = api_client.post( "/ec2spotmanager/rest/report/", {"client": "host1", "status_data": "data"} ) - assert resp.status_code == requests.codes["created"] - resp = json.loads(resp.content.decode("utf-8")) + assert resp_response.status_code == requests.codes["created"] + resp = json.loads(resp_response.content.decode("utf-8")) assert resp == {"status_data": "data"} host1 = Instance.objects.get(pk=host1.pk) # re-read assert host1.status_data == "data" @@ -105,7 +109,7 @@ def test_rest_status_report2(api_client): assert host1.status_data == "data" -def test_rest_status_put(api_client): +def test_rest_status_put(api_client: APIClient) -> None: """put should not be allowed""" user = User.objects.get(username="test") api_client.force_authenticate(user=user) @@ -113,7 +117,7 @@ def test_rest_status_put(api_client): assert resp.status_code == requests.codes["method_not_allowed"] -def test_rest_status_delete(api_client): +def test_rest_status_delete(api_client: APIClient) -> None: """delete should not be allowed""" user = User.objects.get(username="test") api_client.force_authenticate(user=user) @@ -121,7 +125,7 @@ def test_rest_status_delete(api_client): assert resp.status_code == requests.codes["method_not_allowed"] -def test_rest_status_patch(api_client): +def test_rest_status_patch(api_client: APIClient) -> None: """patch should not be allowed""" user = User.objects.get(username="test") api_client.force_authenticate(user=user) diff --git a/server/ec2spotmanager/tests/test_task_graph.py b/server/ec2spotmanager/tests/test_task_graph.py index 245cc696a..06d90936e 100644 --- a/server/ec2spotmanager/tests/test_task_graph.py +++ b/server/ec2spotmanager/tests/test_task_graph.py @@ -9,10 +9,14 @@ License, v. 2.0. If a copy of the MPL was not distributed with this file, You can obtain one at http://mozilla.org/MPL/2.0/. """ + +from __future__ import annotations + import logging from unittest.mock import call import pytest +from pytest_mock import MockerFixture from ec2spotmanager.CloudProvider.CloudProvider import INSTANCE_STATE from ec2spotmanager.cron import check_instance_pools @@ -26,7 +30,7 @@ ) # pylint: disable=invalid-name -def test_update_pool_graph(mocker): +def test_update_pool_graph(mocker: MockerFixture) -> None: mock_group = mocker.patch("celery.group") mock_chain = mocker.patch("celery.chain") mock_chord = mocker.patch("celery.chord") @@ -195,7 +199,7 @@ def test_update_pool_graph(mocker): assert mock_chain.return_value.on_error.return_value.call_args == call() -def test_update_pool_graph_unsupported_running(mocker): +def test_update_pool_graph_unsupported_running(mocker: MockerFixture) -> None: """check that unsupported but running instances are still updated eg. if a config is edited to exclude a provider, but there are already instances. we should still update them. @@ -308,7 +312,7 @@ def test_update_pool_graph_unsupported_running(mocker): assert mock_chain.return_value.on_error.return_value.call_args == call() -def test_terminate_instances(mocker): +def test_terminate_instances(mocker: MockerFixture) -> None: """test that terminate instances triggers the appropriate subtasks""" mock_group = mocker.patch("celery.group") mock_term_instance = mocker.patch("ec2spotmanager.tasks._terminate_instance_ids") diff --git a/server/ec2spotmanager/tests/test_task_status.py b/server/ec2spotmanager/tests/test_task_status.py index e5869a251..bfb66b7a5 100644 --- a/server/ec2spotmanager/tests/test_task_status.py +++ b/server/ec2spotmanager/tests/test_task_status.py @@ -9,6 +9,9 @@ License, v. 2.0. If a copy of the MPL was not distributed with this file, You can obtain one at http://mozilla.org/MPL/2.0/. """ + +from __future__ import annotations + import pytest from ec2spotmanager.models import PoolStatusEntry, ProviderStatusEntry @@ -19,7 +22,7 @@ pytestmark = pytest.mark.usefixtures("ec2spotmanager_test") -def test_update_pool_status(): +def test_update_pool_status() -> None: """test that update_pool_status utility function works""" config = create_config( name="config #1", @@ -41,7 +44,7 @@ def test_update_pool_status(): assert not entry.isCritical -def test_update_provider_status(): +def test_update_provider_status() -> None: """test that update_provider_status utility function works""" _update_provider_status("EC2Spot", "price-too-low", "testing") entry = ProviderStatusEntry.objects.get() diff --git a/server/ec2spotmanager/tests/test_tasks.py b/server/ec2spotmanager/tests/test_tasks.py index 77c90e92e..f2cea31e8 100644 --- a/server/ec2spotmanager/tests/test_tasks.py +++ b/server/ec2spotmanager/tests/test_tasks.py @@ -9,12 +9,18 @@ License, v. 2.0. If a copy of the MPL was not distributed with this file, You can obtain one at http://mozilla.org/MPL/2.0/. """ + +from __future__ import annotations + import datetime import logging +from collections.abc import Callable +from typing import Iterable import boto.ec2 import pytest from django.utils import timezone +from pytest_mock import MockerFixture from ec2spotmanager.CloudProvider.CloudProvider import ( INSTANCE_STATE, @@ -42,7 +48,7 @@ @pytest.mark.usefixtures("mock_provider") -def test_nothing_to_do(): +def test_nothing_to_do() -> None: """nothing is done if no pools are enabled""" config = create_config( @@ -62,7 +68,7 @@ def test_nothing_to_do(): assert not Instance.objects.exists() -def test_bad_config(): +def test_bad_config() -> None: """invalid configs create a pool status entry""" config = create_config(name="config #1") pool = create_pool(config=config) @@ -78,10 +84,10 @@ def test_bad_config(): assert not Instance.objects.exists() -def test_create_instance(mocker): +def test_create_instance(mocker: MockerFixture) -> None: """spot instance requests are created when required""" # set-up redis mock to return price data and image name - def _mock_redis_get(key): + def _mock_redis_get(key: str) -> bool | str | None: if ":blacklist:redmond:mshq:" in key: return True if ":blacklist:" in key: @@ -140,19 +146,21 @@ def _mock_redis_get(key): assert instance.instance_id == "req123" -def test_fulfilled_spot_instance(mocker): +def test_fulfilled_spot_instance(mocker: MockerFixture) -> None: """spot instance requests are turned into instances when fulfilled""" # ensure EC2Manager returns a request ID class _MockInstance(boto.ec2.instance.Instance): - def __init__(self, *args, **kwds): + def __init__(self, *args: str, **kwds: str) -> None: super().__init__(*args, **kwds) - self._test_tags = {} + self._test_tags: dict[str, str] = {} @property - def state_code(self): - return INSTANCE_STATE["running"] + def state_code(self) -> int: + return int(INSTANCE_STATE["running"]) - def add_tags(self, tags, dry_run=False): + def add_tags( + self, tags: Iterable[tuple[str, str]], dry_run: bool = False + ) -> None: self._test_tags.update(tags) boto_instance = _MockInstance() @@ -207,21 +215,23 @@ def add_tags(self, tags, dry_run=False): } # pylint: disable=protected-access -def test_instance_shutting_down(mocker): +def test_instance_shutting_down(mocker: MockerFixture) -> None: """instances are replaced when shut down or terminated""" # ensure EC2Manager returns a request ID class _MockInstance(boto.ec2.instance.Instance): @property - def state_code(self): - return INSTANCE_STATE["shutting-down"] + def state_code(self) -> int: + return int(INSTANCE_STATE["shutting-down"]) - def add_tags(self, _tags, _dry_run=False): + def add_tags( + self, _tags: Iterable[tuple[str, str]], _dry_run: bool = False + ) -> None: pass class _MockInstance2(_MockInstance): @property - def state_code(self): - return INSTANCE_STATE["terminated"] + def state_code(self) -> int: + return int(INSTANCE_STATE["terminated"]) boto_instance1 = _MockInstance() boto_instance1.id = "i-123" @@ -241,7 +251,7 @@ def state_code(self): mock_ec2mgr.return_value.find.return_value = (boto_instance1, boto_instance2) # set-up redis mock to return price data and image name - def _mock_redis_get(key): + def _mock_redis_get(key: str) -> str | None: if ":blacklist:" in key: return None if ":price:" in key: @@ -320,13 +330,13 @@ def _mock_redis_get(key): assert not remaining -def test_instance_not_updatable(mocker): +def test_instance_not_updatable(mocker: MockerFixture) -> None: """instances are not touched while they are not tagged Updatable""" # ensure EC2Manager returns a request ID class _MockInstance(boto.ec2.instance.Instance): @property - def state_code(self): - return INSTANCE_STATE["stopping"] + def state_code(self) -> int: + return int(INSTANCE_STATE["stopping"]) boto_instance = _MockInstance() boto_instance.id = "i-123" @@ -377,10 +387,10 @@ def state_code(self): assert count == 1 -def test_instance_price_high(mocker): +def test_instance_price_high(mocker: MockerFixture) -> None: """check that instances are not created if the price is too high""" # set-up redis mock to return price data and image name - def _mock_redis_get(key): + def _mock_redis_get(key: str) -> str | None: if ":blacklist:" in key: return None if ":price:" in key: @@ -430,14 +440,14 @@ def _mock_redis_get(key): assert not Instance.objects.exists() -def test_spot_instance_blacklist(mocker): +def test_spot_instance_blacklist(mocker: MockerFixture) -> None: """check that spot requests being cancelled will result in temporary blacklisting""" # ensure EC2Manager returns a request ID class _status_code: code = "instance-terminated-by-service" class _MockReq(boto.ec2.spotinstancerequest.SpotInstanceRequest): - def __init__(self, *args, **kwds): + def __init__(self, *args: str, **kwds: str) -> None: super().__init__(*args, **kwds) self.state = "cancelled" self.status = _status_code @@ -455,7 +465,7 @@ def __init__(self, *args, **kwds): mock_ec2mgr.return_value.check_spot_requests.return_value = (req,) # set-up redis mock to return price data and image name - def _mock_redis_get(key): + def _mock_redis_get(key: str) -> bool | str: if ":blacklist:" in key: return True if ":price:" in key: @@ -464,7 +474,7 @@ def _mock_redis_get(key): return "warp" raise UncatchableException(f"unhandle key in mock_get(): {key}") - def _mock_redis_set(key, value, ex=None): + def _mock_redis_set(key: str, value: str, ex: str | None = None) -> None: assert ":blacklist:redmond:mshq:" in key mock_redis = mocker.patch("redis.StrictRedis.from_url") @@ -522,7 +532,7 @@ def _mock_redis_set(key, value, ex=None): assert len(mock_redis.return_value.set.mock_calls) == 1 -def test_pool_disabled(mocker): +def test_pool_disabled(mocker: MockerFixture) -> None: """check that pool disabled results in running and pending instances being terminated""" # ensure EC2Manager returns a request ID @@ -579,7 +589,7 @@ def test_pool_disabled(mocker): mock_term_request.delay.assert_called_once_with("EC2Spot", "redmond", ["r-456"]) -def test_pool_trim(): +def test_pool_trim() -> None: """check that pool down-size trims older instances until we meet the requirement""" # create database state config = create_config( @@ -643,22 +653,24 @@ def test_pool_trim(): (_terminate_instance_request_ids, "cancel_requests"), ], ) -def test_terminate(mocker, term_task, provider_func): +def test_terminate( + mocker: MockerFixture, term_task: Callable[..., None], provider_func: str +) -> None: """check that terminate instances task works properly""" fake_provider_cls = mocker.patch("ec2spotmanager.tasks.CloudProvider") fake_provider = fake_provider_cls.get_instance.return_value = mocker.Mock() term_task("provider", "region", ["inst1", "inst2"]) fake_provider_cls.get_instance.assert_called_once_with("provider") - provider_func = getattr(fake_provider, provider_func) - provider_func.assert_called_once_with({"region": ["inst1", "inst2"]}) + provider_func_ = getattr(fake_provider, provider_func) + provider_func_.assert_called_once_with({"region": ["inst1", "inst2"]}) - provider_func.side_effect = CloudProviderTemporaryFailure("blah") + provider_func_.side_effect = CloudProviderTemporaryFailure("blah") with pytest.raises( CloudProviderTemporaryFailure, match=r"CloudProviderTemporaryFailure: blah \(temporary-failure\)", ): term_task("provider", "region", []) - provider_func.side_effect = Exception("blah") + provider_func_.side_effect = Exception("blah") with pytest.raises(Exception, match=r"blah"): term_task("provider", "region", []) diff --git a/server/ec2spotmanager/urls.py b/server/ec2spotmanager/urls.py index 7f0de353a..0f316c926 100644 --- a/server/ec2spotmanager/urls.py +++ b/server/ec2spotmanager/urls.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from django.conf.urls import include from django.urls import re_path from rest_framework import routers diff --git a/server/ec2spotmanager/views.py b/server/ec2spotmanager/views.py index 1c73a9e1c..1e12e6fd9 100644 --- a/server/ec2spotmanager/views.py +++ b/server/ec2spotmanager/views.py @@ -1,5 +1,8 @@ +from __future__ import annotations + import json from operator import attrgetter +from typing import Any, cast import fasteners import redis @@ -10,11 +13,15 @@ from django.core.files.base import ContentFile from django.db.models import Q from django.db.models.aggregates import Count, Sum +from django.http import HttpResponse +from django.http.request import HttpRequest from django.http.response import Http404 # noqa +from django.http.response import HttpResponsePermanentRedirect, HttpResponseRedirect from django.shortcuts import get_object_or_404, redirect, render from django.utils.timezone import now, timedelta from rest_framework import mixins, serializers, status, viewsets from rest_framework.authentication import SessionAuthentication, TokenAuthentication +from rest_framework.request import Request from rest_framework.response import Response from rest_framework.views import APIView @@ -39,17 +46,17 @@ from .serializers import MachineStatusSerializer, PoolConfigurationSerializer -def renderError(request, err): +def renderError(request: HttpRequest, err: str) -> HttpResponse: return render(request, "error.html", {"error_message": err}) @deny_restricted_users -def index(request): +def index(request: HttpRequest) -> HttpResponsePermanentRedirect | HttpResponseRedirect: return redirect("ec2spotmanager:pools") @deny_restricted_users -def pools(request): +def pools(request: HttpRequest) -> HttpResponse: filters = {} isSearch = True @@ -71,7 +78,10 @@ def pools(request): .order_by("config__name") ) # fetch all pool configs since most will be used by flatten later - configs = {cfg.id: cfg for cfg in PoolConfiguration.objects.all()} + configs = { + cast(int, getattr(cfg, "id", None)): cfg + for cfg in PoolConfiguration.objects.all() + } # These are all keys that are allowed for exact filtering exactFilterKeys = [ @@ -98,11 +108,11 @@ def pools(request): entry.msgs.append(status_entry) break - provider_msgs = {} + provider_msgs: dict[str, list[ProviderStatusEntry]] = {} for msg in ProviderStatusEntry.objects.all().order_by("-created"): provider_msgs.setdefault(msg.provider, []).append(msg) - provider_pools = {} + provider_pools: dict[str, set[int]] = {} for pool in entries: flattened_config = pool.config.flatten(configs) for provider in provider_msgs: @@ -137,7 +147,7 @@ def pools(request): @deny_restricted_users -def viewPool(request, poolid): +def viewPool(request: HttpRequest, poolid: int) -> HttpResponse: pool = get_object_or_404(InstancePool, pk=poolid) instances = Instance.objects.filter(pool=poolid) @@ -160,7 +170,7 @@ def viewPool(request, poolid): pool.msgs = PoolStatusEntry.objects.filter(pool=pool).order_by("-created") - provider_msgs = {} + provider_msgs: dict[str, list[ProviderStatusEntry]] = {} relevant_providers = {} for msg in ProviderStatusEntry.objects.all().order_by("-created"): # a status provider is relevant to this pool if it is supported by the config, @@ -194,7 +204,7 @@ def viewPool(request, poolid): @deny_restricted_users -def viewPoolPrices(request, poolid): +def viewPoolPrices(request: HttpRequest, poolid: int) -> HttpResponse: cache = redis.StrictRedis.from_url(settings.REDIS_URL) pool = get_object_or_404(InstancePool, pk=poolid) @@ -208,13 +218,13 @@ def viewPoolPrices(request, poolid): cores_per_instance = cloud_provider.get_cores_per_instance() allowed_regions = set(cloud_provider.get_allowed_regions(config)) zones = set() - latest_price_by_zone = {} + latest_price_by_zone: dict[str, int] = {} for instance_type in cloud_provider.get_instance_types(config): - prices = cache.get(f"{cloud_provider.get_name()}:price:{instance_type}") - if prices is None: + prices_ = cache.get(f"{cloud_provider.get_name()}:price:{instance_type}") + if prices_ is None: continue - prices = json.loads(prices) + prices = json.loads(prices_) for region in prices: if region not in allowed_regions: continue @@ -232,7 +242,7 @@ def viewPoolPrices(request, poolid): @deny_restricted_users -def disablePool(request, poolid): +def disablePool(request: HttpRequest, poolid: int) -> HttpResponse: pool = get_object_or_404(InstancePool, pk=poolid) if not pool.isEnabled: @@ -256,7 +266,7 @@ def disablePool(request, poolid): @deny_restricted_users -def enablePool(request, poolid): +def enablePool(request: HttpRequest, poolid: int) -> HttpResponse: pool = get_object_or_404(InstancePool, pk=poolid) # Safety check: Figure out if any parameters are missing @@ -302,7 +312,9 @@ def enablePool(request, poolid): @deny_restricted_users -def forceCyclePool(request, poolid): +def forceCyclePool( + request: HttpRequest, poolid: int +) -> HttpResponse | HttpResponsePermanentRedirect | HttpResponseRedirect: pool = get_object_or_404(InstancePool, pk=poolid) if not pool.isEnabled: @@ -324,10 +336,14 @@ def forceCyclePool(request, poolid): @deny_restricted_users -def forceCyclePoolsByConfig(request, configid): +def forceCyclePoolsByConfig( + request: HttpRequest, configid: int +) -> list[ + PoolConfiguration +] | HttpResponse | HttpResponsePermanentRedirect | HttpResponseRedirect: config = get_object_or_404(PoolConfiguration, pk=configid) - def recurse_get_dependent_configurations(config): + def recurse_get_dependent_configurations(config: PoolConfiguration) -> list[int]: config_pks = [config.pk] configs = PoolConfiguration.objects.filter(parent=config) @@ -361,7 +377,9 @@ def recurse_get_dependent_configurations(config): @deny_restricted_users -def createPool(request): +def createPool( + request: HttpRequest, +) -> HttpResponse | HttpResponsePermanentRedirect | HttpResponseRedirect: if request.method == "POST": pool = InstancePool() config = get_object_or_404(PoolConfiguration, pk=int(request.POST["config"])) @@ -376,11 +394,11 @@ def createPool(request): @deny_restricted_users -def viewConfigs(request): +def viewConfigs(request: HttpRequest) -> HttpResponse: configs = PoolConfiguration.objects.all() roots = configs.filter(parent=None) - def add_children(node): + def add_children(node: PoolConfiguration) -> None: node.children = [] children = configs.filter(parent=node) for child in children: @@ -396,7 +414,7 @@ def add_children(node): @deny_restricted_users -def viewConfig(request, configid): +def viewConfig(request: HttpRequest, configid: int) -> HttpResponse: config = get_object_or_404(PoolConfiguration, pk=configid) config.deserializeFields() @@ -404,7 +422,9 @@ def viewConfig(request, configid): return render(request, "config/view.html", {"config": config}) -def __handleConfigPOST(request, config): +def __handleConfigPOST( + request: HttpRequest, config: PoolConfiguration +) -> HttpResponsePermanentRedirect | HttpResponseRedirect: if int(request.POST["parent"]) < 0: config.parent = None else: @@ -572,10 +592,12 @@ def __handleConfigPOST(request, config): config.save() if request.POST["ec2_userdata"]: + assert config.ec2_userdata_file is not None if not config.ec2_userdata_file.name: config.ec2_userdata_file.save("default.sh", ContentFile("")) config.ec2_userdata = request.POST["ec2_userdata"] if request.POST.get("ec2_userdata_ff", "unix") == "unix": + assert config.ec2_userdata is not None config.ec2_userdata = config.ec2_userdata.replace("\r\n", "\n") config.storeTestAndSave() else: @@ -588,7 +610,7 @@ def __handleConfigPOST(request, config): @deny_restricted_users -def createConfig(request): +def createConfig(request: HttpRequest) -> HttpResponse: if request.method == "POST": config = PoolConfiguration() return __handleConfigPOST(request, config) @@ -619,7 +641,9 @@ def createConfig(request): @deny_restricted_users -def editConfig(request, configid): +def editConfig( + request: HttpRequest, configid: int +) -> HttpResponse | HttpResponsePermanentRedirect | HttpResponseRedirect: config = get_object_or_404(PoolConfiguration, pk=configid) config.deserializeFields() @@ -641,7 +665,9 @@ def editConfig(request, configid): @deny_restricted_users -def deletePool(request, poolid): +def deletePool( + request: HttpRequest, poolid: int +) -> HttpResponse | HttpResponsePermanentRedirect | HttpResponseRedirect: pool = get_object_or_404(InstancePool, pk=poolid) if pool.isEnabled: @@ -692,7 +718,9 @@ def deletePool(request, poolid): @deny_restricted_users -def deletePoolMsg(request, msgid, from_pool="0"): +def deletePoolMsg( + request: HttpRequest, msgid: int, from_pool: str | int = "0" +) -> HttpResponse | HttpResponsePermanentRedirect | HttpResponseRedirect: entry = get_object_or_404(PoolStatusEntry, pk=msgid) if request.method == "POST": from_pool = int(request.POST["from_pool"]) @@ -713,7 +741,9 @@ def deletePoolMsg(request, msgid, from_pool="0"): @deny_restricted_users -def deleteProviderMsg(request, msgid): +def deleteProviderMsg( + request: HttpRequest, msgid: int +) -> HttpResponse | HttpResponsePermanentRedirect | HttpResponseRedirect: entry = get_object_or_404(ProviderStatusEntry, pk=msgid) if request.method == "POST": entry.delete() @@ -725,7 +755,9 @@ def deleteProviderMsg(request, msgid): @deny_restricted_users -def deleteConfig(request, configid): +def deleteConfig( + request: HttpRequest, configid: int +) -> HttpResponse | HttpResponsePermanentRedirect | HttpResponseRedirect: config = get_object_or_404(PoolConfiguration, pk=configid) pools = InstancePool.objects.filter(config=config) @@ -780,7 +812,7 @@ def deleteConfig(request, configid): class UptimeChartViewDetailed(JSONView): authentication_classes = (SessionAuthentication,) - def get_context_data(self, **kwargs): + def get_context_data(self, **kwargs: Any): context = super().get_context_data(**kwargs) pool = InstancePool.objects.get(pk=int(kwargs["poolid"])) pool.flat_config = pool.config.flatten() @@ -802,7 +834,7 @@ def get_context_data(self, **kwargs): def get_colors(self): return next_color() - def get_data_colors(self, entries): + def get_data_colors(self, entries) -> list[str]: colors = [] red = (204, 0, 0) yellow = (255, 204, 0) @@ -818,7 +850,7 @@ def get_data_colors(self, entries): return colors - def get_options(self, pool, entries): + def get_options(self, pool, entries) -> dict[str, object]: if entries: scaleSteps = max(entries, key=attrgetter("target")).target + 1 else: @@ -833,7 +865,7 @@ def get_options(self, pool, entries): "barShowStroke": False, } - def get_datasets(self, pool, entries): + def get_datasets(self, pool: InstancePool, entries): datasets = [] color_generator = self.get_colors() color = tuple(next(color_generator)) @@ -851,14 +883,14 @@ def get_datasets(self, pool, entries): datasets.append(dataset) return datasets - def get_labels(self, pool, entries): + def get_labels(self, pool: InstancePool, entries) -> list[str]: return [x.created.strftime("%H:%M") for x in entries] class UptimeChartViewAccumulated(JSONView): authentication_classes = (SessionAuthentication,) - def get_context_data(self, **kwargs): + def get_context_data(self, **kwargs: Any): context = super().get_context_data(**kwargs) pool = InstancePool.objects.get(pk=int(kwargs["poolid"])) pool.flat_config = pool.config.flatten() @@ -880,7 +912,7 @@ def get_context_data(self, **kwargs): def get_colors(self): return next_color() - def get_data_colors(self, entries): + def get_data_colors(self, entries) -> list[str]: colors = [] red = (204, 0, 0) orange = (255, 128, 0) @@ -899,7 +931,7 @@ def get_data_colors(self, entries): return colors - def get_options(self, pool, entries): + def get_options(self, pool: InstancePool, entries) -> dict[str, object]: # Scale to 100% but use 110 so the red bar is actually visible scaleSteps = 11 return { @@ -913,7 +945,7 @@ def get_options(self, pool, entries): "barShowStroke": False, } - def get_datasets(self, pool, entries): + def get_datasets(self, pool: InstancePool, entries): datasets = [] color_generator = self.get_colors() color = tuple(next(color_generator)) @@ -931,19 +963,19 @@ def get_datasets(self, pool, entries): datasets.append(dataset) return datasets - def get_labels(self, pool, entries): + def get_labels(self, pool: InstancePool, entries) -> list[str]: return [x.created.strftime("%b %d") for x in entries] class MachineStatusViewSet(APIView): authentication_classes = (TokenAuthentication,) - def get(self, request, *args, **kwargs): - result = {} + def get(self, request: Request, *args: Any, **kwargs: Any) -> Response: + result: dict[str, object] = {} response = Response(result, status=status.HTTP_200_OK) return response - def post(self, request, *args, **kwargs): + def post(self, request: Request, *args: Any, **kwargs: Any) -> Response: if "client" not in request.data: return Response( {"error": '"client" is required.'}, status=status.HTTP_400_BAD_REQUEST @@ -971,8 +1003,8 @@ class PoolConfigurationViewSet( queryset = PoolConfiguration.objects.all() serializer_class = PoolConfigurationSerializer - def retrieve(self, request, *args, **kwds): - flatten = request.query_params.get("flatten", "0") + def retrieve(self, request: Request, *args: Any, **kwds: Any) -> Response: + flatten: str | int = request.query_params.get("flatten", "0") try: flatten = int(flatten) assert flatten in {0, 1} @@ -986,7 +1018,9 @@ class PoolCycleView(APIView): authentication_classes = (TokenAuthentication,) permission_classes = (CheckAppPermission,) - def post(self, request, poolid, format=None): + def post( + self, request: Request, poolid: int, format: str | None = None + ) -> Response: pool = get_object_or_404(InstancePool, pk=poolid) if not pool.isEnabled: @@ -1004,7 +1038,9 @@ class PoolEnableView(APIView): authentication_classes = (TokenAuthentication,) permission_classes = (CheckAppPermission,) - def post(self, request, poolid, format=None): + def post( + self, request: Request, poolid: int, format: str | None = None + ) -> Response: pool = get_object_or_404(InstancePool, pk=poolid) if pool.isEnabled: @@ -1024,7 +1060,9 @@ class PoolDisableView(APIView): authentication_classes = (TokenAuthentication,) permission_classes = (CheckAppPermission,) - def post(self, request, poolid, format=None): + def post( + self, request: Request, poolid: int, format: str | None = None + ) -> Response: pool = get_object_or_404(InstancePool, pk=poolid) if not pool.isEnabled: diff --git a/server/manage.py b/server/manage.py index 194423238..6d4f11750 100644 --- a/server/manage.py +++ b/server/manage.py @@ -1,4 +1,7 @@ #!/usr/bin/env python3 + +from __future__ import annotations + import os import sys diff --git a/server/server/auth.py b/server/server/auth.py index f611f2983..13083e170 100644 --- a/server/server/auth.py +++ b/server/server/auth.py @@ -1,22 +1,26 @@ # This is code for Mozilla's 2FA using OID. If you have your own OID provider, # you can probably use similar code to get 2FA for your FuzzManager instance. +from __future__ import annotations + import unicodedata from django.conf import settings from rest_framework import permissions +from rest_framework.request import Request +from rest_framework.views import APIView if getattr(settings, "USE_OIDC", False): from mozilla_django_oidc.auth import OIDCAuthenticationBackend - def generate_username(email): + def generate_username(email: str) -> str: # Using Python 3 and Django 1.11, usernames can contain alphanumeric # (ascii and unicode), _, @, +, . and - characters. So we normalize # it and slice at 150 characters. return unicodedata.normalize("NFKC", email)[:150] class FMOIDCAB(OIDCAuthenticationBackend): - def verify_claims(self, claims): + def verify_claims(self, claims: dict[str, str]) -> bool: verified = super().verify_claims(claims) if not verified: @@ -31,7 +35,7 @@ class CheckAppPermission(permissions.BasePermission): Check that user has permission to view this app, whether via REST or web UI. """ - def has_permission(self, request, view): + def has_permission(self, request: Request, view: APIView) -> bool: if request.user and request.user.is_authenticated: app = view.__module__.split(".", 1)[0] diff --git a/server/server/middleware.py b/server/server/middleware.py index 682a78c53..a6bbd9d0e 100644 --- a/server/server/middleware.py +++ b/server/server/middleware.py @@ -1,9 +1,16 @@ +from __future__ import annotations + import re import traceback +from types import TracebackType +from typing import Any, Callable from django.conf import settings from django.contrib.auth.decorators import login_required from django.http import HttpResponseForbidden +from django.http.request import HttpRequest +from rest_framework.request import Request +from rest_framework.views import APIView from crashmanager.models import User @@ -16,13 +23,19 @@ class ExceptionLoggingMiddleware: when running a Django instance with runserver.py """ - def __init__(self, get_response): + def __init__(self, get_response: Callable[..., Any]) -> None: self.get_response = get_response - def __call__(self, request): + def __call__(self, request: HttpRequest): return self.get_response(request) - def process_exception(self, request, exception): + def process_exception( + self, + request: HttpRequest, + exception: tuple[ + type[BaseException] | None, type[BaseException] | None, TracebackType | None + ], + ) -> None: print(traceback.format_exc()) return None @@ -44,16 +57,22 @@ class RequireLoginMiddleware: # Based on snippet from https://stackoverflow.com/a/46976284 # Docstring and original idea from https://stackoverflow.com/a/2164224 - def __init__(self, get_response): + def __init__(self, get_response: Callable[..., Any]) -> None: self.get_response = get_response self.exceptions = re.compile( "(" + "|".join(settings.LOGIN_REQUIRED_URLS_EXCEPTIONS) + ")" ) - def __call__(self, request): + def __call__(self, request: HttpRequest): return self.get_response(request) - def process_view(self, request, view_func, view_args, view_kwargs): + def process_view( + self, + request: HttpRequest, + view_func: Callable[..., Any], + view_args: Any, + view_kwargs: Any, + ) -> Any: # No need to process URLs if user already logged in if request.user.is_authenticated: return None @@ -67,16 +86,18 @@ def process_view(self, request, view_func, view_args, view_kwargs): class CheckAppPermissionsMiddleware: - def __init__(self, get_response): + def __init__(self, get_response: Callable[..., Any]) -> None: self.get_response = get_response self.exceptions = re.compile( "(" + "|".join(settings.LOGIN_REQUIRED_URLS_EXCEPTIONS) + ")" ) - def __call__(self, request): + def __call__(self, request: HttpRequest): return self.get_response(request) - def process_view(self, request, view_func, view_args, view_kwargs): + def process_view( + self, request: Request, view_func: APIView, view_args: Any, view_kwargs: Any + ) -> HttpResponseForbidden | None: # Get the app name app = view_func.__module__.split(".", 1)[0] diff --git a/server/server/settings.py b/server/server/settings.py index da420e9a4..060b4e74e 100644 --- a/server/server/settings.py +++ b/server/server/settings.py @@ -7,10 +7,14 @@ For the full list of settings and their values, see https://docs.djangoproject.com/en/1.6/ref/settings/ """ + +from __future__ import annotations + # Build paths inside the project like this: os.path.join(BASE_DIR, ...) import os -from django.conf import global_settings # noqa +from django.http.request import HttpRequest +from typing_extensions import TypedDict BASE_DIR = os.path.dirname(os.path.dirname(__file__)) @@ -35,7 +39,7 @@ # SECURITY WARNING: don't run with debug turned on in production! DEBUG = True -ALLOWED_HOSTS = [] +ALLOWED_HOSTS: list[str] = [] # Application definition @@ -78,9 +82,15 @@ ) +class ResolverContextProcessorObj(TypedDict): + app_name: str + namespace: str + url_name: str | None + + # We add a custom context processor to make our application name # and certain other variables available in all our templates -def resolver_context_processor(request): +def resolver_context_processor(request: HttpRequest) -> ResolverContextProcessorObj: return { "app_name": request.resolver_match.app_name, "namespace": request.resolver_match.namespace, diff --git a/server/server/settings_docker.py b/server/server/settings_docker.py index b4c403027..02f212647 100644 --- a/server/server/settings_docker.py +++ b/server/server/settings_docker.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from .settings import * # noqa # Run in production mode diff --git a/server/server/settings_nondebug.py b/server/server/settings_nondebug.py index 9a416e701..97de788e3 100644 --- a/server/server/settings_nondebug.py +++ b/server/server/settings_nondebug.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from .settings import * # noqa DEBUG = False diff --git a/server/server/settings_test.py b/server/server/settings_test.py index 380551a0a..706a199d3 100644 --- a/server/server/settings_test.py +++ b/server/server/settings_test.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from .settings import * # noqa TC_ROOT_URL = "" # must be set for taskmanager tests diff --git a/server/server/tests/__init__.py b/server/server/tests/__init__.py index 1e2913168..260c2203b 100644 --- a/server/server/tests/__init__.py +++ b/server/server/tests/__init__.py @@ -1,3 +1,7 @@ +from __future__ import annotations + +from typing import cast + import pytest from django.contrib.auth.models import Permission, User from django.contrib.contenttypes.models import ContentType @@ -9,7 +13,7 @@ @pytest.fixture -def fm_user(): +def fm_user() -> User: user = User.objects.create_user("fuzzmanager", "test@example.com", "test") content_type = ContentType.objects.get_for_model(CMUser) user.user_permissions.add( @@ -30,4 +34,4 @@ def fm_user(): token.save() user.token = token.key - return user + return cast(User, user) diff --git a/server/server/urls.py b/server/server/urls.py index 8ae5d1a0d..80eff2767 100644 --- a/server/server/urls.py +++ b/server/server/urls.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from django.conf import settings from django.conf.urls import include from django.conf.urls.static import static diff --git a/server/server/utils.py b/server/server/utils.py index f219e9af1..8711a9fe5 100644 --- a/server/server/utils.py +++ b/server/server/utils.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import logging import time import uuid @@ -18,7 +20,9 @@ class RedisLock: release the lock in an async chain. """ - def __init__(self, conn, name, unique_id=None): + def __init__( + self, conn: redis.Redis[bytes], name: str, unique_id: str | None = None + ) -> None: self.conn = conn self.name = name if unique_id is None: @@ -26,7 +30,9 @@ def __init__(self, conn, name, unique_id=None): else: self.unique_id = unique_id - def acquire(self, acquire_timeout=10, lock_expiry=None): + def acquire( + self, acquire_timeout: int = 10, lock_expiry: int | None = None + ) -> str | None: end = time.time() + acquire_timeout while time.time() < end: if self.conn.set(self.name, self.unique_id, ex=lock_expiry, nx=True): @@ -38,7 +44,7 @@ def acquire(self, acquire_timeout=10, lock_expiry=None): LOG.debug("Failed to acquire lock: %s(%s)", self.name, self.unique_id) return None - def release(self): + def release(self) -> bool: with self.conn.pipeline() as pipe: while True: try: diff --git a/server/server/views.py b/server/server/views.py index 33e99da53..4ea5d51e7 100644 --- a/server/server/views.py +++ b/server/server/views.py @@ -1,20 +1,35 @@ +from __future__ import annotations + import collections import functools import json +from typing import Any, TypeVar, cast from django.conf import settings from django.contrib.auth.views import LoginView from django.core.exceptions import PermissionDenied -from django.core.paginator import EmptyPage, PageNotAnInteger, Paginator -from django.db.models import Q -from django.shortcuts import redirect +from django.core.paginator import EmptyPage, Page, PageNotAnInteger, Paginator +from django.db.models import Model, Q +from django.db.models.query import QuerySet +from django.http.request import HttpRequest +from django.http.response import ( + HttpResponse, + HttpResponsePermanentRedirect, + HttpResponseRedirect, +) +from django.shortcuts import redirect, render from django.urls import resolve, reverse from rest_framework import filters +from rest_framework.request import Request +from rest_framework.views import APIView from crashmanager.models import User +from server.covmanager.models import Collection + +MT = TypeVar("MT", bound=Model) -def index(request): +def index(request: HttpRequest) -> HttpResponseRedirect | HttpResponsePermanentRedirect: user = User.get_or_create_restricted(request.user)[0] # return crashmanager, covmanager, or ec2spotmanager, as allowed, in that order. # if no permission to view any apps, then use crashmanager and let that fail @@ -26,16 +41,18 @@ def index(request): return redirect("crashmanager:index") -def login(request): +def login(request: HttpRequest): if settings.USE_OIDC: auth_view = resolve(reverse("oidc_authentication_init")).func return auth_view(request) return LoginView.as_view()(request) -def deny_restricted_users(wrapped): +def deny_restricted_users( + wrapped: collections.abc.Callable[..., Any] +) -> collections.abc.Callable[..., Any]: @functools.wraps(wrapped) - def decorator(request, *args, **kwargs): + def decorator(request: HttpRequest, *args: Any, **kwargs: Any) -> Any: user = User.get_or_create_restricted(request.user)[0] if user.restricted: raise PermissionDenied( @@ -46,23 +63,27 @@ def decorator(request, *args, **kwargs): return decorator -def renderError(request, err): - return render(request, "error.html", {"error_message": err}) # noqa +def renderError(request: HttpRequest, err: str) -> HttpResponse: + return render(request, "error.html", {"error_message": err}) -def paginate_requested_list(request, entries): +def paginate_requested_list(request: HttpRequest, entries: QuerySet[Model]) -> Page: """ This method generically paginates a given QuerySet and returns a list suitable for passing to a template. The set is paginated by request parameters 'page' and 'page_size'. """ page_size = request.GET.get("page_size") - if not page_size: - page_size = 100 - paginator = Paginator(entries, page_size) + if page_size: + page_size_int = int(page_size) + else: + page_size_int = 100 + paginator = Paginator(entries, page_size_int) + page = request.GET.get("page") try: + assert page is not None page_entries = paginator.page(page) except PageNotAnInteger: # If page is not an integer, deliver first page. @@ -83,7 +104,7 @@ def paginate_requested_list(request, entries): return page_entries -def json_to_query(json_str): +def json_to_query(json_str: str): """ This method converts JSON objects into trees of Django Q objects. It can be used to provide the user the ability to perform complex @@ -112,7 +133,9 @@ def json_to_query(json_str): except ValueError as e: raise RuntimeError(f"Invalid JSON: {e}") - def get_query_obj(obj, key=None): + def get_query_obj( + obj: str | list[str] | int | dict[str, str] | None, key: str | None = None + ) -> Q: if obj is None or isinstance(obj, (str, list, int)): kwargs = {key: obj} @@ -157,7 +180,9 @@ class JsonQueryFilterBackend(filters.BaseFilterBackend): (see json_to_query) """ - def filter_queryset(self, request, queryset, view): + def filter_queryset( + self, request: Request, queryset: QuerySet[MT], view: APIView + ) -> QuerySet[MT]: """ Return a filtered queryset. """ @@ -177,7 +202,9 @@ class SimpleQueryFilterBackend(filters.BaseFilterBackend): "contains" searches """ - def filter_queryset(self, request, queryset, view): + def filter_queryset( + self, request: Request, queryset: QuerySet[MT], view: APIView + ) -> QuerySet[MT]: """ Return a filtered queryset. """ @@ -188,7 +215,7 @@ def filter_queryset(self, request, queryset, view): querystr = request.query_params.get("squery", None) if querystr is not None: queryobj = None - for field in queryset[0].simple_query_fields: + for field in cast(Collection, queryset[0]).simple_query_fields: kwargs = {f"{field}__contains": querystr} if queryobj is None: queryobj = Q(**kwargs) diff --git a/server/server/wsgi.py b/server/server/wsgi.py index a52c801b1..e1573fa17 100644 --- a/server/server/wsgi.py +++ b/server/server/wsgi.py @@ -7,6 +7,8 @@ https://docs.djangoproject.com/en/1.6/howto/deployment/wsgi/ """ +from __future__ import annotations + import os os.environ.setdefault("DJANGO_SETTINGS_MODULE", "server.settings") diff --git a/server/taskmanager/__init__.py b/server/taskmanager/__init__.py index ba00c6d1d..2e82c2ba6 100644 --- a/server/taskmanager/__init__.py +++ b/server/taskmanager/__init__.py @@ -1 +1,3 @@ +from __future__ import annotations + from . import tasks # noqa diff --git a/server/taskmanager/cron.py b/server/taskmanager/cron.py index aca135a18..49f77d10a 100644 --- a/server/taskmanager/cron.py +++ b/server/taskmanager/cron.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from datetime import datetime from logging import getLogger @@ -10,7 +12,7 @@ @app.task(ignore_result=True) -def update_tasks(): +def update_tasks() -> None: import taskcluster from .models import Task @@ -25,9 +27,9 @@ def update_tasks(): # normal, try to update the task directly from taskcluster task_status = {} - done = set() + done: set[tuple[int, int]] = set() - def _update_task_run(task_id, run_id): + def _update_task_run(task_id: int, run_id: int) -> None: if (task_id, run_id) in done: return @@ -71,7 +73,7 @@ def _update_task_run(task_id, run_id): @app.task(ignore_result=True) -def delete_expired(): +def delete_expired() -> None: from .models import Task # if the tasks no longer exist, or are expired, remove them from our DB too diff --git a/server/taskmanager/management/commands/taskmanager_change_poolid.py b/server/taskmanager/management/commands/taskmanager_change_poolid.py index a7681daf6..190c82267 100644 --- a/server/taskmanager/management/commands/taskmanager_change_poolid.py +++ b/server/taskmanager/management/commands/taskmanager_change_poolid.py @@ -1,4 +1,8 @@ +from __future__ import annotations + +from argparse import ArgumentParser from logging import getLogger +from typing import Any from django.core.management import BaseCommand @@ -10,7 +14,7 @@ class Command(BaseCommand): help = "Change a Taskmanager pool ID" - def add_arguments(self, parser): + def add_arguments(self, parser: ArgumentParser) -> None: parser.add_argument( "old", type=int, @@ -22,7 +26,7 @@ def add_arguments(self, parser): help="Original pool ID", ) - def handle(self, *args, **options): + def handle(self, *args: Any, **options: Any) -> None: old = options["old"] new = options["new"] diff --git a/server/taskmanager/management/commands/taskmanager_list_pools.py b/server/taskmanager/management/commands/taskmanager_list_pools.py index 602653056..1a4209925 100644 --- a/server/taskmanager/management/commands/taskmanager_list_pools.py +++ b/server/taskmanager/management/commands/taskmanager_list_pools.py @@ -1,3 +1,7 @@ +from __future__ import annotations + +from typing import Any + from django.core.management import BaseCommand from ...models import Pool @@ -6,7 +10,7 @@ class Command(BaseCommand): help = "List Taskmanager pools" - def handle(self, *args, **options): + def handle(self, *args: Any, **options: Any) -> None: for pool in Pool.objects.all(): if pool.pool_id != f"pool{pool.id}": print(f"!=: {pool.id} ({pool.pool_id})") diff --git a/server/taskmanager/management/commands/taskmanager_pulse_listen.py b/server/taskmanager/management/commands/taskmanager_pulse_listen.py index f018cfc35..9ed258962 100644 --- a/server/taskmanager/management/commands/taskmanager_pulse_listen.py +++ b/server/taskmanager/management/commands/taskmanager_pulse_listen.py @@ -1,5 +1,8 @@ +from __future__ import annotations + from logging import getLogger from pathlib import Path +from typing import Any from django.conf import settings from django.core.management import BaseCommand, CommandError # noqa @@ -11,7 +14,7 @@ class TaskClusterConsumer(GenericConsumer): - def __init__(self, **kwargs): + def __init__(self, **kwargs: Any) -> None: repo_slug = Path(settings.TC_FUZZING_CFG_REPO.split(":", 1)[1]) org = repo_slug.parent repo = repo_slug.stem @@ -42,7 +45,7 @@ class Command(BaseCommand): "and schedule celery tasks to handle them" ) - def callback(self, body, msg): + def callback(self, body, msg) -> None: if msg.delivery_info["exchange"].startswith( "exchange/taskcluster-queue/v1/task-" ): @@ -54,7 +57,7 @@ def callback(self, body, msg): ) update_task.delay(body) msg.ack() - return + return None if msg.delivery_info["exchange"] == "exchange/taskcluster-github/v1/push": LOG.info( "%s on %s for %s", @@ -65,16 +68,15 @@ def callback(self, body, msg): if body["body"]["ref"] == "refs/heads/master": update_pool_defns.delay() msg.ack() - return + return None raise RuntimeError( - "Unhandled message: %s on %s" - % ( + "Unhandled message: {} on {}".format( msg.delivery_info["routing_key"], msg.delivery_info["exchange"], ) ) - def handle(self, *args, **options): + def handle(self, *args: Any, **options: Any) -> None: LOG.info("pulse listener starting") try: TaskClusterConsumer( diff --git a/server/taskmanager/management/commands/taskmanager_scrape_task_group.py b/server/taskmanager/management/commands/taskmanager_scrape_task_group.py index 8b3e6e7ca..05cc97499 100644 --- a/server/taskmanager/management/commands/taskmanager_scrape_task_group.py +++ b/server/taskmanager/management/commands/taskmanager_scrape_task_group.py @@ -1,5 +1,10 @@ +from __future__ import annotations + import functools +from argparse import ArgumentParser +from collections.abc import Callable from logging import getLogger +from typing import Any, Generator import taskcluster from django.conf import settings @@ -11,7 +16,7 @@ LOG = getLogger("taskmanager.management.commands.scrape_group") -def paginated(func, result_key): +def paginated(func: Callable[..., Any], result_key: str) -> Callable[..., Any]: """Wraps a Taskcluster API that returns a result like: { continuationToken: "", @@ -22,7 +27,7 @@ def paginated(func, result_key): """ @functools.wraps(func) - def _wrapped(*args, **kwds): + def _wrapped(*args: Any, **kwds: Any) -> Generator[Any, Any, Any]: kwds = kwds.copy() result = func(*args, **kwds) while result.get("continuationToken"): @@ -38,7 +43,7 @@ def _wrapped(*args, **kwds): class Command(BaseCommand): help = "Scrape a task group and add created tasks to taskmanager" - def add_arguments(self, parser): + def add_arguments(self, parser: ArgumentParser) -> None: parser.add_argument( "task_group", help="Taskcluster task group to add tasks for", @@ -50,7 +55,7 @@ def add_arguments(self, parser): "(ie. include task with taskId == taskGroupId)", ) - def handle(self, *args, **options): + def handle(self, *args: Any, **options: Any) -> None: queue_svc = taskcluster.Queue({"rootUrl": settings.TC_ROOT_URL}) task_group_id = options["task_group"] @@ -67,7 +72,7 @@ def handle(self, *args, **options): task["status"]["taskId"], task["status"]["workerType"], ) - return + return None for run in task["status"]["runs"]: _, created = Task.objects.update_or_create( diff --git a/server/taskmanager/migrations/0001_initial.py b/server/taskmanager/migrations/0001_initial.py index 69747d144..52acc0053 100644 --- a/server/taskmanager/migrations/0001_initial.py +++ b/server/taskmanager/migrations/0001_initial.py @@ -1,5 +1,7 @@ # Generated by Django 2.2.12 on 2020-04-03 20:06 +from __future__ import annotations + import django.db.models.deletion from django.db import migrations, models @@ -8,7 +10,7 @@ class Migration(migrations.Migration): initial = True - dependencies = [] + dependencies: list[tuple[str, str]] = [] operations = [ migrations.CreateModel( diff --git a/server/taskmanager/models.py b/server/taskmanager/models.py index 332a6952c..e9d747810 100644 --- a/server/taskmanager/models.py +++ b/server/taskmanager/models.py @@ -1,3 +1,7 @@ +from __future__ import annotations + +from datetime import datetime, timedelta + from django.db import models diff --git a/server/taskmanager/serializers.py b/server/taskmanager/serializers.py index 5a04d23e7..216b0355f 100644 --- a/server/taskmanager/serializers.py +++ b/server/taskmanager/serializers.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import datetime from django.conf import settings @@ -15,12 +17,12 @@ RUN_RATIO_THRESHOLD = 0.03 -class PoolSerializer(serializers.ModelSerializer): +class PoolSerializer(serializers.ModelSerializer[Pool]): class Meta: model = Pool fields = "__all__" - def to_representation(self, instance): + def to_representation(self, instance: Pool): """Add dynamic fields""" ret = super().to_representation(instance) ret["cycle_time"] = None @@ -79,18 +81,18 @@ class Meta(PoolSerializer.Meta): "view_url", ) - def get_hook_url(self, pool): + def get_hook_url(self, pool: Pool) -> str: if pool.pool_id in settings.TC_EXTRA_POOLS: hook = pool.pool_id else: hook = f"{pool.platform}-{pool.pool_id}" return f"{settings.TC_ROOT_URL}hooks/project-{settings.TC_PROJECT}/{hook}" - def get_view_url(self, pool): + def get_view_url(self, pool: Pool) -> str: return reverse("taskmanager:pool-view-ui", kwargs={"pk": pool.id}) -class TaskSerializer(serializers.ModelSerializer): +class TaskSerializer(serializers.ModelSerializer[Task]): status_data = serializers.CharField(trim_whitespace=False) class Meta: @@ -116,5 +118,5 @@ class TaskVueSerializer(TaskSerializer): class Meta(TaskSerializer.Meta): read_only_fields = TaskSerializer.Meta.read_only_fields + ("task_url",) - def get_task_url(self, task): + def get_task_url(self, task: Task) -> str: return f"{settings.TC_ROOT_URL}tasks/{task.task_id}/runs/{task.run_id}" diff --git a/server/taskmanager/tasks.py b/server/taskmanager/tasks.py index 76ca2a8f4..e1219e829 100644 --- a/server/taskmanager/tasks.py +++ b/server/taskmanager/tasks.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from datetime import timedelta from logging import getLogger from pathlib import Path @@ -14,7 +16,7 @@ LOG = getLogger("taskmanager.tasks") -def get_or_create_pool(worker_type): +def get_or_create_pool(worker_type: str): from .models import Pool params = {} @@ -50,7 +52,7 @@ def get_or_create_pool(worker_type): @app.task(ignore_result=True) -def update_pool_defns(): +def update_pool_defns() -> None: from fuzzing_decision.common.pool import PoolConfigLoader from .models import Pool, Task @@ -115,7 +117,7 @@ def update_pool_defns(): @app.task(ignore_result=True) -def task_failed(task_pk): +def task_failed(task_pk) -> None: from django.contrib.auth.models import User as DjangoUser from django.contrib.contenttypes.models import ContentType from notifications.models import Notification @@ -163,7 +165,7 @@ def task_failed(task_pk): @app.task(ignore_result=True) -def update_task(pulse_data): +def update_task(pulse_data) -> None: import taskcluster from .models import Task diff --git a/server/taskmanager/tests/__init__.py b/server/taskmanager/tests/__init__.py index fd5a3f67e..49e3b0a8f 100644 --- a/server/taskmanager/tests/__init__.py +++ b/server/taskmanager/tests/__init__.py @@ -1,5 +1,8 @@ +from __future__ import annotations + import datetime import logging +from typing import cast from django.utils import timezone @@ -8,32 +11,40 @@ LOG = logging.getLogger("fm.taskmanager.tests") -def create_pool(): - pool = Pool.objects.create( - pool_id="pool1", - pool_name="Test Pool", - platform="linux", - size=1, - cpu="x64", - cycle_time=datetime.timedelta(days=31), +def create_pool() -> Pool: + pool = cast( + Pool, + Pool.objects.create( + pool_id="pool1", + pool_name="Test Pool", + platform="linux", + size=1, + cpu="x64", + cycle_time=datetime.timedelta(days=31), + ), ) LOG.debug("Create Pool pk=%d", pool.pk) return pool -def create_task(pool=None, task_id="TASK123", run_id=1): +def create_task( + pool: Pool | None = None, task_id: str = "TASK123", run_id: int = 1 +) -> Task: task_time = timezone.now() - task = Task.objects.create( - pool=pool, - task_id=task_id, - decision_id="DECISION123", - run_id=run_id, - state="running", - created=task_time, - status_data="Status text", - started=task_time + datetime.timedelta(minutes=5), - resolved=task_time + datetime.timedelta(minutes=10), - expires=task_time + datetime.timedelta(minutes=15), + task = cast( + Task, + Task.objects.create( + pool=pool, + task_id=task_id, + decision_id="DECISION123", + run_id=run_id, + state="running", + created=task_time, + status_data="Status text", + started=task_time + datetime.timedelta(minutes=5), + resolved=task_time + datetime.timedelta(minutes=10), + expires=task_time + datetime.timedelta(minutes=15), + ), ) LOG.debug("Create Task pk=%d", task.pk) return task diff --git a/server/taskmanager/tests/conftest.py b/server/taskmanager/tests/conftest.py index 9ce3f5ed5..1649a810c 100644 --- a/server/taskmanager/tests/conftest.py +++ b/server/taskmanager/tests/conftest.py @@ -9,6 +9,11 @@ License, v. 2.0. If a copy of the MPL was not distributed with this file, You can obtain one at http://mozilla.org/MPL/2.0/. """ + +from __future__ import annotations + +from typing import cast + import pytest from django.contrib.auth.models import Permission, User from django.contrib.contenttypes.models import ContentType @@ -17,13 +22,13 @@ def _create_user( - username, - email="test@mozilla.com", - password="test", - has_permission=True, - subscribed=True, -): - user = User.objects.create_user(username, email, password) + username: str, + email: str = "test@mozilla.com", + password: str = "test", + has_permission: bool = True, + subscribed: bool = True, +) -> User: + user = cast(User, User.objects.create_user(username, email, password)) user.user_permissions.clear() if has_permission: content_type = ContentType.objects.get_for_model(cmUser) @@ -39,7 +44,7 @@ def _create_user( @pytest.fixture -def taskmanager_test(db): # pylint: disable=invalid-name,unused-argument +def taskmanager_test(db: None) -> None: # pylint: disable=invalid-name,unused-argument """Common testcase class for all taskmanager unittests""" # Create one unrestricted and one restricted test user _create_user("test", subscribed=False) diff --git a/server/taskmanager/tests/test_pools_rest.py b/server/taskmanager/tests/test_pools_rest.py index 822c48397..ececa8b86 100644 --- a/server/taskmanager/tests/test_pools_rest.py +++ b/server/taskmanager/tests/test_pools_rest.py @@ -8,6 +8,9 @@ License, v. 2.0. If a copy of the MPL was not distributed with this file, You can obtain one at http://mozilla.org/MPL/2.0/. """ + +from __future__ import annotations + import datetime import itertools import json @@ -16,6 +19,7 @@ import pytest import requests from django.contrib.auth.models import User +from rest_framework.test import APIClient from . import create_pool, create_task @@ -23,7 +27,7 @@ pytestmark = pytest.mark.usefixtures("taskmanager_test") # pylint: disable=invalid-name -def test_rest_pools_no_auth(api_client): +def test_rest_pools_no_auth(api_client: APIClient) -> None: """must yield forbidden without authentication""" url = "/taskmanager/rest/pools/" assert api_client.get(url).status_code == requests.codes["unauthorized"] @@ -33,7 +37,7 @@ def test_rest_pools_no_auth(api_client): assert api_client.delete(url).status_code == requests.codes["unauthorized"] -def test_rest_pools_no_perm(api_client): +def test_rest_pools_no_perm(api_client: APIClient) -> None: """must yield forbidden without permission""" user = User.objects.get(username="test-noperm") api_client.force_authenticate(user=user) @@ -49,7 +53,7 @@ def test_rest_pools_no_perm(api_client): ("method", "item"), itertools.product(["post", "patch", "put", "delete"], [True, False]), ) -def test_rest_pool_methods(api_client, method, item): +def test_rest_pool_methods(api_client: APIClient, method: str, item: bool) -> None: """post/put/patch/delete should not be allowed""" user = User.objects.get(username="test") api_client.force_authenticate(user=user) @@ -59,14 +63,14 @@ def test_rest_pool_methods(api_client, method, item): else: url = "/taskmanager/rest/pools/" - method = getattr(api_client, method) - resp = method(url) + method_ = getattr(api_client, method) + resp = method_(url) LOG.debug(resp) assert resp.status_code == requests.codes["method_not_allowed"] @pytest.mark.parametrize("item", [True, False]) -def test_rest_pool_read(api_client, item): +def test_rest_pool_read(api_client: APIClient, item: bool) -> None: user = User.objects.get(username="test") api_client.force_authenticate(user=user) pool = create_pool() @@ -109,7 +113,7 @@ def test_rest_pool_read(api_client, item): assert datetime.timedelta(seconds=int(resp["cycle_time"])) == pool.cycle_time -def test_rest_pool_running_status(api_client): +def test_rest_pool_running_status(api_client: APIClient) -> None: user = User.objects.get(username="test") api_client.force_authenticate(user=user) pool = create_pool() diff --git a/server/taskmanager/tests/test_taskmanager.py b/server/taskmanager/tests/test_taskmanager.py index b66d300e9..f2c83b890 100644 --- a/server/taskmanager/tests/test_taskmanager.py +++ b/server/taskmanager/tests/test_taskmanager.py @@ -8,11 +8,17 @@ License, v. 2.0. If a copy of the MPL was not distributed with this file, You can obtain one at http://mozilla.org/MPL/2.0/. """ + +from __future__ import annotations + import logging +import typing import pytest import requests +from django.test.client import Client from django.urls import reverse +from pytest_django.fixtures import SettingsWrapper from . import create_pool @@ -21,15 +27,18 @@ @pytest.mark.parametrize("name", ["taskmanager:index", "taskmanager:pool-list-ui"]) -def test_views_no_login(name, client): +def test_views_no_login(name: str, client: Client) -> None: """Request without login hits the login redirect""" path = reverse(name) response = client.get(path, follow=False) assert response.status_code == requests.codes["found"] - assert response.url == "/login/?next=" + path + assert ( + typing.cast(typing.Union[str, None], getattr(response, "url", None)) + == "/login/?next=" + path + ) -def test_index_simple_get(client): +def test_index_simple_get(client: Client) -> None: """Index redirects""" client.login(username="test", password="test") response = client.get(reverse("taskmanager:index")) @@ -38,7 +47,7 @@ def test_index_simple_get(client): assert response["Location"] == reverse("taskmanager:pool-list-ui") -def test_view_simple_get(client): +def test_view_simple_get(client: Client) -> None: """No errors are thrown in template""" client.login(username="test", password="test") response = client.get(reverse("taskmanager:pool-list-ui")) @@ -46,15 +55,18 @@ def test_view_simple_get(client): assert response.status_code == requests.codes["ok"] -def test_detail_view_no_login(client): +def test_detail_view_no_login(client: Client) -> None: pool = create_pool() path = reverse("taskmanager:pool-view-ui", args=(pool.pk,)) response = client.get(path, follow=False) assert response.status_code == requests.codes["found"] - assert response.url == "/login/?next=" + path + assert ( + typing.cast(typing.Union[str, None], getattr(response, "url", None)) + == "/login/?next=" + path + ) -def test_detail_view_simple_get(client, settings): +def test_detail_view_simple_get(client: Client, settings: SettingsWrapper) -> None: settings.TC_EXTRA_POOLS = ["extra"] pool = create_pool() path = reverse("taskmanager:pool-view-ui", args=(pool.pk,)) diff --git a/server/taskmanager/tests/test_tasks_rest.py b/server/taskmanager/tests/test_tasks_rest.py index 4332c5db1..75ce9a757 100644 --- a/server/taskmanager/tests/test_tasks_rest.py +++ b/server/taskmanager/tests/test_tasks_rest.py @@ -8,14 +8,19 @@ License, v. 2.0. If a copy of the MPL was not distributed with this file, You can obtain one at http://mozilla.org/MPL/2.0/. """ + +from __future__ import annotations + import itertools import json import logging +from collections.abc import Callable import pytest import requests from django.contrib.auth.models import User from django.utils import dateparse +from rest_framework.test import APIClient from taskmanager.models import Task @@ -25,7 +30,7 @@ pytestmark = pytest.mark.usefixtures("taskmanager_test") # pylint: disable=invalid-name -def test_rest_tasks_no_auth(api_client): +def test_rest_tasks_no_auth(api_client: APIClient) -> None: """must yield forbidden without authentication""" url = "/taskmanager/rest/tasks/" assert api_client.get(url).status_code == requests.codes["unauthorized"] @@ -35,7 +40,7 @@ def test_rest_tasks_no_auth(api_client): assert api_client.delete(url).status_code == requests.codes["unauthorized"] -def test_rest_tasks_no_perm(api_client): +def test_rest_tasks_no_perm(api_client: APIClient) -> None: """must yield forbidden without permission""" user = User.objects.get(username="test-noperm") api_client.force_authenticate(user=user) @@ -51,7 +56,7 @@ def test_rest_tasks_no_perm(api_client): ("method", "item"), itertools.product(["post", "put", "patch", "delete"], [True, False]), ) -def test_rest_task_methods(api_client, method, item): +def test_rest_task_methods(api_client: APIClient, method: str, item: bool) -> None: """post/put/patch/delete should not be allowed""" user = User.objects.get(username="test") api_client.force_authenticate(user=user) @@ -62,14 +67,14 @@ def test_rest_task_methods(api_client, method, item): else: url = "/taskmanager/rest/tasks/" - method = getattr(api_client, method) - resp = method(url) + method_ = getattr(api_client, method) + resp = method_(url) LOG.debug(resp) assert resp.status_code == requests.codes["method_not_allowed"] @pytest.mark.parametrize("method", ["get", "put", "patch", "delete"]) -def test_rest_task_status_methods(api_client, method): +def test_rest_task_status_methods(api_client: APIClient, method: str) -> None: """post/put/patch/delete should not be allowed""" user = User.objects.get(username="test") api_client.force_authenticate(user=user) @@ -77,8 +82,8 @@ def test_rest_task_status_methods(api_client, method): create_task(pool=pool) url = "/taskmanager/rest/tasks/update_status/" - method = getattr(api_client, method) - resp = method(url) + method_ = getattr(api_client, method) + resp = method_(url) LOG.debug(resp) assert resp.status_code == requests.codes["method_not_allowed"] @@ -132,7 +137,12 @@ def test_rest_task_status_methods(api_client, method): ), ], ) -def test_rest_task_status(api_client, make_data, result, status_data): +def test_rest_task_status( + api_client: APIClient, + make_data: Callable[..., dict[str, str]], + result: int, + status_data: str, +) -> None: """post should require well-formed parameters""" user = User.objects.get(username="test") api_client.force_authenticate(user=user) @@ -147,7 +157,7 @@ def test_rest_task_status(api_client, make_data, result, status_data): assert task.status_data == status_data -def test_rest_task_status_unknown(api_client): +def test_rest_task_status_unknown(api_client: APIClient) -> None: """post should require well-formed parameters""" user = User.objects.get(username="test") api_client.force_authenticate(user=user) @@ -166,7 +176,7 @@ def test_rest_task_status_unknown(api_client): @pytest.mark.parametrize("item", [True, False]) -def test_rest_task_read(api_client, item): +def test_rest_task_read(api_client: APIClient, item: bool) -> None: user = User.objects.get(username="test") api_client.force_authenticate(user=user) pool = create_pool() diff --git a/server/taskmanager/tests/test_update_pools.py b/server/taskmanager/tests/test_update_pools.py index 7c099a5fb..8ace39da5 100644 --- a/server/taskmanager/tests/test_update_pools.py +++ b/server/taskmanager/tests/test_update_pools.py @@ -8,6 +8,9 @@ License, v. 2.0. If a copy of the MPL was not distributed with this file, You can obtain one at http://mozilla.org/MPL/2.0/. """ + +from __future__ import annotations + import datetime import logging import os.path @@ -15,6 +18,8 @@ import pytest from dateutil.parser import isoparse +from pytest_mock import MockerFixture + from notifications.models import Notification from crashmanager.models import User as cmUser @@ -302,7 +307,12 @@ @pytest.mark.parametrize("pulse_data, expected", TASK_EVENT_DATA.values()) -def test_update_task_0(mocker, settings, pulse_data, expected): +def test_update_task_0( + mocker: MockerFixture, + settings, + pulse_data: str, + expected: dict[str, dict[str, object]], +) -> None: """test that Task events update the DB""" settings.TC_EXTRA_POOLS = ["extra"] settings.TC_ROOT_URL = "https://allizom.org/tc" @@ -331,7 +341,7 @@ def test_update_task_0(mocker, settings, pulse_data, expected): assert getattr(task_obj, field) == value -def test_update_pool_defns_0(mocker, settings): +def test_update_pool_defns_0(mocker: MockerFixture, settings) -> None: """test that Pool definition is read from GH""" settings.TC_FUZZING_CFG_STORAGE = os.path.join( os.path.dirname(__file__), "fixtures", "pool1" diff --git a/server/taskmanager/urls.py b/server/taskmanager/urls.py index 0815345fb..4ee850433 100644 --- a/server/taskmanager/urls.py +++ b/server/taskmanager/urls.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from django.conf.urls import include from django.urls import re_path from rest_framework import routers diff --git a/server/taskmanager/views.py b/server/taskmanager/views.py index 5b1b2085a..b63fdb61d 100644 --- a/server/taskmanager/views.py +++ b/server/taskmanager/views.py @@ -1,13 +1,23 @@ +from __future__ import annotations + import datetime import logging import re +from typing import Any +from django.http.request import HttpRequest +from django.http.response import ( + HttpResponse, + HttpResponsePermanentRedirect, + HttpResponseRedirect, +) from django.shortcuts import get_object_or_404, redirect, render from django.utils import timezone from rest_framework import mixins, status, viewsets from rest_framework.authentication import SessionAuthentication, TokenAuthentication from rest_framework.decorators import action from rest_framework.filters import OrderingFilter +from rest_framework.request import Request from rest_framework.response import Response from server.auth import CheckAppPermission @@ -35,17 +45,17 @@ @deny_restricted_users -def index(request): +def index(request: HttpRequest) -> HttpResponsePermanentRedirect | HttpResponseRedirect: return redirect("taskmanager:pool-list-ui") @deny_restricted_users -def list_pools(request): +def list_pools(request: HttpRequest) -> HttpResponse: return render(request, "pool/index.html", {}) @deny_restricted_users -def view_pool(request, pk): +def view_pool(request: HttpRequest, pk: int) -> HttpResponse: pool = get_object_or_404(Pool, pk=pk) return render( request, @@ -72,7 +82,7 @@ class PoolViewSet(viewsets.ReadOnlyModelViewSet): OrderingFilter, ] - def get_serializer(self, *args, **kwds): + def get_serializer(self, *args: Any, **kwds: Any): vue = self.request.query_params.get("vue", "false").lower() not in ( "false", "0", @@ -100,7 +110,7 @@ class TaskViewSet( OrderingFilter, ] - def get_serializer(self, *args, **kwds): + def get_serializer(self, *args: Any, **kwds: Any): vue = self.request.query_params.get("vue", "false").lower() not in ( "false", "0", @@ -112,7 +122,7 @@ def get_serializer(self, *args, **kwds): @action( detail=False, methods=["post"], authentication_classes=(TokenAuthentication,) ) - def update_status(self, request): + def update_status(self, request: Request) -> Response: if set(request.data.keys()) != {"client", "status_data"}: LOG.debug("request.data.keys(): %s", request.data.keys()) errors = {} diff --git a/setup.cfg b/setup.cfg index 52f8de684..169886edb 100644 --- a/setup.cfg +++ b/setup.cfg @@ -22,6 +22,7 @@ keywords = fuzz fuzzing security test testing install_requires = fasteners>=0.14.1 requests>=2.20.1 + typing_extensions>=4.0.0 packages = Collector CovReporter @@ -36,7 +37,7 @@ python_requires = >=3.7 [options.extras_require] dev = pre-commit - tox + tox~=3.28.0 server = boto3 celery~=4.4.0 diff --git a/setup.py b/setup.py index da804ebee..edd3c8492 100755 --- a/setup.py +++ b/setup.py @@ -5,6 +5,8 @@ # file, You can obtain one at https://mozilla.org/MPL/2.0/. """setuptools install script""" +from __future__ import annotations + import site from setuptools import setup diff --git a/tox.ini b/tox.ini index 6c2653d75..1ec15e731 100644 --- a/tox.ini +++ b/tox.ini @@ -45,6 +45,40 @@ deps = commands = codecov +[testenv:mypy] +whitelist_externals = sed +commands = + sed -i -e 's/ "ec2spotmanager",/ # "ec2spotmanager",/' \ + -e 's/ "crashmanager",/ # "crashmanager",/' \ + -e 's/ "taskmanager",/ # "taskmanager",/' \ + -e 's/ "covmanager",/ # "covmanager",/' server/server/settings.py + # Ensure server/settings.py has its changes above reverted in the event of failure + python -c 'import argparse,subprocess,sys ; \ + parser = argparse.ArgumentParser() ; \ + parser.add_argument("-p", "--posargs"); \ + p_mypy = subprocess.run(["mypy", \ + "--explicit-package-bases", "--install-types", "--non-interactive", \ + parser.parse_args().posargs, \ + ]); \ + subprocess.run(["sed", "-i", \ + ("s/ # \\\"ec2spotmanager\\\",/ \\\"ec2spotmanager\\\",/;" \ + "s/ # \\\"crashmanager\\\",/ \\\"crashmanager\\\",/;" \ + "s/ # \\\"taskmanager\\\",/ \\\"taskmanager\\\",/;" \ + "s/ # \\\"covmanager\\\",/ \\\"covmanager\\\",/;"), \ + "server/server/settings.py", \ + ]); \ + sys.exit(p_mypy.returncode);' -p {posargs} +setenv = PYTHONPATH = {toxinidir}/server +deps = + boto3-stubs>=1.26.64 + celery-types>=0.14.0 + django-stubs>=1.14.0 + djangorestframework-stubs>=1.8.0 + ffpuppet~=0.9.2 + mypy==1.0.0 + typing-extensions>=4.4.0 +usedevelop = true + [testenv:pypi] skip_install = true deps =