アットランタイム

Python の httpx.Auth に対応した AWS Signature Version 4 を実装する

はじめに

https の Auth に対応した AWS Signature Version 4 を実装しました。

背景

Strands Agents で AI エージェントを実装しており、金銭的な理由から AWS Lambda を採用することにしました。 AI エージェントに加えて、MCP サーバーも実装しており同様に AWS Lambda 上で構築することにしました。 料金、また構築の容易さと、現状 Streamable HTTP transport 対応するためには Lambda URL 一択でした。

Strands Agent の実装として、Streamable HTTP transport の MCP を使用する場合、MCP サーバーとして Lambda URL を指定します。

2. Streamable HTTP¶

from mcp.client.streamable_http import streamablehttp_client
from strands import Agent
from strands.tools.mcp.mcp_client import MCPClient

streamable_http_mcp_client = MCPClient(lambda: streamablehttp_client("http://localhost:8000/mcp"))

# Create an agent with MCP tools
with streamable_http_mcp_client:
    # Get the tools from the MCP server
    tools = streamable_http_mcp_client.list_tools_sync()

    # Create an agent with these tools
    agent = Agent(tools=tools)

ここで、セキュリティの観点、MCP サーバーは AI エージェント Lambda からのみアクセス可能にしたいところです。 Lambda URL の認証タイプとして None を使用すると URL を知る人は誰でもアクセス可能となり、特に料金に対する懸念がありました。

Lambda URL では認証タイプとして IAM があるので、IAM 認証情報で署名されたリクエストを送信できれば要件を満たせそうでした。

streamablehttp_client の実装

streamablehttp_client は HTTP ヘッダーと、httpx.Auth を引数として受け取れます。 署名をする場合、httpx.Auth に対応した AWS Signature Version 4 を使用するのがよさそうです。

実装すれば、次のように IAM 認証を使用して MCP サーバーを実行できるようになります。

is_running_on_lambda = os.getenv("AWS_LAMBDA_FUNCTION_NAME", None) is not None
mcp_db_client = MCPClient(
    lambda: streamablehttp_client(
        url=DB_MCP_FULL_URL,
        auth=AWSv4Auth("lambda", os.getenv("AWS_REGION", "")) if is_running_on_lambda else None,
    )
)

実装

gist をはっておきます。 通常のリクエストとストリームの両方に対応しているはずです。

テスト

作成した MCP エンドポイントや通常のエンドポイントに対して動作したので、一般的なユースケースでは動作しそうです。

import pytest
import asyncio
import httpx
from main import AWSv4Auth


class TestAWSv4Auth:
    """AWSv4Auth クラスのテストスイート - 必要な機能に特化"""

    @pytest.fixture
    def auth(self):
        """テスト用の認証インスタンス"""
        return AWSv4Auth(service="lambda", region="ap-northeast-1")

    @pytest.fixture
    def general_url(self):
        """通常のテスト用URL"""
        return f"https://{URL_1}.lambda-url.ap-northeast-1.on.aws/ping"

    @pytest.fixture
    def mcp_url(self):
        """MCPエンドポイント用URL"""
        return f"https://{URL_2}.lambda-url.ap-northeast-1.on.aws/mcp"

    # === GETリクエストのテスト ===

    def test_basic_get_request(self, auth, general_url):
        """基本的なGETリクエストのテスト"""
        with httpx.Client(auth=auth) as client:
            response = client.get(general_url)
            assert response.status_code == 200
            assert len(response.text) > 0

    def test_get_with_query_params(self, auth, general_url):
        """クエリパラメータ付きGETリクエストのテスト"""
        with httpx.Client(auth=auth) as client:
            params = {"message": "Hello World!", "symbols": "!@#$%^&*()"}
            response = client.get(general_url, params=params)
            assert response.status_code == 200
            assert len(response.text) > 0

    def test_get_with_japanese_query_params(self, auth, general_url):
        """日本語クエリパラメータ付きGETリクエストのテスト"""
        with httpx.Client(auth=auth) as client:
            params = {"message": "こんにちは", "name": "田中さん"}
            response = client.get(general_url, params=params)
            assert response.status_code == 200
            assert len(response.text) > 0

    # === POSTリクエストのテスト ===

    def test_post_json_request(self, auth, general_url):
        """JSONデータを送信するPOSTリクエストのテスト"""
        with httpx.Client(auth=auth) as client:
            data = {"message": "Hello", "timestamp": "2025-08-15"}
            response = client.post(general_url, json=data)
            assert response.status_code == 200
            assert len(response.text) > 0

    def test_post_with_query_params(self, auth, general_url):
        """クエリパラメータ付きPOSTリクエストのテスト"""
        with httpx.Client(auth=auth) as client:
            params = {"action": "create", "version": "v1"}
            data = {"content": "test data"}
            response = client.post(general_url, params=params, json=data)
            assert response.status_code == 200
            assert len(response.text) > 0

    def test_post_mcp_endpoint(self, auth, mcp_url):
        """MCPエンドポイントへのPOSTリクエストのテスト"""
        with httpx.Client(auth=auth) as client:
            headers = {
                "Accept": "application/json, text/event-stream",
                "Content-Type": "application/json",
            }
            json_rpc = {
                "jsonrpc": "2.0",
                "method": "initialize",
                "params": {
                    "protocolVersion": "2024-11-05",
                    "capabilities": {},
                    "clientInfo": {"name": "test-client", "version": "1.0.0"},
                },
                "id": 1,
            }
            response = client.post(mcp_url, json=json_rpc, headers=headers)
            assert response.status_code in [200, 201, 202]
            assert len(response.text) > 0

    # === ストリーミングリクエストのテスト ===

    def test_streaming_get_request(self, auth, general_url):
        """ストリーミングGETリクエストのテスト"""
        with httpx.Client(auth=auth) as client:
            with client.stream("GET", general_url) as response:
                assert response.status_code == 200

                content = b""
                chunk_count = 0
                for chunk in response.iter_bytes():
                    content += chunk
                    chunk_count += 1

                assert len(content) > 0
                assert chunk_count > 0
                # レスポンスがデコード可能であることを確認
                decoded_content = content.decode()
                assert len(decoded_content) > 0

    def test_streaming_post_request(self, auth, general_url):
        """ストリーミングPOSTリクエストのテスト"""
        with httpx.Client(auth=auth) as client:
            data = {"message": "streaming test", "data": list(range(10))}
            with client.stream("POST", general_url, json=data) as response:
                assert response.status_code == 200

                content = b""
                chunk_count = 0
                for chunk in response.iter_bytes():
                    content += chunk
                    chunk_count += 1

                assert len(content) > 0
                assert chunk_count > 0

    # === 非同期リクエストのテスト ===

    @pytest.mark.asyncio
    async def test_async_get_request(self, auth, general_url):
        """非同期GETリクエストのテスト"""
        async with httpx.AsyncClient(auth=auth) as client:
            response = await client.get(general_url, params={"async": "true"})
            assert response.status_code == 200
            assert len(response.text) > 0

    @pytest.mark.asyncio
    async def test_async_post_request(self, auth, general_url):
        """非同期POSTリクエストのテスト"""
        async with httpx.AsyncClient(auth=auth) as client:
            data = {"message": "async test", "timestamp": "2025-08-15"}
            response = await client.post(general_url, json=data)
            assert response.status_code == 200
            assert len(response.text) > 0

    @pytest.mark.asyncio
    async def test_async_streaming_get_request(self, auth, general_url):
        """非同期ストリーミングGETリクエストのテスト"""
        async with httpx.AsyncClient(auth=auth) as client:
            async with client.stream("GET", general_url) as response:
                assert response.status_code == 200

                content = b""
                chunk_count = 0
                async for chunk in response.aiter_bytes():
                    content += chunk
                    chunk_count += 1

                assert len(content) > 0
                assert chunk_count > 0
                # レスポンスがデコード可能であることを確認
                decoded_content = content.decode()
                assert len(decoded_content) > 0

    @pytest.mark.asyncio
    async def test_async_streaming_post_request(self, auth, general_url):
        """非同期ストリーミングPOSTリクエストのテスト"""
        async with httpx.AsyncClient(auth=auth) as client:
            data = {"message": "async streaming test", "data": list(range(5))}
            async with client.stream("POST", general_url, json=data) as response:
                assert response.status_code == 200

                content = b""
                chunk_count = 0
                async for chunk in response.aiter_bytes():
                    content += chunk
                    chunk_count += 1

                assert len(content) > 0
                assert chunk_count > 0

    # === クエリパラメータの詳細テスト ===

    def test_query_params_special_characters(self, auth, general_url):
        """特殊文字を含むクエリパラメータのテスト"""
        with httpx.Client(auth=auth) as client:
            params = {
                "special": "!@#$%^&*()",
                "spaces": "hello world",
                "unicode": "こんにちは世界",
            }
            response = client.get(general_url, params=params)
            assert response.status_code == 200
            assert len(response.text) > 0

    def test_query_params_multiple_values(self, auth, general_url):
        """複数の値を持つクエリパラメータのテスト"""
        with httpx.Client(auth=auth) as client:
            # httpxでは同じキーに複数の値を設定する場合はリストを使用
            params = [
                ("tags", "python"),
                ("tags", "aws"),
                ("tags", "lambda"),
                ("category", "test"),
            ]
            response = client.get(general_url, params=params)
            assert response.status_code == 200
            assert len(response.text) > 0


if __name__ == "__main__":
    # pytestを直接実行する場合
    pytest.main([__file__, "-v"])