Skip to content
209 changes: 207 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import json
import os
import asyncio
from urllib.parse import urlparse, parse_qs
import socket
import warnings
import brotli
Expand Down Expand Up @@ -51,25 +53,40 @@
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from typing import Optional
from typing import Any, Callable, MutableMapping, Optional
from collections.abc import Iterator

try:
from anyio import create_memory_object_stream, create_task_group
from anyio import create_memory_object_stream, create_task_group, EndOfStream
from mcp.types import (
JSONRPCMessage,
JSONRPCNotification,
JSONRPCRequest,
)
from mcp.shared.message import SessionMessage
from httpx import (
ASGITransport,
Request as HttpxRequest,
Response as HttpxResponse,
AsyncByteStream,
AsyncClient,
)
except ImportError:
create_memory_object_stream = None
create_task_group = None
EndOfStream = None

JSONRPCMessage = None
JSONRPCNotification = None
JSONRPCRequest = None
SessionMessage = None

ASGITransport = None
HttpxRequest = None
HttpxResponse = None
AsyncByteStream = None
AsyncClient = None


SENTRY_EVENT_SCHEMA = "./checkouts/data-schemas/relay/event.schema.json"

Expand Down Expand Up @@ -787,6 +804,194 @@ def inner(events):
return inner


@pytest.fixture()
def json_rpc_sse(is_structured_content: bool = True):
class StreamingASGITransport(ASGITransport):
"""
Simple transport whose only purpose is to keep GET request alive in SSE connections, allowing
tests involving SSE interactions to run in-process.
"""

def __init__(
self,
app: "Callable",
keep_sse_alive: "asyncio.Event",
) -> None:
self.keep_sse_alive = keep_sse_alive
super().__init__(app)

async def handle_async_request(
self, request: "HttpxRequest"
) -> "HttpxResponse":
scope = {
"type": "http",
"method": request.method,
"headers": [(k.lower(), v) for (k, v) in request.headers.raw],
"path": request.url.path,
"query_string": request.url.query,
}

is_streaming_sse = scope["method"] == "GET" and scope["path"] == "/sse"
if not is_streaming_sse:
return await super().handle_async_request(request)

request_body = b""
if request.content:
request_body = await request.aread()

body_sender, body_receiver = create_memory_object_stream[bytes](0) # type: ignore

async def receive() -> "dict[str, Any]":
if self.keep_sse_alive.is_set():
return {"type": "http.disconnect"}

await self.keep_sse_alive.wait() # Keep alive :)
return {
"type": "http.request",
"body": request_body,
"more_body": False,
}

async def send(message: "MutableMapping[str, Any]") -> None:
if message["type"] == "http.response.body":
body = message.get("body", b"")
more_body = message.get("more_body", False)

if body == b"" and not more_body:
return

if body:
await body_sender.send(body)

if not more_body:
await body_sender.aclose()

async def run_app():
await self.app(scope, receive, send)

class StreamingBodyStream(AsyncByteStream): # type: ignore
def __init__(self, receiver):
self.receiver = receiver

async def __aiter__(self):
try:
async for chunk in self.receiver:
yield chunk
except EndOfStream: # type: ignore
pass

stream = StreamingBodyStream(body_receiver)
response = HttpxResponse(status_code=200, headers=[], stream=stream) # type: ignore

asyncio.create_task(run_app())
return response

def parse_sse_data_package(sse_chunk):
sse_text = sse_chunk.decode("utf-8")
json_str = sse_text.split("data: ")[1]
return json.loads(json_str)

async def inner(
app, method: str, params, request_id: str, keep_sse_alive: "asyncio.Event"
):
context = {}

stream_complete = asyncio.Event()
endpoint_parsed = asyncio.Event()

# https://github.com/Kludex/starlette/issues/104#issuecomment-729087925
async with AsyncClient( # type: ignore
transport=StreamingASGITransport(app=app, keep_sse_alive=keep_sse_alive),
base_url="http://test",
) as client:

async def parse_stream():
async with client.stream("GET", "/sse") as stream:
# Read directly from stream.stream instead of aiter_bytes()
async for chunk in stream.stream:
if b"event: endpoint" in chunk:
sse_text = chunk.decode("utf-8")
url = sse_text.split("data: ")[1]

parsed = urlparse(url)
query_params = parse_qs(parsed.query)
context["session_id"] = query_params["session_id"][0]
endpoint_parsed.set()
continue

if (
is_structured_content
and b"event: message" in chunk
and b"structuredContent" in chunk
):
context["response"] = parse_sse_data_package(chunk)
break
elif (
"result" in parse_sse_data_package(chunk)
and "content" in parse_sse_data_package(chunk)["result"]
):
context["response"] = parse_sse_data_package(chunk)
break

stream_complete.set()

task = asyncio.create_task(parse_stream())
await endpoint_parsed.wait()

await client.post(
f"/messages/?session_id={context['session_id']}",
headers={
"Content-Type": "application/json",
},
json={
"jsonrpc": "2.0",
"method": "initialize",
"params": {
"clientInfo": {"name": "test-client", "version": "1.0"},
"protocolVersion": "2025-11-25",
"capabilities": {},
},
"id": request_id,
},
)

# Notification response is mandatory.
# https://modelcontextprotocol.io/specification/2025-11-25/basic/lifecycle
await client.post(
f"/messages/?session_id={context['session_id']}",
headers={
"Content-Type": "application/json",
"mcp-session-id": context["session_id"],
},
json={
"jsonrpc": "2.0",
"method": "notifications/initialized",
"params": {},
},
)

await client.post(
f"/messages/?session_id={context['session_id']}",
headers={
"Content-Type": "application/json",
"mcp-session-id": context["session_id"],
},
json={
"jsonrpc": "2.0",
"method": method,
"params": params,
"id": request_id,
},
)

await stream_complete.wait()
keep_sse_alive.set()

return task, context["session_id"], context["response"]

return inner


class MockServerRequestHandler(BaseHTTPRequestHandler):
def do_GET(self): # noqa: N802
# Process an HTTP GET request and return a response.
Expand Down
Loading
Loading