from collections.abc import Callable, Mapping
from operator import itemgetter
from typing import Any

from langchain_core.messages import BaseMessage
from langchain_core.output_parsers.openai_functions import JsonOutputFunctionsParser
from langchain_core.runnables import RouterRunnable, Runnable
from langchain_core.runnables.base import RunnableBindingBase
from typing_extensions import TypedDict


class OpenAIFunction(TypedDict):
    """A function description for ChatOpenAI."""

    name: str
    """The name of the function."""
    description: str
    """The description of the function."""
    parameters: dict
    """The parameters to the function."""


class OpenAIFunctionsRouter(RunnableBindingBase[BaseMessage, Any]):  # type: ignore[no-redef]
    """A runnable that routes to the selected function."""

    functions: list[OpenAIFunction] | None

    def __init__(
        self,
        runnables: Mapping[
            str,
            Runnable[dict, Any] | Callable[[dict], Any],
        ],
        functions: list[OpenAIFunction] | None = None,
    ):
        """Initialize the OpenAIFunctionsRouter.

        Args:
            runnables: A mapping of function names to runnables.
            functions: Optional list of functions to check against the runnables.
        """
        if functions is not None:
            if len(functions) != len(runnables):
                msg = "The number of functions does not match the number of runnables."
                raise ValueError(msg)
            if not all(func["name"] in runnables for func in functions):
                msg = "One or more function names are not found in runnables."
                raise ValueError(msg)
        router = (
            JsonOutputFunctionsParser(args_only=False)
            | {"key": itemgetter("name"), "input": itemgetter("arguments")}
            | RouterRunnable(runnables)
        )
        super().__init__(bound=router, kwargs={}, functions=functions)
