From fa51a75f11d4e2b91f7426cefcfdd8787e9f1470 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Przemys=C5=82aw=20Suliga?= Date: Sun, 1 Dec 2024 20:07:43 +0100 Subject: [PATCH 1/2] Add deadlock prevention MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Detects and prevents deadlocks during the library misuse, eg. by injecting code into the critical sections that itself might want to obtain the relevant lock. A follow up to #1076. Signed-off-by: Przemysław Suliga --- prometheus_client/errors.py | 3 +++ prometheus_client/metrics.py | 9 ++++---- prometheus_client/registry.py | 4 ++-- prometheus_client/utils.py | 40 +++++++++++++++++++++++++++++++++++ prometheus_client/values.py | 9 ++++---- tests/test_core.py | 29 ++++++++++++++++++++++++- 6 files changed, 81 insertions(+), 13 deletions(-) create mode 100644 prometheus_client/errors.py diff --git a/prometheus_client/errors.py b/prometheus_client/errors.py new file mode 100644 index 00000000..d9906584 --- /dev/null +++ b/prometheus_client/errors.py @@ -0,0 +1,3 @@ + +class PrometheusClientRuntimeError(RuntimeError): + pass diff --git a/prometheus_client/metrics.py b/prometheus_client/metrics.py index 9b251274..e1d3e1fd 100644 --- a/prometheus_client/metrics.py +++ b/prometheus_client/metrics.py @@ -1,5 +1,4 @@ import os -from threading import Lock import time import types from typing import ( @@ -13,7 +12,7 @@ from .metrics_core import Metric from .registry import Collector, CollectorRegistry, REGISTRY from .samples import Exemplar, Sample -from .utils import floatToGoString, INF +from .utils import floatToGoString, INF, WarnLock from .validation import ( _validate_exemplar, _validate_labelnames, _validate_metric_name, ) @@ -120,7 +119,7 @@ def __init__(self: T, if self._is_parent(): # Prepare the fields needed for child metrics. - self._lock = Lock() + self._lock = WarnLock() self._metrics: Dict[Sequence[str], T] = {} if self._is_observable(): @@ -673,7 +672,7 @@ class Info(MetricWrapperBase): def _metric_init(self): self._labelname_set = set(self._labelnames) - self._lock = Lock() + self._lock = WarnLock() self._value = {} def info(self, val: Dict[str, str]) -> None: @@ -735,7 +734,7 @@ def __init__(self, def _metric_init(self) -> None: self._value = 0 - self._lock = Lock() + self._lock = WarnLock() def state(self, state: str) -> None: """Set enum metric state.""" diff --git a/prometheus_client/registry.py b/prometheus_client/registry.py index 694e4bd8..7163cb67 100644 --- a/prometheus_client/registry.py +++ b/prometheus_client/registry.py @@ -1,9 +1,9 @@ from abc import ABC, abstractmethod import copy -from threading import Lock from typing import Dict, Iterable, List, Optional from .metrics_core import Metric +from .utils import WarnLock # Ideally this would be a Protocol, but Protocols are only available in Python >= 3.8. @@ -30,7 +30,7 @@ def __init__(self, auto_describe: bool = False, target_info: Optional[Dict[str, self._collector_to_names: Dict[Collector, List[str]] = {} self._names_to_collectors: Dict[str, Collector] = {} self._auto_describe = auto_describe - self._lock = Lock() + self._lock = WarnLock() self._target_info: Optional[Dict[str, str]] = {} self.set_target_info(target_info) diff --git a/prometheus_client/utils.py b/prometheus_client/utils.py index 0d2b0948..d913415b 100644 --- a/prometheus_client/utils.py +++ b/prometheus_client/utils.py @@ -1,4 +1,7 @@ import math +from threading import Lock, RLock + +from .errors import PrometheusClientRuntimeError INF = float("inf") MINUS_INF = float("-inf") @@ -22,3 +25,40 @@ def floatToGoString(d): mantissa = f'{s[0]}.{s[1:dot]}{s[dot + 1:]}'.rstrip('0.') return f'{mantissa}e+0{dot - 1}' return s + + +class WarnLock: + """A wrapper around RLock and Lock that prevents deadlocks. + + Raises a RuntimeError when it detects attempts to re-enter the critical + section from a single thread. Intended to be used as a context manager. + """ + error_msg = ( + 'Attempt to enter a non reentrant context from a single thread.' + ' It is possible that the client code is trying to register or update' + ' metrics from within metric registration code or from a signal handler' + ' while metrics are being registered or updated.' + ' This is unsafe and cannot be allowed. It would result in a deadlock' + ' if this exception was not raised.' + ) + + def __init__(self): + self._rlock = RLock() + self._lock = Lock() + + def __enter__(self): + self._rlock.acquire() + if not self._lock.acquire(blocking=False): + self._rlock.release() + raise PrometheusClientRuntimeError(self.error_msg) + + def __exit__(self, exc_type, exc_value, traceback): + self._lock.release() + self._rlock.release() + + def _locked(self): + # For use in tests. + if self._rlock.acquire(blocking=False): + self._rlock.release() + return False + return True diff --git a/prometheus_client/values.py b/prometheus_client/values.py index 6ff85e3b..20c9978a 100644 --- a/prometheus_client/values.py +++ b/prometheus_client/values.py @@ -1,8 +1,8 @@ import os -from threading import Lock import warnings from .mmap_dict import mmap_key, MmapedDict +from .utils import WarnLock class MutexValue: @@ -13,7 +13,7 @@ class MutexValue: def __init__(self, typ, metric_name, name, labelnames, labelvalues, help_text, **kwargs): self._value = 0.0 self._exemplar = None - self._lock = Lock() + self._lock = WarnLock() def inc(self, amount): with self._lock: @@ -47,10 +47,9 @@ def MultiProcessValue(process_identifier=os.getpid): files = {} values = [] pid = {'value': process_identifier()} - # Use a single global lock when in multi-processing mode - # as we presume this means there is no threading going on. + # Use a single global lock when in multi-processing mode. # This avoids the need to also have mutexes in __MmapDict. - lock = Lock() + lock = WarnLock() class MmapedValue: """A float protected by a mutex backed by a per-process mmaped file.""" diff --git a/tests/test_core.py b/tests/test_core.py index f28c9abc..0ba4c783 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -13,6 +13,7 @@ StateSetMetricFamily, Summary, SummaryMetricFamily, UntypedMetricFamily, ) from prometheus_client.decorator import getargspec +from prometheus_client.errors import PrometheusClientRuntimeError from prometheus_client.metrics import _get_use_created from prometheus_client.validation import ( disable_legacy_validation, enable_legacy_validation, @@ -134,6 +135,19 @@ def test_exemplar_too_long(self): 'y123456': '7+15 characters', }) + def test_single_thread_deadlock_detection(self): + counter = self.counter + + class Tracked(float): + def __radd__(self, other): + counter.inc(10) + return self + other + + expected_msg = 'Attempt to enter a non reentrant context from a single thread.' + self.assertRaisesRegex( + PrometheusClientRuntimeError, expected_msg, counter.inc, Tracked(100) + ) + class TestDisableCreated(unittest.TestCase): def setUp(self): @@ -1004,7 +1018,20 @@ def test_restricted_registry_does_not_yield_while_locked(self): m = Metric('target', 'Target metadata', 'info') m.samples = [Sample('target_info', {'foo': 'bar'}, 1)] for _ in registry.restricted_registry(['target_info', 's_sum']).collect(): - self.assertFalse(registry._lock.locked()) + self.assertFalse(registry._lock._locked()) + + def test_registry_deadlock_detection(self): + registry = CollectorRegistry(auto_describe=True) + + class RecursiveCollector: + def collect(self): + Counter('x', 'help', registry=registry) + return [CounterMetricFamily('c_total', 'help', value=1)] + + expected_msg = 'Attempt to enter a non reentrant context from a single thread.' + self.assertRaisesRegex( + PrometheusClientRuntimeError, expected_msg, registry.register, RecursiveCollector() + ) if __name__ == '__main__': From c0d5593faaee7d13511203bcf31f0dd8644759bc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Przemys=C5=82aw=20Suliga?= Date: Sun, 8 Dec 2024 16:35:17 +0100 Subject: [PATCH 2/2] Drop the use of double locking in values and metrics modules MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit It's hard to justify the overhead of double locking there. Signed-off-by: Przemysław Suliga --- prometheus_client/metrics.py | 9 +++++---- prometheus_client/values.py | 6 +++--- tests/test_core.py | 13 ------------- 3 files changed, 8 insertions(+), 20 deletions(-) diff --git a/prometheus_client/metrics.py b/prometheus_client/metrics.py index e1d3e1fd..9b251274 100644 --- a/prometheus_client/metrics.py +++ b/prometheus_client/metrics.py @@ -1,4 +1,5 @@ import os +from threading import Lock import time import types from typing import ( @@ -12,7 +13,7 @@ from .metrics_core import Metric from .registry import Collector, CollectorRegistry, REGISTRY from .samples import Exemplar, Sample -from .utils import floatToGoString, INF, WarnLock +from .utils import floatToGoString, INF from .validation import ( _validate_exemplar, _validate_labelnames, _validate_metric_name, ) @@ -119,7 +120,7 @@ def __init__(self: T, if self._is_parent(): # Prepare the fields needed for child metrics. - self._lock = WarnLock() + self._lock = Lock() self._metrics: Dict[Sequence[str], T] = {} if self._is_observable(): @@ -672,7 +673,7 @@ class Info(MetricWrapperBase): def _metric_init(self): self._labelname_set = set(self._labelnames) - self._lock = WarnLock() + self._lock = Lock() self._value = {} def info(self, val: Dict[str, str]) -> None: @@ -734,7 +735,7 @@ def __init__(self, def _metric_init(self) -> None: self._value = 0 - self._lock = WarnLock() + self._lock = Lock() def state(self, state: str) -> None: """Set enum metric state.""" diff --git a/prometheus_client/values.py b/prometheus_client/values.py index 20c9978a..ed4e75cd 100644 --- a/prometheus_client/values.py +++ b/prometheus_client/values.py @@ -1,8 +1,8 @@ import os +from threading import Lock import warnings from .mmap_dict import mmap_key, MmapedDict -from .utils import WarnLock class MutexValue: @@ -13,7 +13,7 @@ class MutexValue: def __init__(self, typ, metric_name, name, labelnames, labelvalues, help_text, **kwargs): self._value = 0.0 self._exemplar = None - self._lock = WarnLock() + self._lock = Lock() def inc(self, amount): with self._lock: @@ -49,7 +49,7 @@ def MultiProcessValue(process_identifier=os.getpid): pid = {'value': process_identifier()} # Use a single global lock when in multi-processing mode. # This avoids the need to also have mutexes in __MmapDict. - lock = WarnLock() + lock = Lock() class MmapedValue: """A float protected by a mutex backed by a per-process mmaped file.""" diff --git a/tests/test_core.py b/tests/test_core.py index 0ba4c783..5b12e1c6 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -135,19 +135,6 @@ def test_exemplar_too_long(self): 'y123456': '7+15 characters', }) - def test_single_thread_deadlock_detection(self): - counter = self.counter - - class Tracked(float): - def __radd__(self, other): - counter.inc(10) - return self + other - - expected_msg = 'Attempt to enter a non reentrant context from a single thread.' - self.assertRaisesRegex( - PrometheusClientRuntimeError, expected_msg, counter.inc, Tracked(100) - ) - class TestDisableCreated(unittest.TestCase): def setUp(self):