diff --git a/mcpauth/exceptions.py b/mcpauth/exceptions.py index 3f73acc..e949e65 100644 --- a/mcpauth/exceptions.py +++ b/mcpauth/exceptions.py @@ -34,7 +34,15 @@ def to_json(self, show_cause: bool = False) -> Record: data: Record = { "error": self.code.value if isinstance(self.code, Enum) else self.code, "error_description": self.message, - "cause": self.cause if show_cause and hasattr(self, "cause") else None, + "cause": ( + ( + {k: v for k, v in self.cause.model_dump().items() if v is not None} + if isinstance(self.cause, BaseModel) + else str(self.cause) + ) + if show_cause and hasattr(self, "cause") + else None + ), } return {k: v for k, v in data.items() if v is not None} @@ -99,8 +107,8 @@ class MCPAuthBearerAuthExceptionDetails(BaseModel): cause: Any = None uri: Optional[str] = None missing_scopes: Optional[List[str]] = None - expected: Optional[Union[str, Record]] = None - actual: Optional[Union[str, Record]] = None + expected: Any = None + actual: Any = None class MCPAuthBearerAuthException(MCPAuthException): @@ -124,15 +132,15 @@ def __init__( def to_json(self, show_cause: bool = False) -> Dict[str, Optional[str]]: # Matches the OAuth 2.0 exception response format at best effort - result = super().to_json(show_cause) + data = super().to_json(show_cause) if self.cause: - result.update( + data.update( { "error_uri": self.cause.uri, "missing_scopes": self.cause.missing_scopes, } ) - return result + return {k: v for k, v in data.items() if v is not None} class MCPAuthJwtVerificationExceptionCode(str, Enum): diff --git a/mcpauth/exceptioins_test.py b/mcpauth/exceptions_test.py similarity index 99% rename from mcpauth/exceptioins_test.py rename to mcpauth/exceptions_test.py index 87fa7cc..412ad5a 100644 --- a/mcpauth/exceptioins_test.py +++ b/mcpauth/exceptions_test.py @@ -28,7 +28,7 @@ def test_to_json(self): assert mcp_exception.to_json(show_cause=True) == { "error": "test_code", "error_description": "Test message", - "cause": exception, + "cause": str(exception), } def test_properties(self): diff --git a/mcpauth/middleware/create_bearer_auth.py b/mcpauth/middleware/create_bearer_auth.py new file mode 100644 index 0000000..152a410 --- /dev/null +++ b/mcpauth/middleware/create_bearer_auth.py @@ -0,0 +1,213 @@ +from typing import Any, Dict, List, Optional +from urllib.parse import urlparse +import logging +from pydantic import BaseModel +from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint +from starlette.requests import Request +from starlette.responses import Response, JSONResponse +from starlette.datastructures import Headers + +from ..exceptions import ( + MCPAuthBearerAuthException, + MCPAuthJwtVerificationException, + MCPAuthAuthServerException, + MCPAuthConfigException, + BearerAuthExceptionCode, + MCPAuthBearerAuthExceptionDetails, +) +from ..types import VerifyAccessTokenFunction, Record + + +class BearerAuthConfig(BaseModel): + """ + Configuration for the Bearer auth handler. + + Attributes: + issuer: The expected issuer of the access token. + audience: The expected audience of the access token. + required_scopes: An array of required scopes that the access token must have. + show_error_details: Whether to show detailed error information in the response. + """ + + issuer: str + audience: Optional[str] = None + required_scopes: Optional[List[str]] = None + show_error_details: bool = False + + +def get_bearer_token_from_headers(headers: Headers) -> str: + """ + Extract the Bearer token from the request headers. + + Args: + headers: The HTTP request headers. + + Returns: + The Bearer token. + + Raises: + MCPAuthBearerAuthException: If the Authorization header is missing or invalid. + """ + + auth_header = headers.get("authorization") or headers.get("Authorization") + + print(f"Authorization header: {auth_header}") + + if not auth_header: + raise MCPAuthBearerAuthException(BearerAuthExceptionCode.MISSING_AUTH_HEADER) + + parts = auth_header.split(" ") + if len(parts) != 2 or parts[0].lower() != "bearer": + raise MCPAuthBearerAuthException( + BearerAuthExceptionCode.INVALID_AUTH_HEADER_FORMAT + ) + + token = parts[1] + if not token: + raise MCPAuthBearerAuthException(BearerAuthExceptionCode.MISSING_BEARER_TOKEN) + + return token + + +def _handle_error( + error: Exception, show_error_details: bool = False +) -> tuple[int, Dict[str, Any]]: + """ + Handle errors from the Bearer auth process. + + Args: + error: The exception that was caught. + show_error_details: Whether to include detailed error information in the response. + + Returns: + A tuple of (status_code, response_body). + """ + if isinstance(error, MCPAuthJwtVerificationException): + return 401, error.to_json(show_error_details) + + if isinstance(error, MCPAuthBearerAuthException): + if error.code == BearerAuthExceptionCode.MISSING_REQUIRED_SCOPES: + return 403, error.to_json(show_error_details) + return 401, error.to_json(show_error_details) + + if isinstance(error, (MCPAuthAuthServerException, MCPAuthConfigException)): + response: Record = { + "error": "server_error", + "error_description": "An error occurred with the authorization server.", + } + if show_error_details: + response["cause"] = error.to_json() + return 500, response + + # Re-raise other errors + raise error + + +def create_bearer_auth( + verify_access_token: VerifyAccessTokenFunction, config: BearerAuthConfig +) -> type[BaseHTTPMiddleware]: + """ + Creates a middleware function for handling Bearer auth. + + This middleware extracts the Bearer token from the `Authorization` header, verifies it using the + provided `verify_access_token` function, and checks the issuer, audience, and required scopes. + + Args: + verify_access_token: A function that takes a Bearer token and returns an `AuthInfo` object. + config: Configuration for the Bearer auth handler. + + Returns: + A middleware class that handles Bearer auth. + """ + + if not callable(verify_access_token): + raise TypeError( + "`verify_access_token` must be a function that takes a token and returns an `AuthInfo` object." + ) + + try: + result = urlparse(config.issuer) + if not all([result.scheme, result.netloc]): + raise ValueError("Invalid URL") + except: + raise TypeError("`issuer` must be a valid URL.") + + class BearerAuthMiddleware(BaseHTTPMiddleware): + """ + Middleware class that handles Bearer auth. + + This class is used to wrap the request handling process and apply Bearer auth checks. + """ + + async def dispatch( + self, request: Request, call_next: RequestResponseEndpoint + ) -> Response: + """ + Dispatch method that processes the request and applies Bearer auth checks. + + Args: + request: The HTTP request. + call_next: The next middleware or route handler to call. + + Returns: + The HTTP response after processing the request. + """ + try: + token = get_bearer_token_from_headers(request.headers) + auth_info = verify_access_token(token) + + if auth_info.issuer != config.issuer: + details = MCPAuthBearerAuthExceptionDetails( + expected=config.issuer, actual=auth_info.issuer + ) + raise MCPAuthBearerAuthException( + BearerAuthExceptionCode.INVALID_ISSUER, cause=details + ) + + if config.audience: + audience_matches = ( + config.audience == auth_info.audience + if isinstance(auth_info.audience, str) + else ( + isinstance(auth_info.audience, list) + and config.audience in auth_info.audience + ) + ) + if not audience_matches: + details = MCPAuthBearerAuthExceptionDetails( + expected=config.audience, actual=auth_info.audience + ) + raise MCPAuthBearerAuthException( + BearerAuthExceptionCode.INVALID_AUDIENCE, cause=details + ) + + if config.required_scopes: + missing_scopes = [ + scope + for scope in config.required_scopes + if scope not in auth_info.scopes + ] + if missing_scopes: + details = MCPAuthBearerAuthExceptionDetails( + missing_scopes=missing_scopes + ) + raise MCPAuthBearerAuthException( + BearerAuthExceptionCode.MISSING_REQUIRED_SCOPES, + cause=details, + ) + + # Attach auth info to the request + request.state.auth = auth_info + + # Call the next middleware or route handler + response = await call_next(request) + return response + + except Exception as error: + logging.error(f"Error during Bearer auth: {error}") + status_code, response_body = _handle_error( + error, config.show_error_details + ) + return JSONResponse(status_code=status_code, content=response_body) + + return BearerAuthMiddleware diff --git a/mcpauth/middleware/create_bearer_auth_test.py b/mcpauth/middleware/create_bearer_auth_test.py new file mode 100644 index 0000000..063f889 --- /dev/null +++ b/mcpauth/middleware/create_bearer_auth_test.py @@ -0,0 +1,598 @@ +import json +import pytest +from unittest.mock import MagicMock, AsyncMock +from starlette.requests import Request +from starlette.responses import Response, JSONResponse +from starlette.middleware.base import BaseHTTPMiddleware +from mcpauth.types import AuthInfo, VerifyAccessTokenFunction +from datetime import timedelta, datetime + +from mcpauth.middleware.create_bearer_auth import ( + create_bearer_auth, + BearerAuthConfig, + BearerAuthExceptionCode, +) +from mcpauth.exceptions import ( + AuthServerExceptionCode, + MCPAuthJwtVerificationException, + MCPAuthAuthServerException, + MCPAuthConfigException, + MCPAuthJwtVerificationExceptionCode, +) + + +class TestHandleBearerAuth: + def test_should_return_middleware_class(self): + middleware = create_bearer_auth( + lambda _: None, # type: ignore + BearerAuthConfig(issuer="https://example.com"), + ) + assert callable(middleware) + + def test_should_throw_error_if_verify_access_token_is_not_a_function(self): + with pytest.raises( + TypeError, match=r"`verify_access_token` must be a function" + ): + create_bearer_auth( + "not a function", # type: ignore + BearerAuthConfig(issuer="https://example.com"), + ) + + def test_should_throw_error_if_issuer_is_not_a_valid_url(self): + with pytest.raises(TypeError, match=r"`issuer` must be a valid URL."): + create_bearer_auth( + lambda _: None, # type: ignore + BearerAuthConfig(issuer="not a valid url"), + ) + + +@pytest.mark.asyncio +class TestHandleBearerAuthMiddleware: + @pytest.fixture + def auth_config(self): + issuer = "https://example.com" + required_scopes = ["read", "write"] + audience = "test-audience" + + def verify_access_token(token: str) -> AuthInfo: + if token == "valid-token": + return AuthInfo( + issuer=issuer, + client_id="client-id", + scopes=["read", "write"], + token=token, + audience=audience, + expires_at=int((datetime.now() + timedelta(hours=1)).timestamp()), + subject="subject-id", + claims={"sub": "subject-id", "aud": audience, "iss": issuer}, + ) + raise MCPAuthJwtVerificationException( + MCPAuthJwtVerificationExceptionCode.INVALID_JWT + ) + + return ( + verify_access_token, + BearerAuthConfig( + issuer=issuer, + required_scopes=required_scopes, + audience=audience, + ), + ) + + @pytest.fixture + def middleware( + self, auth_config: tuple[VerifyAccessTokenFunction, BearerAuthConfig] + ): + MiddlewareClass = create_bearer_auth(auth_config[0], auth_config[1]) + return MiddlewareClass(app=MagicMock()) + + async def test_should_respond_with_error_if_request_does_not_have_bearer_token( + self, middleware: BaseHTTPMiddleware + ): + # Create mock request with no Authorization header + request = Request( + scope={ + "type": "http", + "headers": [], + "method": "GET", + "path": "/", + } + ) + + response = await middleware.dispatch(request, MagicMock()) + + assert response.status_code == 401 + assert isinstance(response, JSONResponse) and isinstance(response.body, bytes) + response_data = json.loads(response.body.decode("utf-8")) + assert response_data == { + "error": BearerAuthExceptionCode.MISSING_AUTH_HEADER.value, + "error_description": "Missing `Authorization` header. Please provide a valid bearer token.", + } + + async def test_should_respond_with_error_if_bearer_token_is_malformed( + self, middleware: BaseHTTPMiddleware + ): + # Test case 1: Invalid token format + request1 = Request( + scope={ + "type": "http", + "headers": [(b"authorization", b"Bearer invalid token format")], + "method": "GET", + "path": "/", + } + ) + + response1 = await middleware.dispatch(request1, MagicMock()) + + assert response1.status_code == 401 + assert isinstance(response1, JSONResponse) and isinstance(response1.body, bytes) + response1_data = json.loads(response1.body.decode("utf-8")) + assert response1_data == { + "error": BearerAuthExceptionCode.INVALID_AUTH_HEADER_FORMAT.value, + "error_description": 'Invalid `Authorization` header format. Expected "Bearer ".', + } + + # Test case 2: Invalid header format + request2 = Request( + scope={ + "type": "http", + "headers": [(b"authorization", b"invalid-header")], + "method": "GET", + "path": "/", + } + ) + + response2 = await middleware.dispatch(request2, MagicMock()) + + assert response2.status_code == 401 + assert isinstance(response2, JSONResponse) and isinstance(response2.body, bytes) + response2_data = json.loads(response2.body.decode("utf-8")) + assert response2_data == { + "error": BearerAuthExceptionCode.INVALID_AUTH_HEADER_FORMAT.value, + "error_description": 'Invalid `Authorization` header format. Expected "Bearer ".', + } + + # Test case 3: Missing token + request3 = Request( + scope={ + "type": "http", + "headers": [(b"authorization", b"Bearer ")], + "method": "GET", + "path": "/", + } + ) + + response3 = await middleware.dispatch(request3, MagicMock()) + + assert response3.status_code == 401 + assert isinstance(response3, JSONResponse) and isinstance(response3.body, bytes) + response3_data = json.loads(response3.body.decode("utf-8")) + assert response3_data == { + "error": BearerAuthExceptionCode.MISSING_BEARER_TOKEN.value, + "error_description": "Missing bearer token in `Authorization` header. Please provide a valid token.", + } + + async def test_should_respond_with_error_if_bearer_token_is_not_valid( + self, + auth_config: tuple[VerifyAccessTokenFunction, BearerAuthConfig], + ): + mock_verify = MagicMock( + side_effect=MCPAuthJwtVerificationException( + MCPAuthJwtVerificationExceptionCode.INVALID_JWT + ) + ) + MiddlewareClass = create_bearer_auth(mock_verify, auth_config[1]) + middleware = MiddlewareClass(app=MagicMock()) + + mock_verify.side_effect = MCPAuthJwtVerificationException( + MCPAuthJwtVerificationExceptionCode.INVALID_JWT + ) + + request = Request( + scope={ + "type": "http", + "headers": [(b"authorization", b"Bearer invalid-token")], + "method": "GET", + "path": "/", + } + ) + + response = await middleware.dispatch(request, MagicMock()) + + assert response.status_code == 401 + assert isinstance(response, JSONResponse) and isinstance(response.body, bytes) + response_data = json.loads(response.body.decode("utf-8")) + assert response_data == { + "error": "invalid_jwt", + "error_description": "The provided JWT is invalid or malformed.", + } + mock_verify.assert_called_once_with("invalid-token") + + async def test_should_respond_with_error_if_issuer_does_not_match( + self, + auth_config: tuple[VerifyAccessTokenFunction, BearerAuthConfig], + ): + mock_verify = MagicMock() + mock_verify.return_value = AuthInfo( + issuer="https://wrong-issuer.com", + client_id="client-id", + scopes=["read", "write"], + token="valid-token", + audience=auth_config[1].audience, + expires_at=int((datetime.now() + timedelta(hours=1)).timestamp()), + subject="subject-id", + claims={ + "sub": "subject-id", + "aud": auth_config[1].audience, + "iss": "https://wrong-issuer.com", + }, + ) + + MiddlewareClass = create_bearer_auth(mock_verify, auth_config[1]) + middleware = MiddlewareClass(app=MagicMock()) + + request = Request( + scope={ + "type": "http", + "headers": [(b"authorization", b"Bearer valid-token")], + "method": "GET", + "path": "/", + } + ) + + response = await middleware.dispatch(request, MagicMock()) + + assert response.status_code == 401 + assert isinstance(response, JSONResponse) and isinstance(response.body, bytes) + response_data = json.loads(response.body.decode("utf-8")) + assert response_data == { + "error": "invalid_issuer", + "error_description": "The token issuer does not match the expected issuer.", + } + mock_verify.assert_called_once_with("valid-token") + + async def test_should_respond_with_error_if_audience_does_not_match( + self, + auth_config: tuple[VerifyAccessTokenFunction, BearerAuthConfig], + ): + mock_verify = MagicMock() + mock_verify.return_value = AuthInfo( + issuer=auth_config[1].issuer, + client_id="client-id", + scopes=["read", "write"], + token="valid-token", + audience="wrong-audience", + expires_at=int((datetime.now() + timedelta(hours=1)).timestamp()), + subject="subject-id", + claims={ + "sub": "subject-id", + "aud": "wrong-audience", + "iss": auth_config[1].issuer, + }, + ) + + MiddlewareClass = create_bearer_auth(mock_verify, auth_config[1]) + middleware = MiddlewareClass(app=MagicMock()) + + request = Request( + scope={ + "type": "http", + "headers": [(b"authorization", b"Bearer valid-token")], + "method": "GET", + "path": "/", + } + ) + + response = await middleware.dispatch(request, MagicMock()) + + assert response.status_code == 401 + assert isinstance(response, JSONResponse) and isinstance(response.body, bytes) + response_data = json.loads(response.body.decode("utf-8")) + assert response_data == { + "error": "invalid_audience", + "error_description": "The token audience does not match the expected audience.", + } + mock_verify.assert_called_once_with("valid-token") + + async def test_should_respond_with_error_if_audience_does_not_match_array_case( + self, + auth_config: tuple[VerifyAccessTokenFunction, BearerAuthConfig], + ): + mock_verify = MagicMock() + mock_verify.return_value = AuthInfo( + issuer=auth_config[1].issuer, + client_id="client-id", + scopes=["read", "write"], + token="valid-token", + audience=["wrong-audience"], + expires_at=int((datetime.now() + timedelta(hours=1)).timestamp()), + subject="subject-id", + claims={ + "sub": "subject-id", + "aud": ["wrong-audience"], + "iss": auth_config[1].issuer, + }, + ) + + MiddlewareClass = create_bearer_auth(mock_verify, auth_config[1]) + middleware = MiddlewareClass(app=MagicMock()) + + request = Request( + scope={ + "type": "http", + "headers": [(b"authorization", b"Bearer valid-token")], + "method": "GET", + "path": "/", + } + ) + + response = await middleware.dispatch(request, MagicMock()) + + assert response.status_code == 401 + assert isinstance(response, JSONResponse) and isinstance(response.body, bytes) + response_data = json.loads(response.body.decode("utf-8")) + assert response_data == { + "error": "invalid_audience", + "error_description": "The token audience does not match the expected audience.", + } + mock_verify.assert_called_once_with("valid-token") + + async def test_should_respond_with_error_if_required_scopes_are_not_present( + self, + auth_config: tuple[VerifyAccessTokenFunction, BearerAuthConfig], + ): + mock_verify = MagicMock() + mock_verify.return_value = AuthInfo( + issuer=auth_config[1].issuer, + client_id="client-id", + scopes=["read"], # Missing "write" scope + token="valid-token", + audience=auth_config[1].audience, + expires_at=int((datetime.now() + timedelta(hours=1)).timestamp()), + subject="subject-id", + claims={ + "sub": "subject-id", + "aud": auth_config[1].audience, + "iss": auth_config[1].issuer, + }, + ) + + MiddlewareClass = create_bearer_auth(mock_verify, auth_config[1]) + middleware = MiddlewareClass(app=MagicMock()) + + request = Request( + scope={ + "type": "http", + "headers": [(b"authorization", b"Bearer valid-token")], + "method": "GET", + "path": "/", + } + ) + + response = await middleware.dispatch(request, MagicMock()) + + assert response.status_code == 403 + assert isinstance(response, JSONResponse) and isinstance(response.body, bytes) + response_data = json.loads(response.body.decode("utf-8")) + assert response_data == { + "error": "missing_required_scopes", + "error_description": "The token does not contain the necessary scopes for this request.", + "missing_scopes": ["write"], + } + mock_verify.assert_called_once_with("valid-token") + + async def test_should_call_next_if_token_is_valid_and_has_correct_audience_and_scopes( + self, middleware: BaseHTTPMiddleware + ): + request = Request( + scope={ + "type": "http", + "headers": [(b"authorization", b"Bearer valid-token")], + "method": "GET", + "path": "/", + } + ) + + # Create a mock for the next_call + # Create a mock for the next_call with AsyncMock + next_call = AsyncMock() + next_call.return_value = Response(status_code=200) + + response = await middleware.dispatch(request, next_call) + + # Verify next was called + next_call.assert_called_once() + assert response.status_code == 200 + + async def test_should_override_existing_auth_property_on_request( + self, middleware: BaseHTTPMiddleware + ): + # Create request with existing auth attribute + request = Request( + scope={ + "type": "http", + "headers": [(b"authorization", b"Bearer valid-token")], + "method": "GET", + "path": "/", + } + ) + + # Set pre-existing auth property + setattr( + request.state, + "auth", + {"client_id": "old-client-id", "scopes": ["old-scope"]}, + ) + + # Create mock for next_call + next_call = AsyncMock() + next_call.return_value = Response(status_code=200) + + response = await middleware.dispatch(request, next_call) + + # Check that auth was overridden with new values + assert hasattr(request.state, "auth") + assert request.state.auth.issuer == "https://example.com" + assert request.state.auth.client_id == "client-id" + assert request.state.auth.scopes == ["read", "write"] + assert request.state.auth.token == "valid-token" + assert request.state.auth.audience == "test-audience" + + next_call.assert_called_once() + assert response.status_code == 200 + + async def test_should_handle_mcp_auth_server_error_and_config_error(self): + # Test MCPAuthAuthServerError with show_error_details enabled + mock_verify = MagicMock() + mock_verify.side_effect = MCPAuthAuthServerException( + AuthServerExceptionCode.INVALID_SERVER_CONFIG, + cause=Exception("Server configuration is invalid"), + ) + + config = BearerAuthConfig( + issuer="https://example.com", + required_scopes=[], + audience=None, + show_error_details=True, + ) + + MiddlewareClass = create_bearer_auth(mock_verify, config) + middleware = MiddlewareClass(app=MagicMock()) + + request = Request( + scope={ + "type": "http", + "headers": [(b"authorization", b"Bearer valid-token")], + "method": "GET", + "path": "/", + } + ) + + response = await middleware.dispatch(request, MagicMock()) + + assert response.status_code == 500 + assert isinstance(response, JSONResponse) and isinstance(response.body, bytes) + response_data = json.loads(response.body.decode("utf-8")) + assert response_data == { + "error": "server_error", + "error_description": "An error occurred with the authorization server.", + "cause": { + "error": "invalid_server_config", + "error_description": "The server configuration does not match the MCP specification.", + }, + } + + # Test MCPAuthConfigException + mock_verify_config = MagicMock() + mock_verify_config.side_effect = MCPAuthConfigException( + "invalid_config", "Configuration is invalid" + ) + + config_error_middleware_class = create_bearer_auth( + mock_verify_config, + BearerAuthConfig( + issuer="https://example.com", required_scopes=[], audience=None + ), + ) + config_error_middleware = config_error_middleware_class(app=MagicMock()) + + config_error_request = Request( + scope={ + "type": "http", + "headers": [(b"authorization", b"Bearer valid-token")], + "method": "GET", + "path": "/", + } + ) + + config_error_response = await config_error_middleware.dispatch( + config_error_request, MagicMock() + ) + + assert config_error_response.status_code == 500 + assert isinstance(config_error_response, JSONResponse) and isinstance( + config_error_response.body, bytes + ) + config_error_response_data = json.loads( + config_error_response.body.decode("utf-8") + ) + assert config_error_response_data == { + "error": "server_error", + "error_description": "An error occurred with the authorization server.", + } + + async def test_should_throw_for_unexpected_errors(self): + mock_verify = MagicMock() + mock_verify.side_effect = Exception("Unexpected error") + + middleware_class = create_bearer_auth( + mock_verify, + BearerAuthConfig( + issuer="https://example.com", required_scopes=[], audience=None + ), + ) + middleware = middleware_class(app=MagicMock()) + + request = Request( + scope={ + "type": "http", + "headers": [(b"authorization", b"Bearer valid-token")], + "method": "GET", + "path": "/", + } + ) + + with pytest.raises(Exception, match="Unexpected error"): + await middleware.dispatch(request, MagicMock()) + + async def test_should_show_error_details_for_bearer_auth_error(self): + issuer = "https://example.com" + required_scopes = ["read", "write"] + audience = "test-audience" + + mock_verify = MagicMock() + mock_verify.return_value = AuthInfo( + issuer=issuer + "1", # Different issuer + client_id="client-id", + scopes=required_scopes, + token="valid-token", + audience=audience, + expires_at=int((datetime.now() + timedelta(hours=1)).timestamp()), + subject="subject-id", + claims={"sub": "subject-id", "aud": audience, "iss": issuer + "1"}, + ) + + middleware_class = create_bearer_auth( + mock_verify, + BearerAuthConfig( + issuer=issuer, + required_scopes=required_scopes, + audience=audience, + show_error_details=True, + ), + ) + middleware = middleware_class(app=MagicMock()) + + request = Request( + scope={ + "type": "http", + "headers": [(b"authorization", b"Bearer valid-token")], + "method": "GET", + "path": "/", + } + ) + + response = await middleware.dispatch(request, MagicMock()) + + assert response.status_code == 401 + assert isinstance(response, JSONResponse) and isinstance(response.body, bytes) + response_data = json.loads(response.body.decode("utf-8")) + assert response_data == { + "error": "invalid_issuer", + "error_description": "The token issuer does not match the expected issuer.", + "cause": { + "expected": issuer, + "actual": issuer + "1", + }, + } + mock_verify.assert_called_once_with("valid-token") diff --git a/mcpauth/types.py b/mcpauth/types.py index 333a6be..e768b9c 100644 --- a/mcpauth/types.py +++ b/mcpauth/types.py @@ -1,4 +1,110 @@ -from typing import Any, Dict +from typing import Dict, List, Optional, Protocol, Union, Any +from pydantic import BaseModel Record = Dict[str, Any] + + +class AuthInfo(BaseModel): + """ + Authentication information extracted from tokens. + + These fields can be used in the MCP handlers to provide more context about the authenticated + identity. + """ + + token: str + """ + The raw access token received in the request. This is typically a JWT or opaque token that is + used to authenticate the request. + """ + + issuer: str + """ + The issuer of the access token, which is typically the OAuth / OIDC provider that issued the token. + This is usually a URL that identifies the authorization server. + + See Also: + - https://datatracker.ietf.org/doc/html/rfc7519#section-4.1.1 + - https://openid.net/specs/openid-connect-core-1_0.html#IssuerIdentifier + """ + + client_id: str + """ + The client ID of the OAuth client that the token was issued to. This is typically the client ID + registered with the OAuth / OIDC provider. + + Some providers may use 'application ID' or similar terms instead of 'client ID'. + """ + + scopes: List[str] + """ + The scopes (permissions) that the access token has been granted. Scopes define what actions the + token can perform on behalf of the user or client. Normally, you need to define these scopes in + the OAuth / OIDC provider and assign them to the `subject` of the token. + + The provider may support different mechanisms for defining and managing scopes, such as + role-based access control (RBAC) or fine-grained permissions. + """ + + expires_at: Optional[int] + """ + The expiration time of the access token, represented as a Unix timestamp (seconds since epoch). + """ + + subject: Optional[str] + """ + The `sub` (subject) claim of the token, which typically represents the user ID or principal + that the token is issued for. + + See Also: + - https://datatracker.ietf.org/doc/html/rfc7519#section-4.1.2 + """ + + audience: Optional[Union[str, List[str]]] + """ + The `aud` (audience) claim of the token, which indicates the intended recipient(s) of the token. + + For OAuth / OIDC providers that support Resource Indicators (RFC 8707), this claim can be used + to specify the intended Resource Server (API) that the token is meant for. + + If the token is intended for multiple audiences, this can be a list of strings. + + See Also: + - https://datatracker.ietf.org/doc/html/rfc7519#section-4.1.3 + - https://datatracker.ietf.org/doc/html/rfc8707 + """ + + claims: Optional[Dict[str, Any]] + """ + The raw claims from the token, which can include any additional information provided by the + token issuer. + """ + + +class VerifyAccessTokenFunction(Protocol): + """ + Function type for verifying an access token. + + This function should throw an `MCPAuthJwtVerificationError` if the token is invalid, or return an + `AuthInfo` instance if the token is valid. + + For example, if you have a JWT verification function, it should at least check the token's + signature, validate its expiration, and extract the necessary claims to return an `AuthInfo` + instance. + + Note: + There's no need to verify the following fields in the token, as they will be checked + by the MCP handlers: + + - `iss` (issuer) + - `aud` (audience) + - `scope` (scopes) + """ + + def __call__(self, token: str) -> AuthInfo: + """ + :param token: The access token to verify. + :return: An `AuthInfo` instance containing the extracted authentication information. + """ + ... diff --git a/pyproject.toml b/pyproject.toml index b95cfc9..38bea3e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,7 +13,12 @@ keywords = [ "oauth2", "openid-connect", ] -dependencies = ["aiohttp>=3.11.18", "pydantic>=2.11.3", "pyjwt[crypto]>=2.9.0"] +dependencies = [ + "aiohttp>=3.11.18", + "pydantic>=2.11.3", + "pyjwt[crypto]>=2.9.0", + "starlette>=0.46.2", +] [project.urls] homepage = "https://mcp-auth.dev" diff --git a/uv.lock b/uv.lock index b924dc5..fdba472 100644 --- a/uv.lock +++ b/uv.lock @@ -135,6 +135,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl", hash = "sha256:1f02e8b43a8fbbc3f3e0d4f0f4bfc8131bcb4eebe8849b8e5c773f3a1c582a53", size = 13643, upload_time = "2024-05-20T21:33:24.1Z" }, ] +[[package]] +name = "anyio" +version = "4.9.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "exceptiongroup", marker = "python_full_version < '3.11'" }, + { name = "idna" }, + { name = "sniffio" }, + { name = "typing-extensions", marker = "python_full_version < '3.13'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/95/7d/4c1bd541d4dffa1b52bd83fb8527089e097a106fc90b467a7313b105f840/anyio-4.9.0.tar.gz", hash = "sha256:673c0c244e15788651a4ff38710fea9675823028a6f08a5eda409e0c9840a028", size = 190949, upload_time = "2025-03-17T00:02:54.77Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a1/ee/48ca1a7c89ffec8b6a0c5d02b89c305671d5ffd8d3c94acf8b8c408575bb/anyio-4.9.0-py3-none-any.whl", hash = "sha256:9f76d541cad6e36af7beb62e978876f3b41e3e04f2c1fbf0884604c0a9c4d93c", size = 100916, upload_time = "2025-03-17T00:02:52.713Z" }, +] + [[package]] name = "aresponses" version = "3.0.0" @@ -560,6 +575,7 @@ dependencies = [ { name = "aiohttp" }, { name = "pydantic" }, { name = "pyjwt", extra = ["crypto"] }, + { name = "starlette" }, ] [package.dev-dependencies] @@ -576,6 +592,7 @@ requires-dist = [ { name = "aiohttp", specifier = ">=3.11.18" }, { name = "pydantic", specifier = ">=2.11.3" }, { name = "pyjwt", extras = ["crypto"], specifier = ">=2.9.0" }, + { name = "starlette", specifier = ">=0.46.2" }, ] [package.metadata.requires-dev] @@ -1041,6 +1058,28 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/28/d0/def53b4a790cfb21483016430ed828f64830dd981ebe1089971cd10cab25/pytest_cov-6.1.1-py3-none-any.whl", hash = "sha256:bddf29ed2d0ab6f4df17b4c55b0a657287db8684af9c42ea546b21b1041b3dde", size = 23841, upload_time = "2025-04-05T14:07:49.641Z" }, ] +[[package]] +name = "sniffio" +version = "1.3.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/a2/87/a6771e1546d97e7e041b6ae58d80074f81b7d5121207425c964ddf5cfdbd/sniffio-1.3.1.tar.gz", hash = "sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc", size = 20372, upload_time = "2024-02-25T23:20:04.057Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e9/44/75a9c9421471a6c4805dbf2356f7c181a29c1879239abab1ea2cc8f38b40/sniffio-1.3.1-py3-none-any.whl", hash = "sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2", size = 10235, upload_time = "2024-02-25T23:20:01.196Z" }, +] + +[[package]] +name = "starlette" +version = "0.46.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio" }, + { name = "typing-extensions", marker = "python_full_version < '3.10'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/ce/20/08dfcd9c983f6a6f4a1000d934b9e6d626cff8d2eeb77a89a68eef20a2b7/starlette-0.46.2.tar.gz", hash = "sha256:7f7361f34eed179294600af672f565727419830b54b7b084efe44bb82d2fccd5", size = 2580846, upload_time = "2025-04-13T13:56:17.942Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8b/0c/9d30a4ebeb6db2b25a841afbb80f6ef9a854fc3b41be131d249a977b4959/starlette-0.46.2-py3-none-any.whl", hash = "sha256:595633ce89f8ffa71a015caed34a5b2dc1c0cdb3f0f1fbd1e69339cf2abeec35", size = 72037, upload_time = "2025-04-13T13:56:16.21Z" }, +] + [[package]] name = "tomli" version = "2.2.1"