From 89e278dcc327c6d65b325d413b1b7ae68d48abe6 Mon Sep 17 00:00:00 2001 From: mohammedtarek Date: Sat, 2 May 2026 22:53:26 +0300 Subject: [PATCH 1/3] feat: add OTel metrics to OpenTelemetryMiddleware Add counters and histograms to OpenTelemetryMiddleware: - tasks_sent: producer-side counter per task name - task_success / task_errors: consumer-side counters with retry_error attribute - task_execution_time: histogram using result.execution_time - task_wait_time: histogram measuring queue time from send to receive via UTC timestamps in labels Add tests covering all instruments, retry_error attribute paths, and queue time correctness. --- .../middlewares/opentelemetry_middleware.py | 88 ++++++++- tests/opentelemetry/taskiq_test_tasks.py | 6 + tests/opentelemetry/test_metrics.py | 169 ++++++++++++++++++ 3 files changed, 262 insertions(+), 1 deletion(-) create mode 100644 tests/opentelemetry/test_metrics.py diff --git a/taskiq/middlewares/opentelemetry_middleware.py b/taskiq/middlewares/opentelemetry_middleware.py index 6ebbe10e..3780184e 100644 --- a/taskiq/middlewares/opentelemetry_middleware.py +++ b/taskiq/middlewares/opentelemetry_middleware.py @@ -1,5 +1,6 @@ import logging from contextlib import AbstractContextManager +from datetime import datetime, timezone from importlib.metadata import version from typing import Any, TypeVar @@ -59,6 +60,9 @@ _TASK_RETRY_REASON_KEY = "taskiq.retry.reason" _TASK_NAME_KEY = "taskiq.task_name" +_TASK_QUEUE_TIME_KEY = "_taskiq_queue_time" +_TASK_RECEIVED_TIME_KEY = "_taskiq_broker_receive_time" + def set_attributes_from_context(span: Span, context: dict[str, Any]) -> None: """Helper to extract meta values from a Taskiq Context.""" @@ -170,6 +174,37 @@ def __init__( if meter is None else meter ) + # Create metrics + # 1- Number of tasks sent. Producer (Counter) + self.n_tasks_sent_counter = self._meter.create_counter( + name="tasks_sent", + unit="1", + description="Number of tasks sent from the producer side", + ) + # 2- Number of errors by task name. consumer (Counter) + self.n_errors_counter = self._meter.create_counter( + name="task_errors", + unit="1", + description="Number of errors raised", + ) + # 3- Number of task successes. consumer (Counter) + self.n_success_counter = self._meter.create_counter( + name="task_success", + unit="1", + description="Number of tasks completed successfully", + ) + # 4- Task execution time. consumer (Histogram) + self.execution_time_hist = self._meter.create_histogram( + "task_execution_time", + unit="s", + description="Time to finish executing tasks", + ) + # 5- Task wait time. both (Histogram) + self.task_wait_time = self._meter.create_histogram( + "task_wait_time", + unit="s", + description="Time the tasks waited before executing", + ) def pre_send(self, message: TaskiqMessage) -> TaskiqMessage: """ @@ -193,7 +228,7 @@ def pre_send(self, message: TaskiqMessage) -> TaskiqMessage: activation.__enter__() attach_context(message, span, activation, None, is_publish=True) inject(message.labels) - + message.labels[_TASK_QUEUE_TIME_KEY] = datetime.now(timezone.utc).timestamp() return message def post_send(self, message: TaskiqMessage) -> None: @@ -214,6 +249,7 @@ def post_send(self, message: TaskiqMessage) -> None: activation.__exit__(None, None, None) detach_context(message, is_publish=True) + self.n_tasks_sent_counter.add(1, attributes={"task_name": message.task_name}) def pre_execute(self, message: TaskiqMessage) -> TaskiqMessage: """ @@ -236,6 +272,7 @@ def pre_execute(self, message: TaskiqMessage) -> TaskiqMessage: activation = trace.use_span(span, end_on_exit=True) activation.__enter__() # pylint: disable=E1101 attach_context(message, span, activation, token) + message.labels[_TASK_RECEIVED_TIME_KEY] = datetime.now(timezone.utc).timestamp() return message def post_save( # pylint: disable=R6301 @@ -313,3 +350,52 @@ def on_error( } span.record_exception(exception) span.set_status(Status(**status_kwargs)) # type: ignore[arg-type] + + def post_execute( + self, + message: "TaskiqMessage", + result: "TaskiqResult[Any]", + ) -> None: + """ + This function tracks number of errors and success executions. + + :param message: received message. + :param result: result of the execution. + """ + if result.is_err: + retry_on_error = message.labels.get("retry_on_error") + if isinstance(retry_on_error, str): + retry_on_error = retry_on_error.lower() == "true" + + if retry_on_error is None: + retry_on_error = False + + if retry_on_error: + # Add retry reason metadata to span + self.n_errors_counter.add( + 1, + attributes={"retry_error": True, "task_name": message.task_name}, + ) + else: + self.n_errors_counter.add( + 1, + attributes={"retry_error": False, "task_name": message.task_name}, + ) + else: + self.n_success_counter.add( + 1, + attributes={"task_name": message.task_name}, + ) + self.execution_time_hist.record( + result.execution_time, + attributes={ + "task_name": message.task_name, + }, + ) + task_receive_time = message.labels.get(_TASK_RECEIVED_TIME_KEY) + task_send_time = message.labels.get(_TASK_QUEUE_TIME_KEY) + if task_receive_time is not None and task_send_time is not None: + self.task_wait_time.record( + amount=task_receive_time - task_send_time, + attributes={"task_name": message.task_name}, + ) diff --git a/tests/opentelemetry/taskiq_test_tasks.py b/tests/opentelemetry/taskiq_test_tasks.py index d910313b..af2198f1 100644 --- a/tests/opentelemetry/taskiq_test_tasks.py +++ b/tests/opentelemetry/taskiq_test_tasks.py @@ -1,3 +1,4 @@ +import asyncio from typing import Any from opentelemetry import baggage @@ -26,3 +27,8 @@ async def task_raises() -> None: @broker.task async def task_returns_baggage() -> Any: return dict(baggage.get_all()) + + +@broker.task +async def task_does_processing(wait_time: float) -> None: + await asyncio.sleep(wait_time) diff --git a/tests/opentelemetry/test_metrics.py b/tests/opentelemetry/test_metrics.py new file mode 100644 index 00000000..3ce82de1 --- /dev/null +++ b/tests/opentelemetry/test_metrics.py @@ -0,0 +1,169 @@ +import asyncio +from typing import Any + +from opentelemetry.sdk.metrics import MeterProvider +from opentelemetry.sdk.metrics.export import InMemoryMetricReader +from opentelemetry.test.test_base import TestBase + +from taskiq.instrumentation import TaskiqInstrumentor + +from .taskiq_test_tasks import ( + broker, + task_add, + task_does_processing, + task_raises, +) + + +class TestTaskiqOTelMetrics(TestBase): + def setUp(self) -> None: + super().setUp() + self.reader = InMemoryMetricReader() + self.meter_provider = MeterProvider(metric_readers=[self.reader]) + TaskiqInstrumentor().instrument_broker( + broker, + meter_provider=self.meter_provider, + ) + + def tearDown(self) -> None: + super().tearDown() + TaskiqInstrumentor().uninstrument_broker(broker) + + def _get_data_points(self, metric_name: str) -> list[Any]: + metrics = self.reader.get_metrics_data() + if metrics is None: + return [] + return [ + point + for rm in metrics.resource_metrics + for sm in rm.scope_metrics + for metric in sm.metrics + if metric.name == metric_name + for point in metric.data.data_points + ] + + def test_metrics_exist(self) -> None: + async def test() -> None: + await task_add.kiq(1, 2) + await task_raises.kiq() + await broker.wait_all() + + asyncio.run(test()) + + metrics = self.reader.get_metrics_data() + self.assertIsNotNone(metrics) + expected = { + "task_errors", + "tasks_sent", + "task_success", + "task_execution_time", + "task_wait_time", + } + found = { + metric.name + for rm in metrics.resource_metrics # type: ignore[union-attr] + for sm in rm.scope_metrics + for metric in sm.metrics + } + self.assertSetEqual(found.intersection(expected), expected) + + def test_success_counter(self) -> None: + async def test() -> None: + for _ in range(3): + await task_add.kiq(1, 2) + await broker.wait_all() + + asyncio.run(test()) + + points = self._get_data_points("task_success") + self.assertEqual(len(points), 1) + self.assertEqual(points[0].value, 3) + + def test_error_counter_no_retry(self) -> None: + async def test() -> None: + for _ in range(3): + await task_raises.kiq() + await broker.wait_all() + + asyncio.run(test()) + + points = self._get_data_points("task_errors") + no_retry_points = [ + p for p in points if p.attributes.get("retry_error") is False + ] + self.assertEqual(len(no_retry_points), 1) + self.assertEqual(no_retry_points[0].value, 3) + + def test_error_counter_with_retry(self) -> None: + async def test() -> None: + for _ in range(3): + await task_raises.kicker().with_labels(retry_on_error="true").kiq() + await broker.wait_all() + + asyncio.run(test()) + + points = self._get_data_points("task_errors") + retry_points = [p for p in points if p.attributes.get("retry_error") is True] + self.assertEqual(len(retry_points), 1) + self.assertEqual(retry_points[0].value, 3) + + def test_execution_time_histogram(self) -> None: + async def test() -> None: + for _ in range(3): + await task_does_processing.kiq(0.01) + await broker.wait_all() + + asyncio.run(test()) + + points = self._get_data_points("task_execution_time") + self.assertEqual(len(points), 1) + self.assertEqual(points[0].count, 3) + self.assertGreater(points[0].sum, 0) + + def test_task_wait_time_histogram(self) -> None: + async def test() -> None: + await task_does_processing.kiq(0.01) + await broker.wait_all() + + asyncio.run(test()) + + points = self._get_data_points("task_wait_time") + self.assertEqual(len(points), 1) + self.assertEqual(points[0].count, 1) + self.assertGreaterEqual(points[0].sum, 0) + + def test_queue_time(self) -> None: + async def test() -> None: + for _ in range(3): + await task_add.kiq(1, 2) + await broker.wait_all() + + asyncio.run(test()) + + points = self._get_data_points("task_wait_time") + # all 3 tasks share the same task_name so they aggregate into one data point + self.assertEqual(len(points), 1) + point = points[0] + # 3 tasks recorded + self.assertEqual(point.count, 3) + # queue time must be non-negative — a negative value means timestamps + # were not written/read correctly + self.assertGreaterEqual(point.sum, 0) + self.assertGreaterEqual(point.min, 0) + # task_name attribute must be present and correct + self.assertEqual( + point.attributes.get("task_name"), + "tests.opentelemetry.taskiq_test_tasks:task_add", + ) + + def test_tasks_sent_counter(self) -> None: + async def test() -> None: + for _ in range(3): + await task_add.kiq(1, 2) + await broker.wait_all() + + asyncio.run(test()) + + points = self._get_data_points("tasks_sent") + self.assertEqual(len(points), 1) + self.assertEqual(points[0].value, 3) From c6270e84c80edb6a76e5574b5e8f4d6d9e0f69e8 Mon Sep 17 00:00:00 2001 From: mohammedtarek Date: Sat, 2 May 2026 23:45:27 +0300 Subject: [PATCH 2/3] Added cpu and memory utilization guages --- .../middlewares/opentelemetry_middleware.py | 29 +++++++++++++++++-- 1 file changed, 27 insertions(+), 2 deletions(-) diff --git a/taskiq/middlewares/opentelemetry_middleware.py b/taskiq/middlewares/opentelemetry_middleware.py index 3780184e..f723dcf4 100644 --- a/taskiq/middlewares/opentelemetry_middleware.py +++ b/taskiq/middlewares/opentelemetry_middleware.py @@ -2,7 +2,8 @@ from contextlib import AbstractContextManager from datetime import datetime, timezone from importlib.metadata import version -from typing import Any, TypeVar +from typing import Any, Generator, TypeVar +import psutil from packaging.version import Version, parse @@ -17,7 +18,7 @@ from opentelemetry import context as context_api from opentelemetry import trace -from opentelemetry.metrics import Meter, MeterProvider, get_meter +from opentelemetry.metrics import Meter, MeterProvider, Observation, get_meter from opentelemetry.propagate import extract, inject from opentelemetry.semconv.trace import SpanAttributes from opentelemetry.trace import Span, Tracer, TracerProvider @@ -205,6 +206,30 @@ def __init__( unit="s", description="Time the tasks waited before executing", ) + # current metrics to watch for in workers: CPU and memory utilization + self._process = psutil.Process() + # 6- CPU utilization + self.worker_cpu_utilization = self._meter.create_observable_gauge( + "worker_cpu_utilization", + callbacks=[self._observe_cpu], + unit="%", + description="Worker CPU utilization percentage. Only for worker processes", + ) + # 7- Memory utilization + self.worker_memory_utilization = self._meter.create_observable_gauge( + "worker_memory_utilization", + callbacks=[self._observe_memory], + unit="By", + description="Worker memory utilization in bytes. Only for worker processes", + ) + + def _observe_memory(self, _options: Any) -> Generator[Observation, None, None]: + if self.broker and self.broker.is_worker_process: + yield Observation(self._process.memory_info().rss) + + def _observe_cpu(self, _options: Any) -> Generator[Observation, None, None]: + if self.broker and self.broker.is_worker_process: + yield Observation(self._process.cpu_percent()) def pre_send(self, message: TaskiqMessage) -> TaskiqMessage: """ From 8bf420d9fc10e3d45f139817782eec88e52674cc Mon Sep 17 00:00:00 2001 From: mohammedtarek Date: Sun, 3 May 2026 00:55:56 +0300 Subject: [PATCH 3/3] feat: add worker queue and active task OTel metrics - Add worker_active_tasks UpDownCounter driven by pre/post_execute hooks - Add worker_prefetched_tasks UpDownCounter via on_prefetch_queue_add/remove hooks in receiver - Add worker_cpu_utilization and worker_memory_utilization observable gauges (worker process only) - Add tests for all new metrics --- .../middlewares/opentelemetry_middleware.py | 39 +++++++++++-- taskiq/receiver/receiver.py | 13 +++++ tests/opentelemetry/test_metrics.py | 55 +++++++++++++++++++ 3 files changed, 103 insertions(+), 4 deletions(-) diff --git a/taskiq/middlewares/opentelemetry_middleware.py b/taskiq/middlewares/opentelemetry_middleware.py index f723dcf4..f90c1299 100644 --- a/taskiq/middlewares/opentelemetry_middleware.py +++ b/taskiq/middlewares/opentelemetry_middleware.py @@ -1,10 +1,11 @@ import logging +from collections.abc import Generator from contextlib import AbstractContextManager from datetime import datetime, timezone from importlib.metadata import version -from typing import Any, Generator, TypeVar -import psutil +from typing import Any, TypeVar +import psutil from packaging.version import Version, parse try: @@ -223,11 +224,24 @@ def __init__( description="Worker memory utilization in bytes. Only for worker processes", ) - def _observe_memory(self, _options: Any) -> Generator[Observation, None, None]: + # 8- Number of tasks executing + self.number_of_broker_active_tasks = self._meter.create_up_down_counter( + "worker_active_tasks", + unit="1", + description="Number of tasks currently executing in the worker.", + ) + # 9- Number of tasks executing + self.number_of_broker_prefetched_tasks = self._meter.create_up_down_counter( + "worker_prefetched_tasks", + unit="1", + description="Number of tasks currently prefetched in the worker.", + ) + + def _observe_memory(self, options: Any) -> Generator[Observation, None, None]: if self.broker and self.broker.is_worker_process: yield Observation(self._process.memory_info().rss) - def _observe_cpu(self, _options: Any) -> Generator[Observation, None, None]: + def _observe_cpu(self, options: Any) -> Generator[Observation, None, None]: if self.broker and self.broker.is_worker_process: yield Observation(self._process.cpu_percent()) @@ -298,6 +312,10 @@ def pre_execute(self, message: TaskiqMessage) -> TaskiqMessage: activation.__enter__() # pylint: disable=E1101 attach_context(message, span, activation, token) message.labels[_TASK_RECEIVED_TIME_KEY] = datetime.now(timezone.utc).timestamp() + self.number_of_broker_active_tasks.add( + 1, + attributes={"task_name": message.task_name}, + ) return message def post_save( # pylint: disable=R6301 @@ -424,3 +442,16 @@ def post_execute( amount=task_receive_time - task_send_time, attributes={"task_name": message.task_name}, ) + + self.number_of_broker_active_tasks.add( + -1, + attributes={"task_name": message.task_name}, + ) + + def on_prefetch_queue_add(self) -> None: + """This hook is called after task is added to the worker prefetch queue.""" + self.number_of_broker_prefetched_tasks.add(1) + + def on_prefetch_queue_remove(self) -> None: + """This hook is called after task is removed from the worker prefetch queue.""" + self.number_of_broker_prefetched_tasks.add(-1) diff --git a/taskiq/receiver/receiver.py b/taskiq/receiver/receiver.py index 99298af2..6649886a 100644 --- a/taskiq/receiver/receiver.py +++ b/taskiq/receiver/receiver.py @@ -383,6 +383,12 @@ async def prefetcher( current_message = asyncio.create_task(iterator.__anext__()) # type: ignore fetched_tasks += 1 await queue.put(message) + # Custom hooks for OTel and any future instrumentations + for middleware in reversed(self.broker.middlewares): + if hasattr(middleware, "on_prefetch_queue_add"): + await maybe_awaitable( + middleware.on_prefetch_queue_add(), # type: ignore + ) except (asyncio.CancelledError, StopAsyncIteration): break # We don't want to fetch new messages if we are shutting down. @@ -434,6 +440,13 @@ def task_cb(task: "asyncio.Task[Any]") -> None: logger.info("No more tasks to wait for. Shutting down.") break + # Custom hooks for OTel and any future instrumentations + for middleware in reversed(self.broker.middlewares): + if hasattr(middleware, "on_prefetch_queue_remove"): + await maybe_awaitable( + middleware.on_prefetch_queue_remove(), # type: ignore + ) + task = asyncio.create_task( self.callback(message=message, raise_err=False), ) diff --git a/tests/opentelemetry/test_metrics.py b/tests/opentelemetry/test_metrics.py index 3ce82de1..d12291fc 100644 --- a/tests/opentelemetry/test_metrics.py +++ b/tests/opentelemetry/test_metrics.py @@ -6,6 +6,7 @@ from opentelemetry.test.test_base import TestBase from taskiq.instrumentation import TaskiqInstrumentor +from taskiq.middlewares.opentelemetry_middleware import OpenTelemetryMiddleware from .taskiq_test_tasks import ( broker, @@ -58,6 +59,7 @@ async def test() -> None: "task_success", "task_execution_time", "task_wait_time", + "worker_active_tasks", } found = { metric.name @@ -167,3 +169,56 @@ async def test() -> None: points = self._get_data_points("tasks_sent") self.assertEqual(len(points), 1) self.assertEqual(points[0].value, 3) + + def test_active_tasks_counter(self) -> None: + async def test() -> None: + for _ in range(3): + await task_add.kiq(1, 2) + await broker.wait_all() + + asyncio.run(test()) + + points = self._get_data_points("worker_active_tasks") + # all 3 tasks share the same task_name so they aggregate into one data point + self.assertEqual(len(points), 1) + # net zero: pre_execute incremented, post_execute decremented for each task + self.assertEqual(points[0].value, 0) + self.assertIn("task_name", points[0].attributes) + self.assertEqual( + points[0].attributes.get("task_name"), + "tests.opentelemetry.taskiq_test_tasks:task_add", + ) + + def test_prefetch_queue_counter(self) -> None: + middleware = next( + m for m in broker.middlewares if isinstance(m, OpenTelemetryMiddleware) + ) + middleware.on_prefetch_queue_add() + middleware.on_prefetch_queue_add() + middleware.on_prefetch_queue_add() + middleware.on_prefetch_queue_remove() + + points = self._get_data_points("worker_prefetched_tasks") + self.assertEqual(len(points), 1) + self.assertEqual(points[0].value, 2) + + def test_worker_resource_metrics_when_worker_process(self) -> None: + middleware = next( + m for m in broker.middlewares if isinstance(m, OpenTelemetryMiddleware) + ) + middleware.set_broker(broker) + broker.is_worker_process = True + try: + metrics_data = self.reader.get_metrics_data() + self.assertIsNotNone(metrics_data) + found = { + metric.name + for rm in metrics_data.resource_metrics # type: ignore[union-attr] + for sm in rm.scope_metrics + for metric in sm.metrics + } + self.assertIn("worker_cpu_utilization", found) + self.assertIn("worker_memory_utilization", found) + finally: + broker.is_worker_process = False + middleware.set_broker(None) # type: ignore[arg-type]