Skip to content

Commit

Permalink
de-couple from TrackedRequest
Browse files Browse the repository at this point in the history
  • Loading branch information
quinnmil committed Jan 14, 2025
1 parent a99de8b commit eeefff4
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 87 deletions.
27 changes: 10 additions & 17 deletions src/scout_apm/core/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
import random
from typing import Dict, Optional, Tuple

from scout_apm.core.tracked_request import TrackedRequest

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -110,20 +108,12 @@ def _get_operation_type_and_name(
else:
return None, None

def get_effective_sample_rate(self, request: TrackedRequest) -> int:
def get_effective_sample_rate(
self, operation: str, is_ignored: bool = False
) -> int:
"""
Determines the effective sample rate for a given operation.
Priority order (highest to lowest):
1. Exact matches in sample_endpoints/sample_jobs
2. Exact matches in ignore lists (returns 0)
3. Prefix matches in sample_endpoints/sample_jobs
4. Legacy ignore patterns (returns 0)
5. Request-level ignore (returns 0)
6. Operation-specific default rate
7. Global sample rate
"""
operation = request.operation
op_type, name = self._get_operation_type_and_name(operation)

if not op_type or not name:
Expand Down Expand Up @@ -155,8 +145,9 @@ def get_effective_sample_rate(self, request: TrackedRequest) -> int:
if self._is_legacy_ignored(name):
return 0

# Check if request is explicitly ignored via tag
if request.is_ignored():
# Check if request is explicitly ignored via the
# is_ignored() tracked_request method.
if is_ignored:
return 0

# Use operation-specific default rate if available
Expand All @@ -166,7 +157,7 @@ def get_effective_sample_rate(self, request: TrackedRequest) -> int:
# Fall back to global sample rate
return self.sample_rate

def should_sample(self, request: TrackedRequest) -> bool:
def should_sample(self, operation: str, is_ignored: bool) -> bool:
"""
Determines if an operation should be sampled.
If no sampling is enabled, always return True.
Expand All @@ -180,4 +171,6 @@ def should_sample(self, request: TrackedRequest) -> bool:
"""
if not self._any_sampling():
return True
return random.randint(1, 100) <= self.get_effective_sample_rate(request)
return random.randint(1, 100) <= self.get_effective_sample_rate(
operation, is_ignored
)
114 changes: 44 additions & 70 deletions tests/unit/core/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

from scout_apm.core.config import ScoutConfig
from scout_apm.core.sampler import Sampler
from scout_apm.core.tracked_request import TrackedRequest


@pytest.fixture
Expand Down Expand Up @@ -37,63 +36,49 @@ def sampler(config):
return Sampler(config)


@pytest.fixture
def tracked_request():
return TrackedRequest()


def test_should_sample_endpoint_always(sampler, tracked_request):
tracked_request.operation = "Controller/users"
assert sampler.should_sample(tracked_request) is True
def test_should_sample_endpoint_always(sampler):
assert sampler.should_sample("Controller/users", False) is True


def test_should_sample_endpoint_never(sampler, tracked_request):
tracked_request.operation = "Controller/health/check"
assert sampler.should_sample(tracked_request) is False
tracked_request.operation = "Controller/users/test"
assert sampler.should_sample(tracked_request) is False
def test_should_sample_endpoint_never(sampler):
assert sampler.should_sample("Controller/health/check", False) is False
assert sampler.should_sample("Controller/users/test", False) is False


def test_should_sample_endpoint_ignored(sampler, tracked_request):
tracked_request.operation = "Controller/metrics"
assert sampler.should_sample(tracked_request) is False
def test_should_sample_endpoint_ignored(sampler):
assert sampler.should_sample("Controller/metrics", False) is False


def test_should_sample_endpoint_partial(sampler, tracked_request):
tracked_request.operation = "Controller/test/endpoint"
def test_should_sample_endpoint_partial(sampler):
with mock.patch("random.randint", return_value=10):
assert sampler.should_sample(tracked_request) is True
assert sampler.should_sample("Controller/test/endpoint", False) is True
with mock.patch("random.randint", return_value=30):
assert sampler.should_sample(tracked_request) is False
assert sampler.should_sample("Controller/test/endpoint", False) is False


def test_should_sample_job_always(sampler, tracked_request):
tracked_request.operation = "Job/critical-job"
assert sampler.should_sample(tracked_request) is True
def test_should_sample_job_always(sampler):
assert sampler.should_sample("Job/critical-job", False) is True


def test_should_sample_job_never(sampler, tracked_request):
tracked_request.operation = "Job/test-job"
assert sampler.should_sample(tracked_request) is False
def test_should_sample_job_never(sampler):
assert sampler.should_sample("Job/test-job", False) is False


def test_should_sample_job_partial(sampler, tracked_request):
tracked_request.operation = "Job/batch-process"
def test_should_sample_job_partial(sampler):
with mock.patch("random.randint", return_value=10):
assert sampler.should_sample(tracked_request) is True
assert sampler.should_sample("Job/batch-process", False) is True
with mock.patch("random.randint", return_value=40):
assert sampler.should_sample(tracked_request) is False
assert sampler.should_sample("Job/batch-process", False) is False


def test_should_sample_unknown_operation(sampler, tracked_request):
tracked_request.operation = "Unknown/operation"
def test_should_sample_unknown_operation(sampler):
with mock.patch("random.randint", return_value=10):
assert sampler.should_sample(tracked_request) is True
assert sampler.should_sample("Unknown/operation", False) is True
with mock.patch("random.randint", return_value=60):
assert sampler.should_sample(tracked_request) is False
assert sampler.should_sample("Unknown/operation", False) is False


def test_should_sample_no_sampling_enabled(config, tracked_request):
def test_should_sample_no_sampling_enabled(config):
config.set(
sample_rate=100, # Return config to defaults
sample_endpoints={},
Expand All @@ -104,49 +89,43 @@ def test_should_sample_no_sampling_enabled(config, tracked_request):
job_sample_rate=None,
)
sampler = Sampler(config)
tracked_request.operation = "Controller/any_endpoint"
assert sampler.should_sample(tracked_request) is True
tracked_request.operation = "Job/any_job"
assert sampler.should_sample(tracked_request) is True
assert sampler.should_sample("Controller/any_endpoint", False) is True
assert sampler.should_sample("Job/any_job", False) is True


def test_should_sample_endpoint_default_rate(sampler, tracked_request):
tracked_request.operation = "Controller/unspecified"
def test_should_sample_endpoint_default_rate(sampler):
with mock.patch("random.randint", return_value=60):
assert sampler.should_sample(tracked_request) is True
assert sampler.should_sample("Controller/unspecified", False) is True
with mock.patch("random.randint", return_value=80):
assert sampler.should_sample(tracked_request) is False
assert sampler.should_sample("Controller/unspecified", False) is False


def test_should_sample_job_default_rate(sampler, tracked_request):
tracked_request.operation = "Job/unspecified-job"
def test_should_sample_job_default_rate(sampler):
with mock.patch("random.randint", return_value=30):
assert sampler.should_sample(tracked_request) is True
assert sampler.should_sample("Job/unspecified-job", False) is True
with mock.patch("random.randint", return_value=50):
assert sampler.should_sample(tracked_request) is False
assert sampler.should_sample("Job/unspecified-job", False) is False


def test_should_sample_endpoint_fallback_to_global_rate(config, tracked_request):
def test_should_sample_endpoint_fallback_to_global_rate(config):
config.set(endpoint_sample_rate=None)
sampler = Sampler(config)
tracked_request.operation = "Controller/unspecified"
with mock.patch("random.randint", return_value=40):
assert sampler.should_sample(tracked_request) is True
assert sampler.should_sample("Controller/unspecified", False) is True
with mock.patch("random.randint", return_value=60):
assert sampler.should_sample(tracked_request) is False
assert sampler.should_sample("Controller/unspecified", False) is False


def test_should_sample_job_fallback_to_global_rate(config, tracked_request):
def test_should_sample_job_fallback_to_global_rate(config):
config.set(job_sample_rate=None)
sampler = Sampler(config)
tracked_request.operation = "Job/unspecified-job"
with mock.patch("random.randint", return_value=40):
assert sampler.should_sample(tracked_request) is True
assert sampler.should_sample("Job/unspecified-job", False) is True
with mock.patch("random.randint", return_value=60):
assert sampler.should_sample(tracked_request) is False
assert sampler.should_sample("Job/unspecified-job", False) is False


def test_should_handle_legacy_ignore_with_specific_sampling(config, tracked_request):
def test_should_handle_legacy_ignore_with_specific_sampling(config):
"""Test that specific sampling rates override legacy ignore patterns."""
config.set(
ignore=["foo"],
Expand All @@ -157,18 +136,16 @@ def test_should_handle_legacy_ignore_with_specific_sampling(config, tracked_requ
sampler = Sampler(config)

# foo/bar should be sampled at 50%
tracked_request.operation = "Controller/foo/bar"
with mock.patch("random.randint", return_value=40):
assert sampler.should_sample(tracked_request) is True
assert sampler.should_sample("Controller/foo/bar", False) is True
with mock.patch("random.randint", return_value=60):
assert sampler.should_sample(tracked_request) is False
assert sampler.should_sample("Controller/foo/bar", False) is False

# foo/other should be ignored (0% sampling)
tracked_request.operation = "Controller/foo/other"
assert sampler.should_sample(tracked_request) is False
assert sampler.should_sample("Controller/foo/other", False) is False


def test_prefix_matching_precedence(config, tracked_request):
def test_prefix_matching_precedence(config):
"""Test that longer prefix matches take precedence."""
config.set(
sample_endpoints={
Expand All @@ -180,16 +157,13 @@ def test_prefix_matching_precedence(config, tracked_request):
sampler = Sampler(config)

# Regular API endpoint should be ignored
tracked_request.operation = "Controller/api/status"
assert sampler.should_sample(tracked_request) is False
assert sampler.should_sample("Controller/api/status", False) is False

# Users API should be sampled at 50%
tracked_request.operation = "Controller/api/users/list"
with mock.patch("random.randint", return_value=40):
assert sampler.should_sample(tracked_request) is True
assert sampler.should_sample("Controller/api/users/list", False) is True
with mock.patch("random.randint", return_value=60):
assert sampler.should_sample(tracked_request) is False
assert sampler.should_sample("Controller/api/users/list", False) is False

# VIP users API should always be sampled
tracked_request.operation = "Controller/api/users/vip/list"
assert sampler.should_sample(tracked_request) is True
assert sampler.should_sample("Controller/api/users/vip/list", False) is True

0 comments on commit eeefff4

Please sign in to comment.