Source code for samgis_web.web.middlewares
import time
from typing import Callable
import structlog
from asgi_correlation_id.context import correlation_id
from fastapi import Request, Response
from uvicorn.protocols.utils import get_path_with_query_string
from samgis_core import app_logger
[docs]
async def logging_middleware(request: Request, call_next: Callable) -> Response:
"""
Logging middleware to inject a correlation id in a fastapi application. Requires:
- structlog.stdlib logger
- setup_logging (samgis_core.utilities.session_logger package)
- CorrelationIdMiddleware (asgi_correlation_id package)
See tests/web/test_middlewares.py for an example based on a real fastapi application.
Args:
request: fastapi Request
call_next: next callable function
Returns:
fastapi Response
"""
structlog.contextvars.clear_contextvars()
# These context vars will be added to all log entries emitted during the request
request_id = correlation_id.get()
app_logger.debug(f"request_id:{request_id}.")
structlog.contextvars.bind_contextvars(request_id=request_id)
start_time = time.perf_counter_ns()
# If the call_next raises an error, we still want to return our own 500 response,
# so we can add headers to it (process time, request ID...)
response = Response(status_code=500)
try:
response = await call_next(request)
except Exception:
# TODO: Validate that we don't swallow exceptions (unit test?)
structlog.stdlib.get_logger("api.error").exception("Uncaught exception")
raise
finally:
process_time = time.perf_counter_ns() - start_time
status_code = response.status_code
url = get_path_with_query_string(request.scope)
client_host = request.client.host
client_port = request.client.port
http_method = request.method
http_version = request.scope["http_version"]
# Recreate the Uvicorn access log format, but add all parameters as structured information
app_logger.info(
f"""{client_host}:{client_port} - "{http_method} {url} HTTP/{http_version}" {status_code}""",
http={
"url": str(request.url),
"status_code": status_code,
"method": http_method,
"request_id": request_id,
"version": http_version,
},
network={"client": {"ip": client_host, "port": client_port}},
duration=process_time,
)
response.headers["X-Process-Time"] = str(process_time / 10 ** 9)
return response