Log response body in api.raw_request for all status codes (#SKY-7987) (#4779)
Co-authored-by: Suchintan Singh <suchintan@skyvern.com>
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import time
|
||||
import typing
|
||||
|
||||
@@ -24,7 +25,29 @@ _SENSITIVE_ENDPOINTS = {
|
||||
"POST /v1/credentials/azure_credential/create",
|
||||
}
|
||||
_MAX_BODY_LENGTH = 1000
|
||||
_MAX_RESPONSE_READ_BYTES = 1024 * 1024 # 1 MB — skip logging bodies larger than this
|
||||
_BINARY_PLACEHOLDER = "<binary>"
|
||||
_REDACTED = "****"
|
||||
_LOGGABLE_CONTENT_TYPES = {"text/", "application/json"}
|
||||
_STREAMING_CONTENT_TYPE = "text/event-stream"
|
||||
|
||||
# Exact field names that are always redacted. Use a set for O(1) lookup
|
||||
# instead of regex substring matching to avoid false positives like
|
||||
# credential_id, author, page_token, etc.
|
||||
_SENSITIVE_FIELDS: set[str] = {
|
||||
"password",
|
||||
"secret",
|
||||
"token",
|
||||
"api_key",
|
||||
"apikey",
|
||||
"api-key",
|
||||
"credential",
|
||||
"access_key",
|
||||
"private_key",
|
||||
"auth",
|
||||
"authorization",
|
||||
"secret_key",
|
||||
}
|
||||
|
||||
|
||||
def _sanitize_headers(headers: typing.Mapping[str, str]) -> dict[str, str]:
|
||||
@@ -38,7 +61,7 @@ def _sanitize_headers(headers: typing.Mapping[str, str]) -> dict[str, str]:
|
||||
|
||||
def _sanitize_body(request: Request, body: bytes, content_type: str | None) -> str:
|
||||
if f"{request.method.upper()} {request.url.path.rstrip('/')}" in _SENSITIVE_ENDPOINTS:
|
||||
return "****"
|
||||
return _REDACTED
|
||||
if not body:
|
||||
return ""
|
||||
if content_type and not (content_type.startswith("text/") or content_type.startswith("application/json")):
|
||||
@@ -52,18 +75,76 @@ def _sanitize_body(request: Request, body: bytes, content_type: str | None) -> s
|
||||
return text
|
||||
|
||||
|
||||
async def _get_response_body_str(response: Response) -> str:
|
||||
def _is_sensitive_key(key: str) -> bool:
|
||||
return key.lower() in _SENSITIVE_FIELDS
|
||||
|
||||
|
||||
def _redact_sensitive_fields(obj: typing.Any, _depth: int = 0) -> typing.Any:
|
||||
"""Redact dict values whose *key name* exactly matches a known sensitive field.
|
||||
|
||||
Uses exact-match (case-insensitive) rather than substring/regex to avoid
|
||||
false positives on fields like ``credential_id``, ``author``, or
|
||||
``page_token`` which contain sensitive substrings but are not secrets.
|
||||
"""
|
||||
if _depth > 20:
|
||||
# Stop recursing but still redact sensitive keys at this level
|
||||
if isinstance(obj, dict):
|
||||
return {k: _REDACTED if _is_sensitive_key(k) else v for k, v in obj.items()}
|
||||
return obj
|
||||
if isinstance(obj, dict):
|
||||
return {
|
||||
k: _REDACTED if _is_sensitive_key(k) else _redact_sensitive_fields(v, _depth + 1) for k, v in obj.items()
|
||||
}
|
||||
if isinstance(obj, list):
|
||||
return [_redact_sensitive_fields(item, _depth + 1) for item in obj]
|
||||
return obj
|
||||
|
||||
|
||||
def _is_loggable_content_type(content_type: str | None) -> bool:
|
||||
if not content_type:
|
||||
return True # assume text when header is missing
|
||||
return any(content_type.startswith(prefix) for prefix in _LOGGABLE_CONTENT_TYPES)
|
||||
|
||||
|
||||
def _sanitize_response_body(request: Request, body_str: str | None, content_type: str | None) -> str:
|
||||
if f"{request.method.upper()} {request.url.path.rstrip('/')}" in _SENSITIVE_ENDPOINTS:
|
||||
return _REDACTED
|
||||
if body_str is None:
|
||||
return _BINARY_PLACEHOLDER
|
||||
if not body_str:
|
||||
return ""
|
||||
if not _is_loggable_content_type(content_type):
|
||||
return _BINARY_PLACEHOLDER
|
||||
try:
|
||||
parsed = json.loads(body_str)
|
||||
redacted = _redact_sensitive_fields(parsed)
|
||||
text = json.dumps(redacted)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
text = body_str
|
||||
if len(text) > _MAX_BODY_LENGTH:
|
||||
return text[:_MAX_BODY_LENGTH] + "...[truncated]"
|
||||
return text
|
||||
|
||||
|
||||
async def _get_response_body_str(response: Response) -> str | None:
|
||||
"""Read and reconstitute the response body for logging.
|
||||
|
||||
Returns ``None`` when the body is binary or exceeds
|
||||
``_MAX_RESPONSE_READ_BYTES`` to avoid buffering large payloads
|
||||
solely for logging purposes.
|
||||
"""
|
||||
response_body = b""
|
||||
async for chunk in response.body_iterator:
|
||||
response_body += chunk
|
||||
response.body_iterator = iterate_in_threadpool(iter([response_body]))
|
||||
|
||||
if len(response_body) > _MAX_RESPONSE_READ_BYTES:
|
||||
return None
|
||||
|
||||
try:
|
||||
return response_body.decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
return str(response_body)
|
||||
except Exception:
|
||||
return str(response_body)
|
||||
return None
|
||||
|
||||
|
||||
async def log_raw_request_middleware(request: Request, call_next: Callable[[Request], Awaitable[Response]]) -> Response:
|
||||
@@ -88,13 +169,17 @@ async def log_raw_request_middleware(request: Request, call_next: Callable[[Requ
|
||||
|
||||
if response.status_code >= 500:
|
||||
log_method = LOG.error
|
||||
error_body = await _get_response_body_str(response)
|
||||
elif response.status_code >= 400:
|
||||
log_method = LOG.warning
|
||||
error_body = await _get_response_body_str(response)
|
||||
else:
|
||||
log_method = LOG.info
|
||||
error_body = None
|
||||
|
||||
resp_content_type = response.headers.get("content-type", "")
|
||||
if _STREAMING_CONTENT_TYPE in resp_content_type:
|
||||
response_body = "<streaming>"
|
||||
else:
|
||||
raw_response_body = await _get_response_body_str(response)
|
||||
response_body = _sanitize_response_body(request, raw_response_body, resp_content_type)
|
||||
|
||||
log_method(
|
||||
"api.raw_request",
|
||||
@@ -103,7 +188,9 @@ async def log_raw_request_middleware(request: Request, call_next: Callable[[Requ
|
||||
status_code=response.status_code,
|
||||
body=body_text,
|
||||
headers=sanitized_headers,
|
||||
error_body=error_body,
|
||||
response_body=response_body,
|
||||
# backwards-compat: keep error_body for existing Datadog queries
|
||||
error_body=response_body if response.status_code >= 400 else None,
|
||||
duration_seconds=time.monotonic() - start_time,
|
||||
)
|
||||
return response
|
||||
|
||||
259
tests/unit/test_request_logging.py
Normal file
259
tests/unit/test_request_logging.py
Normal file
@@ -0,0 +1,259 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from skyvern.forge.request_logging import (
|
||||
_BINARY_PLACEHOLDER,
|
||||
_MAX_BODY_LENGTH,
|
||||
_REDACTED,
|
||||
_is_loggable_content_type,
|
||||
_is_sensitive_key,
|
||||
_redact_sensitive_fields,
|
||||
_sanitize_response_body,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _is_sensitive_key — documents exactly which field names are redacted
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestIsSensitiveKey:
|
||||
"""These tests serve as living documentation of the redaction rules.
|
||||
|
||||
If you need to add or remove a field, update ``_SENSITIVE_FIELDS`` in
|
||||
``request_logging.py`` and add a corresponding test case here.
|
||||
"""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"key",
|
||||
[
|
||||
"password",
|
||||
"Password",
|
||||
"PASSWORD",
|
||||
"secret",
|
||||
"token",
|
||||
"api_key",
|
||||
"apikey",
|
||||
"api-key",
|
||||
"credential",
|
||||
"access_key",
|
||||
"private_key",
|
||||
"auth",
|
||||
"authorization",
|
||||
"secret_key",
|
||||
],
|
||||
)
|
||||
def test_sensitive_keys_are_redacted(self, key: str) -> None:
|
||||
assert _is_sensitive_key(key) is True, f"Expected '{key}' to be sensitive"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"key",
|
||||
[
|
||||
# Suffixed IDs / metadata — should NOT be redacted
|
||||
"credential_id",
|
||||
"credential_type",
|
||||
"token_type",
|
||||
"token_count",
|
||||
"access_key_id",
|
||||
# Pagination cursors
|
||||
"next_token",
|
||||
"page_token",
|
||||
"cursor_token",
|
||||
# Author / authentication metadata
|
||||
"author",
|
||||
"authenticated",
|
||||
"authenticated_at",
|
||||
"authorization_url",
|
||||
"auth_method",
|
||||
# Other safe fields
|
||||
"secret_name",
|
||||
"password_updated_at",
|
||||
"api_key_id",
|
||||
],
|
||||
)
|
||||
def test_non_sensitive_keys_are_preserved(self, key: str) -> None:
|
||||
assert _is_sensitive_key(key) is False, f"Expected '{key}' to NOT be sensitive"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _redact_sensitive_fields
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRedactSensitiveFields:
|
||||
def test_redacts_password(self) -> None:
|
||||
data = {"username": "alice", "password": "secret123"}
|
||||
result = _redact_sensitive_fields(data)
|
||||
assert result["username"] == "alice"
|
||||
assert result["password"] == _REDACTED
|
||||
|
||||
def test_redacts_nested_keys(self) -> None:
|
||||
data = {"user": {"api_key": "key123", "name": "bob"}}
|
||||
result = _redact_sensitive_fields(data)
|
||||
assert result["user"]["api_key"] == _REDACTED
|
||||
assert result["user"]["name"] == "bob"
|
||||
|
||||
def test_redacts_in_lists(self) -> None:
|
||||
data = [{"token": "abc"}, {"name": "ok"}]
|
||||
result = _redact_sensitive_fields(data)
|
||||
assert result[0]["token"] == _REDACTED
|
||||
assert result[1]["name"] == "ok"
|
||||
|
||||
def test_redacts_various_sensitive_keys(self) -> None:
|
||||
data = {
|
||||
"access_key": "a",
|
||||
"private_key": "b",
|
||||
"credential": "c",
|
||||
"secret": "d",
|
||||
"apikey": "e",
|
||||
"api-key": "f",
|
||||
"api_key": "g",
|
||||
"Authorization": "h",
|
||||
}
|
||||
result = _redact_sensitive_fields(data)
|
||||
for key in data:
|
||||
assert result[key] == _REDACTED, f"Expected {key} to be redacted"
|
||||
|
||||
def test_preserves_non_sensitive_suffixed_keys(self) -> None:
|
||||
"""Fields like credential_id and page_token must NOT be redacted."""
|
||||
data = {
|
||||
"credential_id": "cred_123",
|
||||
"credential_type": "oauth",
|
||||
"page_token": "abc",
|
||||
"author": "alice",
|
||||
"token_count": 42,
|
||||
}
|
||||
result = _redact_sensitive_fields(data)
|
||||
assert result == data
|
||||
|
||||
def test_depth_limit_prevents_crash(self) -> None:
|
||||
deep: dict = {}
|
||||
current = deep
|
||||
for _ in range(30):
|
||||
current["nested"] = {}
|
||||
current = current["nested"]
|
||||
current["password"] = "should_not_crash"
|
||||
|
||||
result = _redact_sensitive_fields(deep)
|
||||
assert result is not None # should not raise RecursionError
|
||||
|
||||
def test_depth_limit_still_redacts_keys_at_boundary(self) -> None:
|
||||
"""Sensitive keys at the depth boundary must still be redacted."""
|
||||
# depth 0: top dict, depth 1: "level" value, depths 2-20: 19 "next" dicts, depth 21: leaf
|
||||
deep: dict = {"level": {}}
|
||||
current = deep["level"]
|
||||
for _ in range(19):
|
||||
current["next"] = {}
|
||||
current = current["next"]
|
||||
current["password"] = "leak_me"
|
||||
current["safe"] = "visible"
|
||||
|
||||
result = _redact_sensitive_fields(deep)
|
||||
node = result["level"]
|
||||
for _ in range(19):
|
||||
node = node["next"]
|
||||
assert node["password"] == _REDACTED
|
||||
assert node["safe"] == "visible"
|
||||
|
||||
def test_preserves_non_sensitive_values(self) -> None:
|
||||
data = {"status": "ok", "count": 42, "items": [1, 2, 3]}
|
||||
result = _redact_sensitive_fields(data)
|
||||
assert result == data
|
||||
|
||||
def test_handles_non_dict_non_list(self) -> None:
|
||||
assert _redact_sensitive_fields("hello") == "hello"
|
||||
assert _redact_sensitive_fields(42) == 42
|
||||
assert _redact_sensitive_fields(None) is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _is_loggable_content_type
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestIsLoggableContentType:
|
||||
def test_json_is_loggable(self) -> None:
|
||||
assert _is_loggable_content_type("application/json") is True
|
||||
assert _is_loggable_content_type("application/json; charset=utf-8") is True
|
||||
|
||||
def test_text_is_loggable(self) -> None:
|
||||
assert _is_loggable_content_type("text/plain") is True
|
||||
assert _is_loggable_content_type("text/html") is True
|
||||
|
||||
def test_binary_is_not_loggable(self) -> None:
|
||||
assert _is_loggable_content_type("application/octet-stream") is False
|
||||
assert _is_loggable_content_type("image/png") is False
|
||||
|
||||
def test_none_defaults_to_loggable(self) -> None:
|
||||
assert _is_loggable_content_type(None) is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _sanitize_response_body
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_request(method: str = "GET", path: str = "/api/v1/test") -> MagicMock:
|
||||
request = MagicMock()
|
||||
request.method = method
|
||||
request.url.path = path
|
||||
return request
|
||||
|
||||
|
||||
class TestSanitizeResponseBody:
|
||||
def test_sensitive_endpoint_fully_redacted(self) -> None:
|
||||
request = _make_request("POST", "/api/v1/credentials")
|
||||
result = _sanitize_response_body(request, '{"token": "abc"}', "application/json")
|
||||
assert result == _REDACTED
|
||||
|
||||
def test_empty_body(self) -> None:
|
||||
request = _make_request()
|
||||
assert _sanitize_response_body(request, "", "application/json") == ""
|
||||
|
||||
def test_none_body_returns_binary_placeholder(self) -> None:
|
||||
request = _make_request()
|
||||
assert _sanitize_response_body(request, None, "application/json") == _BINARY_PLACEHOLDER
|
||||
|
||||
def test_binary_content_type_returns_placeholder(self) -> None:
|
||||
request = _make_request()
|
||||
result = _sanitize_response_body(request, "some bytes", "application/octet-stream")
|
||||
assert result == _BINARY_PLACEHOLDER
|
||||
|
||||
def test_json_fields_are_redacted(self) -> None:
|
||||
request = _make_request()
|
||||
body = json.dumps({"user": "alice", "password": "hunter2", "api_key": "sk-123"})
|
||||
result = _sanitize_response_body(request, body, "application/json")
|
||||
parsed = json.loads(result)
|
||||
assert parsed["user"] == "alice"
|
||||
assert parsed["password"] == _REDACTED
|
||||
assert parsed["api_key"] == _REDACTED
|
||||
|
||||
def test_json_preserves_non_sensitive_suffixed_keys(self) -> None:
|
||||
"""credential_id and page_token in responses must remain visible for debugging."""
|
||||
request = _make_request()
|
||||
body = json.dumps({"credential_id": "cred_123", "page_token": "abc", "author": "bob"})
|
||||
result = _sanitize_response_body(request, body, "application/json")
|
||||
parsed = json.loads(result)
|
||||
assert parsed["credential_id"] == "cred_123"
|
||||
assert parsed["page_token"] == "abc"
|
||||
assert parsed["author"] == "bob"
|
||||
|
||||
def test_non_json_body_returned_as_is(self) -> None:
|
||||
request = _make_request()
|
||||
result = _sanitize_response_body(request, "plain text response", "text/plain")
|
||||
assert result == "plain text response"
|
||||
|
||||
def test_truncates_long_body(self) -> None:
|
||||
request = _make_request()
|
||||
long_body = "x" * (_MAX_BODY_LENGTH + 500)
|
||||
result = _sanitize_response_body(request, long_body, "text/plain")
|
||||
assert result.endswith("...[truncated]")
|
||||
assert len(result) == _MAX_BODY_LENGTH + len("...[truncated]")
|
||||
|
||||
def test_sensitive_endpoint_trailing_slash(self) -> None:
|
||||
request = _make_request("POST", "/api/v1/credentials/")
|
||||
result = _sanitize_response_body(request, '{"data": "value"}', "application/json")
|
||||
assert result == _REDACTED
|
||||
Reference in New Issue
Block a user