Skip to content
83 changes: 83 additions & 0 deletions tests/integrations/mcp/streaming_asgi_transport.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import asyncio
from httpx import ASGITransport, Request, Response, AsyncByteStream
import anyio

from typing import TYPE_CHECKING

if TYPE_CHECKING:
from typing import Any, Callable, MutableMapping


class StreamingASGITransport(ASGITransport):
"""
Simple transport whose only purpose is to keep GET request alive in SSE connections, allowing
tests involving SSE interactions to run in-process.
"""

def __init__(
self,
app: "Callable",
keep_sse_alive: "asyncio.Event",
) -> None:
self.keep_sse_alive = keep_sse_alive
super().__init__(app)

async def handle_async_request(self, request: "Request") -> "Response":
scope = {
"type": "http",
"method": request.method,
"headers": [(k.lower(), v) for (k, v) in request.headers.raw],
"path": request.url.path,
"query_string": request.url.query,
}

is_streaming_sse = scope["method"] == "GET" and scope["path"] == "/sse"
if not is_streaming_sse:
return await super().handle_async_request(request)

request_body = b""
if request.content:
request_body = await request.aread()

body_sender, body_receiver = anyio.create_memory_object_stream[bytes](0)

async def receive() -> "dict[str, Any]":
if self.keep_sse_alive.is_set():
return {"type": "http.disconnect"}

await self.keep_sse_alive.wait() # Keep alive :)
return {"type": "http.request", "body": request_body, "more_body": False}

async def send(message: "MutableMapping[str, Any]") -> None:
if message["type"] == "http.response.body":
body = message.get("body", b"")
more_body = message.get("more_body", False)

if body == b"" and not more_body:
return

if body:
await body_sender.send(body)

if not more_body:
await body_sender.aclose()

async def run_app():
await self.app(scope, receive, send)

class StreamingBodyStream(AsyncByteStream):
def __init__(self, receiver):
self.receiver = receiver

async def __aiter__(self):
try:
async for chunk in self.receiver:
yield chunk
except anyio.EndOfStream:
pass

stream = StreamingBodyStream(body_receiver)
response = Response(status_code=200, headers=[], stream=stream)

asyncio.create_task(run_app())
return response
186 changes: 148 additions & 38 deletions tests/integrations/mcp/test_mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@
that the integration properly instruments MCP handlers with Sentry spans.
"""

from urllib.parse import urlparse, parse_qs
import anyio
import asyncio
import httpx
from .streaming_asgi_transport import StreamingASGITransport

import pytest
import json
from unittest import mock
Expand All @@ -32,9 +38,10 @@ async def __call__(self, *args, **kwargs):
from mcp.server.lowlevel.helper_types import ReadResourceContents
from mcp.server.lowlevel import Server
from mcp.server.lowlevel.server import request_ctx
from mcp.server.sse import SseServerTransport
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager

from starlette.routing import Mount
from starlette.routing import Mount, Route, Response
from starlette.applications import Starlette

try:
Expand Down Expand Up @@ -66,39 +73,103 @@ def reset_request_ctx():
pass


class MockRequestContext:
"""Mock MCP request context"""

def __init__(self, request_id=None, session_id=None, transport="stdio"):
self.request_id = request_id
if transport in ("http", "sse"):
self.request = MockHTTPRequest(session_id, transport)
else:
self.request = None
class MockTextContent:
"""Mock TextContent object"""

def __init__(self, text):
self.text = text

class MockHTTPRequest:
"""Mock HTTP request for SSE/StreamableHTTP transport"""

def __init__(self, session_id=None, transport="http"):
self.headers = {}
self.query_params = {}
async def json_rpc_sse(
app, method: str, params, request_id: str, keep_sse_alive: "asyncio.Event"
):
context = {}

stream_complete = asyncio.Event()
endpoint_parsed = asyncio.Event()

# https://github.com/Kludex/starlette/issues/104#issuecomment-729087925
async with httpx.AsyncClient(
transport=StreamingASGITransport(app=app, keep_sse_alive=keep_sse_alive),
base_url="http://test",
) as client:

async def parse_stream():
async with client.stream("GET", "/sse") as stream:
# Read directly from stream.stream instead of aiter_bytes()
async for chunk in stream.stream:
if b"event: endpoint" in chunk:
sse_text = chunk.decode("utf-8")
url = sse_text.split("data: ")[1]

parsed = urlparse(url)
query_params = parse_qs(parsed.query)
context["session_id"] = query_params["session_id"][0]
endpoint_parsed.set()
continue

if b"event: message" in chunk and b"structuredContent" in chunk:
sse_text = chunk.decode("utf-8")

json_str = sse_text.split("data: ")[1]
context["response"] = json.loads(json_str)
break

stream_complete.set()

task = asyncio.create_task(parse_stream())
await endpoint_parsed.wait()

await client.post(
f"/messages/?session_id={context['session_id']}",
headers={
"Content-Type": "application/json",
},
json={
"jsonrpc": "2.0",
"method": "initialize",
"params": {
"clientInfo": {"name": "test-client", "version": "1.0"},
"protocolVersion": "2025-11-25",
"capabilities": {},
},
"id": request_id,
},
)

if transport == "sse":
# SSE transport uses query parameter
if session_id:
self.query_params["session_id"] = session_id
else:
# StreamableHTTP transport uses header
if session_id:
self.headers["mcp-session-id"] = session_id
# Notification response is mandatory.
# https://modelcontextprotocol.io/specification/2025-11-25/basic/lifecycle
await client.post(
f"/messages/?session_id={context['session_id']}",
headers={
"Content-Type": "application/json",
"mcp-session-id": context["session_id"],
},
json={
"jsonrpc": "2.0",
"method": "notifications/initialized",
"params": {},
},
)

await client.post(
f"/messages/?session_id={context['session_id']}",
headers={
"Content-Type": "application/json",
"mcp-session-id": context["session_id"],
},
json={
"jsonrpc": "2.0",
"method": method,
"params": params,
"id": request_id,
},
)

class MockTextContent:
"""Mock TextContent object"""
await stream_complete.wait()
keep_sse_alive.set()

def __init__(self, text):
self.text = text
return task, context["session_id"], context["response"]


def test_integration_patches_server(sentry_init):
Expand Down Expand Up @@ -985,7 +1056,8 @@ def test_tool_complex(tool_name, arguments):
assert span["data"]["mcp.request.argument.number"] == "42"


def test_sse_transport_detection(sentry_init, capture_events):
@pytest.mark.asyncio
async def test_sse_transport_detection(sentry_init, capture_events):
"""Test that SSE transport is correctly detected via query parameter"""
sentry_init(
integrations=[MCPIntegration()],
Expand All @@ -994,29 +1066,67 @@ def test_sse_transport_detection(sentry_init, capture_events):
events = capture_events()

server = Server("test-server")
sse = SseServerTransport("/messages/")

# Set up mock request context with SSE transport
mock_ctx = MockRequestContext(
request_id="req-sse", session_id="session-sse-123", transport="sse"
sse_connection_closed = asyncio.Event()

async def handle_sse(request):
async with sse.connect_sse(
request.scope, request.receive, request._send
) as streams:
async with anyio.create_task_group() as tg:

async def run_server():
await server.run(
streams[0], streams[1], server.create_initialization_options()
)

tg.start_soon(run_server)

sse_connection_closed.set()
return Response()

app = Starlette(
routes=[
Route("/sse", endpoint=handle_sse, methods=["GET"]),
Mount("/messages/", app=sse.handle_post_message),
],
)
request_ctx.set(mock_ctx)

@server.call_tool()
def test_tool(tool_name, arguments):
async def test_tool(tool_name, arguments):
return {"result": "success"}

with start_transaction(name="mcp tx"):
result = test_tool("sse_tool", {})
keep_sse_alive = asyncio.Event()
app_task, session_id, result = await json_rpc_sse(
app,
method="tools/call",
params={
"name": "sse_tool",
"arguments": {},
},
request_id="req-sse",
keep_sse_alive=keep_sse_alive,
)

assert result == {"result": "success"}
await sse_connection_closed.wait()
await app_task

(tx,) = events
assert result["result"]["structuredContent"] == {"result": "success"}

transactions = [
event
for event in events
if event["type"] == "transaction" and event["transaction"] == "/sse"
]
assert len(transactions) == 1
tx = transactions[0]
span = tx["spans"][0]

# Check that SSE transport is detected
assert span["data"][SPANDATA.MCP_TRANSPORT] == "sse"
assert span["data"][SPANDATA.NETWORK_TRANSPORT] == "tcp"
assert span["data"][SPANDATA.MCP_SESSION_ID] == "session-sse-123"
assert span["data"][SPANDATA.MCP_SESSION_ID] == session_id


def test_streamable_http_transport_detection(
Expand Down
Loading