Skip to content

Commit

Permalink
chg: Use new annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
Rafiot committed Jan 15, 2024
1 parent a45f917 commit 9a0340a
Show file tree
Hide file tree
Showing 6 changed files with 177 additions and 187 deletions.
16 changes: 4 additions & 12 deletions mypy.ini
Original file line number Diff line number Diff line change
@@ -1,16 +1,8 @@
[mypy]
check_untyped_defs = true
ignore_errors = false
ignore_missing_imports = false
strict_optional = true
no_implicit_optional = true
warn_unused_ignores = true
warn_redundant_casts = true
warn_unused_configs = true
warn_unreachable = true

show_error_context = true
pretty = true
strict = True
warn_return_any = False
show_error_context = True
pretty = True

[mypy-docs.source.*]
ignore_errors = True
217 changes: 106 additions & 111 deletions poetry.lock

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion pypdns/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
from .api import PyPDNS, PDNSRecord, TypedPDNSRecord # noqa
from .errors import PDNSError, RateLimitError, UnauthorizedError, ForbiddenError, ServerError # noqa

__all__ = ['PyPDNS', 'PDNSRecord', 'TypedPDNSRecord', 'PDNSError', 'RateLimitError', 'UnauthorizedError', 'ForbiddenError', 'ServerError']

def main():

def main() -> None:
parser = argparse.ArgumentParser(description='Triggers a request againse CIRCL Passive DNS.')
parser.add_argument('--username', required=True, help='The username of you account.')
parser.add_argument('--password', required=True, help='The password of you account.')
Expand Down
118 changes: 60 additions & 58 deletions pypdns/api.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

from __future__ import annotations

import json
import logging

from datetime import datetime
from functools import cached_property
from importlib.metadata import version
from typing import Optional, Tuple, List, Dict, Union, Any, TypedDict, overload, Literal, Generator
from typing import Any, TypedDict, overload, Literal, Generator

import requests
from requests import Session, Response
Expand All @@ -17,7 +18,8 @@

try:
import requests_cache
from requests_cache import CachedSession, CachedResponse
from requests_cache import CachedSession
from requests_cache.models import CachedResponse # type: ignore[attr-defined]
HAS_CACHE = True
except ImportError:
HAS_CACHE = False
Expand All @@ -32,20 +34,20 @@ class TypedPDNSRecord(TypedDict, total=False):

rrname: str
rrtype: str
rdata: Union[str, List[str]]
rdata: str | list[str]
time_first: int
time_last: int
count: Optional[int]
bailiwick: Optional[str]
sensor_id: Optional[str]
zone_time_first: Optional[int]
zone_time_last: Optional[int]
origin: Optional[str]
time_first_ms: Optional[int]
time_last_ms: Optional[int]
time_first_rfc3339: Optional[str]
time_last_rfc3339: Optional[str]
meta: Optional[Dict[Any, Any]]
count: int | None
bailiwick: str | None
sensor_id: str | None
zone_time_first: int | None
zone_time_last: int | None
origin: str | None
time_first_ms: int | None
time_last_ms: int | None
time_first_rfc3339: str | None
time_last_rfc3339: str | None
meta: dict[Any, Any] | None


class PDNSRecord:
Expand All @@ -54,11 +56,11 @@ class PDNSRecord:
'''
__slots__ = ('_record', )

def __init__(self, record: Dict[str, Optional[Union[str, int, bool, List[str], Dict[Any, Any]]]]):
def __init__(self, record: dict[str, str | int | bool | list[str] | dict[Any, Any] | None]):
self._record = record

@property
def raw(self) -> Dict[str, Optional[Union[str, int, bool, List[str], Dict[Any, Any]]]]:
def raw(self) -> dict[str, str | int | bool | list[str] | dict[Any, Any] | None]:
'''The raw record'''
return self._record

Expand Down Expand Up @@ -159,7 +161,7 @@ def rrtype(self) -> str:
return self.record['rrtype']

@property
def rdata(self) -> Union[str, List[str]]:
def rdata(self) -> str | list[str]:
'''The resource records of the queried resource'''
return self.record['rdata']

Expand Down Expand Up @@ -195,96 +197,98 @@ def time_last_datetime(self) -> datetime:
return datetime.fromtimestamp(self.time_last)

@property
def count(self) -> Optional[int]:
def count(self) -> int | None:
'''How many authoritative DNS answers were received at the Passive DNS Server's
collectors with exactly the given set of values as answers
'''
return self.record.get('count')

@property
def bailiwick(self) -> Optional[str]:
def bailiwick(self) -> str | None:
'''The best estimate of the apex of the zone where this data is authoritative'''
return self.record.get('bailiwick')

@property
def sensor_id(self) -> Optional[str]:
def sensor_id(self) -> str | None:
'''The sensor information where the record was seen.'''
return self.record.get('sensor_id')

@property
def zone_time_first(self) -> Optional[int]:
def zone_time_first(self) -> int | None:
'''The first time that the unique tuple (rrname, rrtype, rdata) record
has been seen via master file import
'''
return self.record.get('zone_time_first')

@property
def zone_time_last(self) -> Optional[int]:
def zone_time_last(self) -> int | None:
'''The last time that the unique tuple (rrname, rrtype, rdata) record
has been seen via master file import
'''
return self.record.get('zone_time_last')

@property
def origin(self) -> Optional[str]:
def origin(self) -> str | None:
'''The resource origin of the Passive DNS response'''
return self.record.get('origin')

@property
def time_first_ms(self) -> Optional[int]:
def time_first_ms(self) -> int | None:
'''The first time that the record / unique tuple (rrname, rrtype, rdata)
has been seen by the passive DNS, in miliseconds since 1st of January 1970 (UTC).
'''
return self.record.get('time_first_ms')

@property
def time_last_ms(self) -> Optional[int]:
def time_last_ms(self) -> int | None:
'''The first time that the record / unique tuple (rrname, rrtype, rdata)
has been seen by the passive DNS, in miliseconds since 1st of January 1970 (UTC).
'''
return self.record.get('time_last_ms')

@property
def time_first_rfc3339(self) -> Optional[str]:
def time_first_rfc3339(self) -> str | None:
return self.record.get('time_first_rfc3339')

@property
def time_last_rfc3339(self) -> Optional[str]:
def time_last_rfc3339(self) -> str | None:
return self.record.get('time_last_rfc3339')

@property
def meta(self) -> Optional[Dict[Any, Any]]:
def meta(self) -> dict[Any, Any] | None:
return self.record.get('meta')


class PyPDNS(object):
class PyPDNS:

def __init__(self, url: str='https://www.circl.lu/pdns/query',
basic_auth: Optional[Tuple[str, str]]=None,
auth_token: Optional[str]=None,
basic_auth: tuple[str, str] | None=None,
auth_token: str | None=None,
enable_cache: bool=False, cache_expire_after: int=604800,
cache_file: str='/tmp/pdns.cache',
https_proxy_string: Optional[str]=None,
useragent: Optional[str]=None,
disable_active_query: bool=False):
https_proxy_string: str | None=None,
useragent: str | None=None,
disable_active_query: bool=False,
*, proxies: dict[str, str] | None=None):
'''Connector to Passive DNS
:param url: The URL of the service
:param basic_auth: HTTP basic auth to cnnect to the service: ("username", "password")
:param auth_token: HTTP basic auth but the token
:param enable_cache: Cache responses locally
:param cache_file: The file to cache the responses to
:param https_proxy_string: The HTTP proxy to connect to the service
:param https_proxy_string: The HTTP proxy to connect to the service (deprecated, use proxies instead)
:param useragent: User Agent to submit to the server
:param disable_active_query: THe passive DNS will attempt to resolve the request by default. Set to True if you don't want that.
:param proxies: The proxies to use to connect to Passive DNS - More details: https://requests.readthedocs.io/en/latest/user/advanced/#proxies
'''

self.url = url

if enable_cache and not HAS_CACHE:
raise PDNSError('Please install requests_cache if you want to use the caching capabilities.')

self.session: Union[CachedSession, Session]
self.session: CachedSession | Session
if enable_cache is True:
requests_cache.install_cache(cache_file, backend='sqlite', expire_after=cache_expire_after)
self.session = requests_cache.CachedSession()
Expand All @@ -305,12 +309,14 @@ def __init__(self, url: str='https://www.circl.lu/pdns/query',
self.session.headers.update({'dribble-disable-active-query': '1'})

if https_proxy_string is not None:
proxy = {'https': https_proxy_string}
self.session.proxies.update(proxy)
proxies = {'https': https_proxy_string}

if proxies:
self.session.proxies.update(proxies)

def iter_query(self, q: str,
filter_rrtype: Optional[str]=None,
break_on_errors: bool=False) -> Generator[PDNSRecord, None, Optional[Dict[str, Union[str, int]]]]:
filter_rrtype: str | None=None,
break_on_errors: bool=False) -> Generator[PDNSRecord, None, dict[str, str | int] | None]:
'''Iterate over all the recording matching your request, useful if there are a lot.
Note: the order is non-deterministic.
Expand All @@ -325,9 +331,7 @@ def iter_query(self, q: str,
while True:
if cursor > 0:
query_headers['dribble-paginate-cursor'] = str(cursor)
response: Union[Response, CachedResponse] = self.session.get(f'{self.url}/{q}',
timeout=15,
headers=query_headers)
response: Response | CachedResponse = self.session.get(f'{self.url}/{q}', timeout=15, headers=query_headers)
if response.status_code != 200:
self._handle_http_error(response)
if break_on_errors:
Expand All @@ -353,18 +357,16 @@ def iter_query(self, q: str,

def _query(self, q: str, sort_by: str = 'time_last',
*,
filter_rrtype: Optional[str]=None) -> Tuple[List[Dict[str, Optional[Union[str, int, bool, List[str], Dict[Any, Any]]]]],
Dict[str, Union[str, int]]]:
filter_rrtype: str | None=None) -> tuple[list[dict[str, str | int | bool | list[str] | dict[Any, Any] | None]],
dict[str, str | int]]:
'''Internal method running a non-paginated query, can be sorted.'''
logger.debug("start query() q=[%s]", q)
if sort_by not in sort_choice:
raise PDNSError(f'You can only sort by {", ".join(sort_choice)}')
query_headers = {}
if filter_rrtype:
query_headers['dribble-filter-rrtype'] = filter_rrtype
response: Union[Response, CachedResponse] = self.session.get(f'{self.url}/{q}',
timeout=15,
headers=query_headers)
response: Response | CachedResponse = self.session.get(f'{self.url}/{q}', timeout=15, headers=query_headers)
if response.status_code != 200:
self._handle_http_error(response)
errors = self._handle_dribble_errors(response)
Expand All @@ -387,23 +389,23 @@ def _query(self, q: str, sort_by: str = 'time_last',
def rfc_query(self, q: str, /,
*,
sort_by: str = 'time_last',
filter_rrtype: Optional[str]= None,
with_errors: Literal[True]) -> Tuple[List[PDNSRecord], Dict[str, Union[str, int]]]:
filter_rrtype: str | None= None,
with_errors: Literal[True]) -> tuple[list[PDNSRecord], dict[str, str | int]]:
pass

@overload
def rfc_query(self, q: str, /,
*,
sort_by: str = 'time_last',
filter_rrtype: Optional[str]= None,
with_errors: Literal[False]) -> List[PDNSRecord]:
filter_rrtype: str | None= None,
with_errors: Literal[False]) -> list[PDNSRecord]:
pass

def rfc_query(self, q: str, /,
*,
sort_by: str = 'time_last',
filter_rrtype: Optional[str]= None,
with_errors: bool=False) -> Union[List[PDNSRecord], Tuple[List[PDNSRecord], Dict[str, Union[str, int]]]]:
filter_rrtype: str | None= None,
with_errors: bool=False) -> list[PDNSRecord] | tuple[list[PDNSRecord], dict[str, str | int]]:
'''Triggers a non-paginated query, can be sorted but will raise an error if the response is too big.
:param q: The query
Expand All @@ -417,7 +419,7 @@ def rfc_query(self, q: str, /,
return to_return_records
return to_return_records, errors

def query(self, q: str, sort_by: str = 'time_last', timeout: Optional[int] = None) -> List[Dict]:
def query(self, q: str, sort_by: str = 'time_last', timeout: int | None = None) -> list[dict[str, Any]]:
'''This method (almost) returns the response from the server but turns the times into python datetime.
It was a bad design decision hears ago. Use rfc_query instead for something saner.
This method is deprecated.
Expand All @@ -428,13 +430,13 @@ def query(self, q: str, sort_by: str = 'time_last', timeout: Optional[int] = Non
record['time_last'] = datetime.fromtimestamp(record['time_last']) # type: ignore
return records

def _handle_dribble_errors(self, response: requests.Response) -> Dict[str, Union[str, int]]:
def _handle_dribble_errors(self, response: requests.Response) -> dict[str, str | int]:
if 'x-dribble-errors' in response.headers:
return json.loads(response.headers['x-dribble-errors'])
return {}

@staticmethod
def _handle_http_error(response: requests.Response):
def _handle_http_error(response: requests.Response) -> None:
if response.status_code == 401:
raise UnauthorizedError("Not authenticated: is authentication correct?")
if response.status_code == 403:
Expand Down
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@ Sphinx = [
]

[tool.poetry.group.dev.dependencies]
mypy = "^1.7.0"
types-requests = "^2.31.0.10"
pytest = "^7.4.3"
mypy = "^1.8.0"
types-requests = "^2.31.0.20240106"
pytest = "^7.4.4"

[tool.poetry.extras]
docs = ["Sphinx"]
Expand Down
3 changes: 1 addition & 2 deletions tests/test.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

import unittest
from pypdns import PyPDNS, UnauthorizedError


class TestBasic(unittest.TestCase):

def test_not_auth(self):
def test_not_auth(self) -> None:
x = PyPDNS(basic_auth=('username', 'yourpassword'))
with self.assertRaises(UnauthorizedError):
x.query('www.microsoft.com')
Expand Down

0 comments on commit 9a0340a

Please sign in to comment.