"""Microsoft Entra ID validator + M2M + Agent Identity shim (primary IdP).

The shim is intentionally minimal — JWT validation, the
client_credentials grant, and the two-step ``fmi_path`` exchange that
mints per-Agent-Identity tokens for the `alphaswarm_admin` service.

First-login provisioning and the EntraTenantLink wizard remain in
``alphaswarm/auth/providers/msal_entra.py`` per AGENTS rule 44.

Issuer / JWKS URL templates follow the Microsoft Entra v2.0 endpoint
conventions:

- ``issuer``    = ``https://login.microsoftonline.com/<tenant>/v2.0``
- ``jwks_url``  = ``https://login.microsoftonline.com/<tenant>/discovery/v2.0/keys``
- ``token_url`` = ``https://login.microsoftonline.com/<tenant>/oauth2/v2.0/token``

``<tenant>`` is one of:

- ``common``         - any Entra tenant + personal Microsoft accounts
- ``organizations``  - any Entra tenant (B2B / external enterprise)
- ``consumers``      - personal Microsoft accounts only
- ``<tenant_id>``    - single-tenant (UUID or verified domain)

For the AlphaSwarm control plane the recommended default is
``organizations`` (B2B / B2C enterprise customers) with
``audience = "api://alphaswarm-controller"`` configured on the Entra
app registration.

Agent Identity (entra-agent-id skill) two-step exchange:

  Step 1: client_credentials with scope=api://AzureADTokenExchange/.default
          + FIC (no client_secret) -> parent_token
  Step 2: client_credentials with scope=<resource>/.default + fmi_path
          + fmi_target_id + requested_token_use=on_behalf_of
          + assertion=parent_token -> agent_token

The minted agent_token carries the Agent Identity object id as ``sub``
and (when the Agent Identity holds an app-role assignment for the
target resource) the matching ``roles`` claim.
"""
from __future__ import annotations

import time
from dataclasses import dataclass
from typing import Any
from urllib.parse import urlencode

import httpx

from alphaswarm_core.auth.jwt_validator import (
    JwtValidationError,
    JwtValidator,
    JwtValidatorConfig,
)
from alphaswarm_core.auth.providers.protocol import (
    IdentityProviderShimConfig,
    M2MGrant,
)

_ENTRA_HOST = "https://login.microsoftonline.com"

#: Audience for the Step 1 parent token in the Agent Identity fmi_path
#: exchange. Per the entra-agent-id skill this is the canonical Entra
#: WIF audience — the parent token is never used as a bearer; only as
#: the ``assertion`` argument to Step 2.
AGENT_IDENTITY_PARENT_AUDIENCE = "api://AzureADTokenExchange"

#: Default fmi_path namespace prefix for AlphaSwarm Agent Identities.
#: The full fmi_path is ``alphaswarm-admin-<env>`` (e.g.
#: ``alphaswarm-admin-prod``). The Blueprint accepts any fmi_path under this
#: prefix — Entra uses it only as an audit label.
AGENT_IDENTITY_FMI_PATH_PREFIX = "alphaswarm-admin"


@dataclass(frozen=True, slots=True)
class AgentTokenResult:
    """Result of the two-step fmi_path exchange.

    Mirrors :class:`M2MGrant` so the broker layer can treat them
    interchangeably, but adds Agent-Identity-specific provenance so
    audit rows can record which Agent Identity minted which token.
    """

    access_token: str
    expires_at: float
    token_type: str
    issued_at: float
    agent_identity_id: str
    blueprint_app_id: str
    fmi_path: str
    audience: str
    scope: tuple[str, ...] = ()


def msal_entra_jwt_validator_config(
    *,
    tenant: str,
    audience: str,
    leeway_seconds: int = 60,
    jwks_ttl_seconds: int = 600,
    expected_claim_namespaces: tuple[str, ...] = (),
) -> JwtValidatorConfig:
    """Build a :class:`JwtValidatorConfig` for an Entra tenant.

    ``tenant`` may be a UUID, ``common``, ``organizations``, or
    ``consumers``. ``audience`` is the Entra app-registration
    ``api://<app-id-uri>`` resource id.
    """
    tenant_segment = tenant.strip().strip("/") or "organizations"
    issuer = f"{_ENTRA_HOST}/{tenant_segment}/v2.0"
    jwks_url = f"{_ENTRA_HOST}/{tenant_segment}/discovery/v2.0/keys"
    return JwtValidatorConfig(
        issuer=issuer,
        audience=audience,
        algorithms=("RS256",),
        leeway_seconds=leeway_seconds,
        jwks_ttl_seconds=jwks_ttl_seconds,
        jwks_url_override=jwks_url,
        expected_claim_namespaces=expected_claim_namespaces,
    )


@dataclass(frozen=True, slots=True)
class _MsalEntraConfig(IdentityProviderShimConfig):
    """Concrete config carrying the tenant segment used to derive endpoints."""

    tenant_segment: str = "organizations"


class MsalEntraValidator:
    """Microsoft Entra ID validator + M2M shim.

    Owns a single :class:`JwtValidator` and a single
    :class:`httpx.AsyncClient` for the token endpoint. Both are
    lazily created on first use.
    """

    provider_alias = "msal_entra"

    def __init__(
        self,
        *,
        tenant: str,
        audience: str,
        leeway_seconds: int = 60,
        jwks_ttl_seconds: int = 600,
        expected_claim_namespaces: tuple[str, ...] = (),
        http_timeout_seconds: float = 10.0,
    ) -> None:
        tenant_segment = tenant.strip().strip("/") or "organizations"
        self._tenant_segment = tenant_segment
        self._http_timeout = http_timeout_seconds
        self._config = _MsalEntraConfig(
            provider_alias=self.provider_alias,
            issuer=f"{_ENTRA_HOST}/{tenant_segment}/v2.0",
            audience=audience,
            jwks_url=f"{_ENTRA_HOST}/{tenant_segment}/discovery/v2.0/keys",
            leeway_seconds=leeway_seconds,
            jwks_ttl_seconds=jwks_ttl_seconds,
            algorithms=("RS256",),
            expected_claim_namespaces=expected_claim_namespaces,
            tenant_segment=tenant_segment,
        )
        self._validator: JwtValidator | None = None
        self._http: httpx.AsyncClient | None = None

    @property
    def config(self) -> IdentityProviderShimConfig:
        return self._config

    def token_endpoint(self) -> str:
        return f"{_ENTRA_HOST}/{self._tenant_segment}/oauth2/v2.0/token"

    def jwt_validator(self) -> JwtValidator:
        if self._validator is None:
            self._validator = JwtValidator(
                msal_entra_jwt_validator_config(
                    tenant=self._tenant_segment,
                    audience=self._config.audience,
                    leeway_seconds=self._config.leeway_seconds,
                    jwks_ttl_seconds=self._config.jwks_ttl_seconds,
                    expected_claim_namespaces=self._config.expected_claim_namespaces,
                )
            )
        return self._validator

    async def acquire_m2m_grant(
        self,
        *,
        client_id: str,
        client_secret: str,
        audience: str,
        scopes: tuple[str, ...] = (),
        extra: dict[str, Any] | None = None,
    ) -> M2MGrant:
        """Execute a ``grant_type=client_credentials`` exchange.

        Entra v2.0 expects the resource as a scope of the form
        ``<audience>/.default``. The optional ``scopes`` argument
        is appended verbatim for callers that want narrower grants.
        """
        await self._ensure_http()
        assert self._http is not None  # narrowed by _ensure_http
        scope_value = self._build_scope(audience, scopes)
        body: dict[str, str] = {
            "grant_type": "client_credentials",
            "client_id": client_id,
            "client_secret": client_secret,
            "scope": scope_value,
        }
        if extra:
            body.update({str(k): str(v) for k, v in extra.items()})
        try:
            response = await self._http.post(
                self.token_endpoint(),
                content=urlencode(body).encode("utf-8"),
                headers={"Content-Type": "application/x-www-form-urlencoded"},
            )
        except httpx.HTTPError as exc:
            raise JwtValidationError(
                f"entra token endpoint unreachable: {exc}",
                code="token_endpoint_unreachable",
            ) from exc
        if response.status_code >= 400:
            raise JwtValidationError(
                f"entra client_credentials failed: HTTP {response.status_code}",
                code="m2m_grant_failed",
            )
        data = response.json()
        access_token = str(data.get("access_token") or "")
        if not access_token:
            raise JwtValidationError(
                "entra response missing access_token",
                code="m2m_grant_invalid_response",
            )
        expires_in = int(data.get("expires_in") or 3600)
        granted_scope = data.get("scope")
        scope_tuple = (
            tuple(str(s) for s in str(granted_scope).split() if s)
            if granted_scope
            else scopes
        )
        now = time.time()
        return M2MGrant(
            access_token=access_token,
            expires_at=now + max(60, expires_in - 30),
            token_type=str(data.get("token_type") or "Bearer"),
            issued_at=now,
            scope=scope_tuple,
        )

    async def acquire_agent_token(
        self,
        *,
        blueprint_app_id: str,
        agent_identity_id: str,
        audience: str,
        fmi_path: str,
        client_secret: str | None = None,
        federated_assertion: str | None = None,
        scopes: tuple[str, ...] = (),
    ) -> AgentTokenResult:
        """Execute the two-step ``fmi_path`` Agent Identity exchange.

        Per the [entra-agent-id skill](file:///C:/Users/Julian%20Wiley/.cursor/plugins/cache/cursor-public/azure/9d86ae4a15bcbc82bd49d908c050638d99d02e38/skills/entra-agent-id/SKILL.md):

        Step 1 — mint the parent token for the Blueprint app
                 (client_credentials with scope=api://AzureADTokenExchange/.default).
                 Either ``client_secret`` or ``federated_assertion``
                 (a WIF token) MUST be supplied; FIC is the
                 production default per the entra-agent-id skill's
                 "Best Practices" #5.

        Step 2 — exchange the parent token for an Agent Identity token
                 (client_credentials with fmi_path + fmi_target_id +
                 requested_token_use=on_behalf_of + assertion=parent_token).

        Both steps target the SAME token endpoint
        (``/{tenant}/oauth2/v2.0/token``); the only difference is
        the body's ``scope`` + the presence of fmi parameters.

        Raises :class:`JwtValidationError` on any HTTP failure,
        missing access_token, or rate-limit response (HTTP 429).
        """
        if not blueprint_app_id:
            raise JwtValidationError(
                "blueprint_app_id is required for agent token exchange",
                code="agent_token_blueprint_missing",
            )
        if not agent_identity_id:
            raise JwtValidationError(
                "agent_identity_id is required for agent token exchange",
                code="agent_token_identity_missing",
            )
        if not client_secret and not federated_assertion:
            raise JwtValidationError(
                "either client_secret or federated_assertion must be supplied for the parent token",
                code="agent_token_credential_missing",
            )

        await self._ensure_http()
        assert self._http is not None  # narrowed by _ensure_http

        # ----- Step 1: parent token -----
        parent_body: dict[str, str] = {
            "grant_type": "client_credentials",
            "client_id": blueprint_app_id,
            "scope": f"{AGENT_IDENTITY_PARENT_AUDIENCE}/.default",
        }
        if federated_assertion is not None:
            parent_body["client_assertion_type"] = (
                "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"
            )
            parent_body["client_assertion"] = federated_assertion
        else:
            # client_secret path — dev / local only per entra-agent-id
            # best practices #5.
            parent_body["client_secret"] = client_secret or ""

        parent_token = await self._post_token(parent_body, stage="parent")

        # ----- Step 2: agent token via fmi_path -----
        agent_scope = self._build_scope(audience, scopes)
        agent_body: dict[str, str] = {
            "grant_type": "client_credentials",
            "client_id": blueprint_app_id,
            "scope": agent_scope,
            "fmi_path": fmi_path,
            "fmi_target_id": agent_identity_id,
            "requested_token_use": "on_behalf_of",
            "assertion": parent_token["access_token"],
        }
        if federated_assertion is not None:
            agent_body["client_assertion_type"] = (
                "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"
            )
            agent_body["client_assertion"] = federated_assertion
        else:
            agent_body["client_secret"] = client_secret or ""

        agent_token = await self._post_token(agent_body, stage="agent")
        now = time.time()
        expires_in = int(agent_token.get("expires_in") or 3600)
        granted_scope = agent_token.get("scope")
        scope_tuple = (
            tuple(str(s) for s in str(granted_scope).split() if s)
            if granted_scope
            else scopes
        )
        return AgentTokenResult(
            access_token=str(agent_token.get("access_token") or ""),
            expires_at=now + max(60, expires_in - 30),
            token_type=str(agent_token.get("token_type") or "Bearer"),
            issued_at=now,
            agent_identity_id=agent_identity_id,
            blueprint_app_id=blueprint_app_id,
            fmi_path=fmi_path,
            audience=audience,
            scope=scope_tuple,
        )

    async def _post_token(
        self,
        body: dict[str, str],
        *,
        stage: str,
    ) -> dict[str, Any]:
        """POST to the token endpoint; raise on HTTP error / missing access_token."""
        assert self._http is not None  # narrowed by caller
        try:
            response = await self._http.post(
                self.token_endpoint(),
                content=urlencode(body).encode("utf-8"),
                headers={"Content-Type": "application/x-www-form-urlencoded"},
            )
        except httpx.HTTPError as exc:
            raise JwtValidationError(
                f"entra token endpoint unreachable ({stage}): {exc}",
                code=f"agent_token_{stage}_unreachable",
            ) from exc
        if response.status_code == 429:
            raise JwtValidationError(
                f"entra rate-limited at {stage} step (HTTP 429)",
                code=f"agent_token_{stage}_rate_limited",
            )
        if response.status_code >= 400:
            raise JwtValidationError(
                f"entra {stage} token request failed: HTTP {response.status_code}",
                code=f"agent_token_{stage}_failed",
            )
        data = response.json()
        if not data.get("access_token"):
            raise JwtValidationError(
                f"entra {stage} response missing access_token",
                code=f"agent_token_{stage}_invalid_response",
            )
        return data

    async def close(self) -> None:
        if self._validator is not None:
            await self._validator.close()
            self._validator = None
        if self._http is not None:
            await self._http.aclose()
            self._http = None

    async def _ensure_http(self) -> None:
        if self._http is None:
            self._http = httpx.AsyncClient(
                timeout=self._http_timeout,
                headers={"Accept": "application/json"},
            )

    @staticmethod
    def _build_scope(audience: str, scopes: tuple[str, ...]) -> str:
        normalised = audience.rstrip("/")
        if not normalised.endswith("/.default"):
            normalised = f"{normalised}/.default"
        parts = [normalised]
        parts.extend(scopes)
        return " ".join(parts)


__all__ = [
    "AGENT_IDENTITY_FMI_PATH_PREFIX",
    "AGENT_IDENTITY_PARENT_AUDIENCE",
    "AgentTokenResult",
    "MsalEntraValidator",
    "msal_entra_jwt_validator_config",
]
