from typing import Any, Optional, List, Dict, Generator, AsyncGenerator
import warnings
import json
from fastapi import Request
from orichain import error_explainer
from orichain.llm import (
openai_llm,
anthropicbedrock_llm,
anthropic_llm,
awsbedrock_llm,
azureopenai_llm,
gcp_gemini_llm,
gcp_vertex_llm,
togetherai_llm,
)
DEFAULT_MODEL = "gpt-5-mini"
DEFAULT_MODEL_PROVIDER = "OpenAI"
SUPPORTED_MODELS = {
"OpenAI": [
"gpt-4o",
"gpt-4-turbo",
"gpt-4-turbo-preview",
"gpt-4o-mini",
"gpt-4",
"gpt-4.1",
"gpt-4.1-mini",
"gpt-4.1-nano",
"gpt-5",
"gpt-5-mini",
"gpt-5-nano",
],
"AzureOpenAI": [
"gpt-4o",
"gpt-4-turbo",
"gpt-4-turbo-preview",
"gpt-4o-mini",
"gpt-4",
"gpt-4.1",
"gpt-4.1-mini",
"gpt-4.1-nano",
"gpt-5",
"gpt-5-mini",
"gpt-5-nano",
],
"AnthropicBedrock": [
"anthropic.claude-3-haiku-20240307-v1:0",
"us.anthropic.claude-3-haiku-20240307-v1:0",
"us-gov.anthropic.claude-3-haiku-20240307-v1:0",
"eu.anthropic.claude-3-haiku-20240307-v1:0",
"apac.anthropic.claude-3-haiku-20240307-v1:0",
"anthropic.claude-3-5-haiku-20241022-v1:0",
"us.anthropic.claude-3-5-haiku-20241022-v1:0",
"anthropic.claude-3-sonnet-20240229-v1:0",
"us.anthropic.claude-3-sonnet-20240229-v1:0",
"eu.anthropic.claude-3-sonnet-20240229-v1:0",
"apac.anthropic.claude-3-sonnet-20240229-v1:0",
"anthropic.claude-3-5-sonnet-20240620-v1:0",
"us.anthropic.claude-3-5-sonnet-20240620-v1:0",
"us-gov.anthropic.claude-3-5-sonnet-20240620-v1:0",
"eu.anthropic.claude-3-5-sonnet-20240620-v1:0",
"apac.anthropic.claude-3-5-sonnet-20240620-v1:0",
"anthropic.claude-3-5-sonnet-20241022-v2:0",
"us.anthropic.claude-3-5-sonnet-20241022-v2:0",
"anthropic.claude-3-7-sonnet-20250219-v1:0",
"us.anthropic.claude-3-7-sonnet-20250219-v1:0",
"anthropic.claude-sonnet-4-20250514-v1:0",
"us.anthropic.claude-sonnet-4-20250514-v1:0",
"anthropic.claude-3-opus-20240229-v1:0",
"us.anthropic.claude-3-opus-20240229-v1:0",
"anthropic.claude-opus-4-20250514-v1:0",
"us.anthropic.claude-opus-4-20250514-v1:0",
"anthropic.claude-opus-4-1-20250805-v1:0",
"us.anthropic.claude-opus-4-1-20250805-v1:0",
],
"Anthropic": [
"claude-3-haiku-20240307",
"claude-3-5-haiku-20241022",
"claude-3-5-haiku-latest",
"claude-3-sonnet-20240229",
"claude-3-5-sonnet-latest",
"claude-3-7-sonnet-20250219",
"claude-3-7-sonnet-latest",
"claude-sonnet-4-0",
"claude-sonnet-4-20250514",
"claude-3-opus-latest",
"claude-opus-4-0",
"claude-opus-4-20250514",
"claude-opus-4-1",
"claude-opus-4-1-20250805",
],
"AWSBedrock": [
"cohere.command-text-v14",
"cohere.command-light-text-v14",
"cohere.command-r-v1:0",
"cohere.command-r-plus-v1:0",
"meta.llama3-8b-instruct-v1:0",
"meta.llama3-70b-instruct-v1:0",
"meta.llama3-1-8b-instruct-v1:0",
"us.meta.llama3-1-8b-instruct-v1:0",
"meta.llama3-1-70b-instruct-v1:0",
"us.meta.llama3-1-70b-instruct-v1:0",
"meta.llama3-1-405b-instruct-v1:0",
"meta.llama3-2-1b-instruct-v1:0",
"us.meta.llama3-2-1b-instruct-v1:0",
"eu.meta.llama3-2-1b-instruct-v1:0",
"meta.llama3-2-3b-instruct-v1:0",
"us.meta.llama3-2-3b-instruct-v1:0",
"eu.meta.llama3-2-3b-instruct-v1:0",
"meta.llama3-2-11b-instruct-v1:0",
"us.meta.llama3-2-11b-instruct-v1:0",
"meta.llama3-2-90b-instruct-v1:0",
"us.meta.llama3-2-90b-instruct-v1:0",
"meta.llama3-3-70b-instruct-v1:0",
"us.meta.llama3-3-70b-instruct-v1:0",
"meta.llama4-maverick-17b-instruct-v1:0",
"us.meta.llama4-maverick-17b-instruct-v1:0",
"meta.llama4-scout-17b-instruct-v1:0",
"us.meta.llama4-scout-17b-instruct-v1:0",
"mistral.mistral-7b-instruct-v0:2",
"mistral.mixtral-8x7b-instruct-v0:1",
"mistral.mistral-large-2402-v1:0",
"mistral.mistral-large-2407-v1:0",
"mistral.mistral-small-2402-v1:0",
"amazon.titan-text-express-v1",
"amazon.titan-text-lite-v1",
"amazon.titan-text-premier-v1:0",
"amazon.nova-pro-v1:0",
"us.amazon.nova-pro-v1:0",
"amazon.nova-lite-v1:0",
"us.amazon.nova-lite-v1:0",
"amazon.nova-micro-v1:0",
"us.amazon.nova-micro-v1:0",
],
"GoogleGemini": [
"gemini-1.5-pro",
"gemini-1.5-flash-8b",
"gemini-1.5-flash",
"gemini-2.0-flash-lite",
"gemini-2.0-flash",
"gemini-2.5-flash-lite-preview-06-17",
"gemini-2.5-flash",
"gemini-2.5-pro",
],
"GoogleVertexAI": [
"gemini-1.5-pro",
"gemini-1.5-flash-8b",
"gemini-1.5-flash",
"gemini-2.0-flash-lite",
"gemini-2.0-flash",
"gemini-2.5-flash-lite-preview-06-17",
"gemini-2.5-flash",
"gemini-2.5-pro",
],
"TogetherAI": [
"Alibaba-NLP/gte-modernbert-base",
"arcee-ai/AFM-4.5B",
"arcee-ai/coder-large",
"arcee-ai/maestro-reasoning",
"arcee-ai/virtuoso-large",
"arcee_ai/arcee-spotlight",
"arize-ai/qwen-2-1.5b-instruct",
"black-forest-labs/FLUX.1-canny",
"black-forest-labs/FLUX.1-depth",
"black-forest-labs/FLUX.1-dev",
"black-forest-labs/FLUX.1-dev-lora",
"black-forest-labs/FLUX.1-kontext-dev",
"black-forest-labs/FLUX.1-kontext-max",
"black-forest-labs/FLUX.1-kontext-pro",
"black-forest-labs/FLUX.1-krea-dev",
"black-forest-labs/FLUX.1-pro",
"black-forest-labs/FLUX.1-redux",
"black-forest-labs/FLUX.1-schnell",
"black-forest-labs/FLUX.1-schnell-Free",
"black-forest-labs/FLUX.1.1-pro",
"cartesia/sonic",
"cartesia/sonic-2",
"deepcogito/cogito-v2-preview-deepseek-671b",
"deepcogito/cogito-v2-preview-llama-109B-MoE",
"deepcogito/cogito-v2-preview-llama-405B",
"deepcogito/cogito-v2-preview-llama-70B",
"deepseek-ai/DeepSeek-R1-0528-tput",
"deepseek-ai/DeepSeek-R1-Distill-Llama-70B",
"deepseek-ai/DeepSeek-R1-Distill-Llama-70B-free",
"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
"deepseek-ai/DeepSeek-R1-Distill-Qwen-14B",
"deepseek-ai/DeepSeek-V3",
"google/gemma-2-27b-it",
"google/gemma-3-27b-it",
"google/gemma-3n-E4B-it",
"intfloat/multilingual-e5-large-instruct",
"lgai/exaone-3-5-32b-instruct",
"lgai/exaone-deep-32b",
"marin-community/marin-8b-instruct",
"meta-llama/Llama-2-70b-hf",
"meta-llama/Llama-3-70b-chat-hf",
"meta-llama/Llama-3-8b-chat-hf",
"meta-llama/Llama-3.2-11B-Vision-Instruct-Turbo",
"meta-llama/Llama-3.2-3B-Instruct-Turbo",
"meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo",
"meta-llama/Llama-3.3-70B-Instruct-Turbo",
"meta-llama/Llama-3.3-70B-Instruct-Turbo-Free",
"meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
"meta-llama/Llama-4-Scout-17B-16E-Instruct",
"meta-llama/Llama-Guard-3-11B-Vision-Turbo",
"meta-llama/Llama-Guard-4-12B",
"meta-llama/Llama-Vision-Free",
"meta-llama/LlamaGuard-2-8b",
"meta-llama/Meta-Llama-3-70B-Instruct-Turbo",
"meta-llama/Meta-Llama-3-8B-Instruct-Lite",
"meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo",
"meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
"meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
"meta-llama/Meta-Llama-Guard-3-8B",
"mistralai/Mistral-7B-Instruct-v0.1",
"mistralai/Mistral-7B-Instruct-v0.3",
"mistralai/Mistral-Small-24B-Instruct-2501",
"mistralai/Mixtral-8x7B-Instruct-v0.1",
"mixedbread-ai/Mxbai-Rerank-Large-V2",
"moonshotai/Kimi-K2-Instruct",
"NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO",
"nvidia/Llama-3.1-Nemotron-70B-Instruct-HF",
"openai/gpt-oss-20b",
"openai/whisper-large-v3",
"perplexity-ai/r1-1776",
"Qwen/Qwen2-72B-Instruct",
"Qwen/Qwen2.5-72B-Instruct-Turbo",
"Qwen/Qwen2.5-7B-Instruct-Turbo",
"Qwen/Qwen2.5-Coder-32B-Instruct",
"Qwen/Qwen2.5-VL-72B-Instruct",
"Qwen/Qwen3-235B-A22B-fp8-tput",
"Qwen/Qwen3-235B-A22B-Instruct-2507-tput",
"Qwen/Qwen3-235B-A22B-Thinking-2507",
"Qwen/Qwen3-Coder-480B-A35B-Instruct-FP8",
"Qwen/QwQ-32B",
"Salesforce/Llama-Rank-V1",
"scb10x/scb10x-llama3-1-typhoon2-70b-instruct",
"scb10x/scb10x-typhoon-2-1-gemma3-12b",
"togethercomputer/m2-bert-80M-32k-retrieval",
"togethercomputer/MoA-1",
"togethercomputer/MoA-1-Turbo",
"togethercomputer/Refuel-Llm-V2",
"togethercomputer/Refuel-Llm-V2-Small",
"Virtue-AI/VirtueGuard-Text-Lite",
"zai-org/GLM-4.5-Air-FP8",
],
}
[docs]
class LLM(object):
"""Synchronous Language Model class for interacting with various LLM providers.
This class provides a unified interface to interact with different language models
from providers such as OpenAI, AWS Bedrock, Google Gemini and Vertex AI, Anthropic, and Azure OpenAI.
"""
default_model = DEFAULT_MODEL
default_model_provider = DEFAULT_MODEL_PROVIDER
supported_models = SUPPORTED_MODELS
model_handler = {
"OpenAI": openai_llm.Generate,
"AWSBedrock": awsbedrock_llm.Generate,
"AnthropicBedrock": anthropicbedrock_llm.Generate,
"Anthropic": anthropic_llm.Generate,
"AzureOpenAI": azureopenai_llm.Generate,
"GoogleGemini": gcp_gemini_llm.Generate,
"GoogleVertexAI": gcp_vertex_llm.Generate,
"TogetherAI": togetherai_llm.Generate,
}
[docs]
def __init__(self, **kwds: Any) -> None:
"""Initialize the Language Model class with the required parameters.
Args:
- model_name (str, optional): Name of the model to be used. Default: "gpt-4.1-mini"
- provider (str, optional): Name of the model provider. Default: "OpenAI". Allowed values:
- OpenAI
- AzureOpenAI
- AWSBedrock
- GoogleGemini
- GoogleVertexAI
- AnthropicBedrock
- Anthropic
- TogetherAI
**Authentication Arguments by provider:**
**OpenAI models:**
- api_key (str): OpenAI API key.
- timeout (Timeout, optional): Request timeout parameter like connect, read, write. Default: 60.0, 5.0, 10.0, 2.0
- max_retries (int, optional): Number of retries for the request. Default: 2
**AWS Bedrock models:**
- aws_access_key (str): AWS access key.
- aws_secret_key (str): AWS secret key.
- aws_region (str): AWS region name.
- prompt_caching (bool, optional): Whether to use prompt caching. Default: True
- config (botocore.config.Config, optional):
- connect_timeout (float or int, optional): The time in seconds till a timeout exception is thrown when attempting to make a connection. Default: 60
- read_timeout: (float or int, optional): The time in seconds till a timeout exception is thrown when attempting to read from a connection. Default: 60
- region_name (str, optional): region name Note: If specifing config you need to still pass region_name even if you have already passed in aws_region
- max_pool_connections: The maximum number of connections to keep in a connection pool. Defualt: 10
- retries (Dict, optional):
- total_max_attempts: Number of retries for the request. Default: 2
**Google Gemini models:**
- api_key (str): Gemini API key
- http_options (types.HttpOptions, optional): HTTP options to be used in each of the requests. Default is None
- debug_config (DebugConfig, optional): Configuration options that change client network behavior when testing. Default is None
**Google Vertex AI models:**
- api_key (str): Vertex AI API key
- credentials (google.auth.credentials.Credentials): The credentials to use for authentication when calling the Vertex AI APIs.
- project (str): The Google Cloud project ID to use for quota.
- location (str): The location to send API requests to (for example, us-central1).
- http_options (types.HttpOptions, optional): HTTP options to be used in each of the requests. Default is None
- debug_config (DebugConfig, optional): Configuration options that change client network behavior when testing. Default is None
**Anthropic models:**
- api_key (str): Anthropic API key.
- timeout (Timeout, optional): Request timeout parameter like connect, read, write. Default: 60.0, 5.0, 10.0, 2.0
- max_retries (int, optional): Number of retries for the request. Default: 2
- prompt_caching (bool, optional): Whether to use prompt caching. Default: True
**Azure OpenAI models:**
- api_key (str): Azure OpenAI API key.
- azure_endpoint (str): Azure OpenAI endpoint.
- api_version (str): Azure OpenAI API version.
- timeout (Timeout, optional): Request timeout parameter like connect, read, write. Default: 60.0, 5.0, 10.0, 2.0
- max_retries (int, optional): Number of retries for the request. Default: 2
**TogetherAI models:**
- api_key (str): TogetherAI API key.
- timeout (float or int, optional): Request timeout in seconds. Default: 60
- max_retries (int, optional): Number of retries for the request. Default: 2
Raises:
- ValueError: If an unsupported model is specified.
- KeyError: If required parameters are not provided.
- TypeError: If an invalid type is provided for a parameter.
Warns:
- UserWarning: If the model name is not provided, it defaults to the default model.
"""
# Set model name and model provider, defaulting if not provided
if not kwds.get("model_name"):
warnings.warn(
f"\nNo 'model_name' specified, hence defaulting to {self.default_model}",
UserWarning,
)
if not kwds.get("provider"):
warnings.warn(
f"\nNo 'provider' specified, hence defaulting to {self.default_model_provider}",
UserWarning,
)
self.model_name = kwds.get("model_name", self.default_model)
self.model_provider = kwds.get("provider", self.default_model_provider)
# Validating model name and model provider name
if self.model_provider not in self.model_handler:
raise ValueError(
f"\nUnsupported model provider: {self.model_provider}\nSupported providers are:"
f"\n- " + "\n- ".join(list(self.model_handler.keys()))
)
elif self.model_name not in self.supported_models.get(self.model_provider):
warnings.warn(
f"\nModel {self.model_name} for provider {self.model_provider} is not supported by Orichain. Supported models for {self.model_provider} are: [{', '.join(self.supported_models.get(self.model_provider))}] \nUsing an unsupported model may lead to unexpected issues. Please verify that you are using the correct 'model_name' and 'provider'",
UserWarning,
)
# Initialize the appropriate model handler
self.model = self.model_handler.get(self.model_provider)(**kwds)
[docs]
def __call__(
self,
user_message: str,
matched_sentence: Optional[List[str]] = None,
system_prompt: Optional[str] = None,
chat_hist: Optional[List[Dict[str, str]]] = None,
sampling_paras: Optional[Dict] = None,
tools: Optional[List[Dict]] = None,
tool_choice: Optional[str] = None,
extra_metadata: Optional[Dict] = None,
do_json: bool = False,
**kwds: Any,
) -> Dict:
"""Generate a synchronous response from the language model.
Args:
- user_message (str): The user's input message.
- system_prompt (str, optional): System prompt to guide the model's behavior.
- chat_hist (List[Dict[str, str]], optional): Chat history for context.
- sampling_paras (Dict, optional): Parameters for sampling (temperature, top_p, etc.).
- model_name (str, optional): Specifies the model to use. If not provided, the default is the model set during class instantiation.
- do_json (bool, optional): Whether to return a JSON response. Default: False.
- tools (List[Dict], optional): List of tools to be used by the model. Example format
[{"name": "tool name", "description": "tool description", "parameters": {"type": "object", "properties": {"arg_1": {"type": "string", "description": "An example argument for the tool."}}, "required": ["arg_1"]}}, .....]
- tool_choice (str, optional): Defines tool usage:
- "auto" (default) lets the model decide
- "none" disables tools (not supported on AWSBedrock/AnthropicBedrock and TogetherAI)
- "required" forces tool use (unsupported on AzureOpenAI < 2024-06-01)
- provide a tool name to call it directly.
- matched_sentence (List[str], optional): A list of matched text chunks for context. Not used internally, but included in the response under the matched_sentence key.
- extra_metadata (Dict, optional): Additional metadata to be included in the response.
**Generation Arguments by provider:**
**AWS Bedrock models:**
- additional_model_fields (Dict, optional): additionalModelRequestFields passed to the client in the request body.
**Google Gemini & Vertex AI models:**
- config (google.genai.types.GenerateContentConfig, optional): Optional model configuration parameters provided to the client.chats.create API.
- response_mime_type (str, optional): Output response mimetype of the generated candidate text. Supported mimetype: "text/plain" (Default), "application/json" (if do_json=True)
**Anthropic & AnthropicBedrock models:**
- timeout (httpx.Timeout, optional): - timeout (httpx.Timeout, optional): Request timeout parameter like connect, read, write. Default is 60.0, 5.0, 10.0, 2.0
Returns:
Dict: The model's response with tool calls and metadata.
"""
try:
# Handle model switching if a different model is specified in kwds
if self._model_n_model_type_validator(**kwds):
model_name = kwds.pop("model_name", self.model_name)
else:
model_name = self.model_name
# Default empty dictionaries
sampling_paras = sampling_paras or {}
extra_metadata = extra_metadata or {}
# Generate the response
result = self.model(
model_name=model_name,
user_message=user_message,
system_prompt=system_prompt,
chat_hist=chat_hist,
sampling_paras=sampling_paras,
tools=tools,
tool_choice=tool_choice,
do_json=do_json,
**kwds,
)
# Add user message and matched sentence to the response
if "error" not in result:
result.update({"message": user_message})
if matched_sentence:
result.update({"matched_sentence": matched_sentence})
# Add extra metadata to the response
if extra_metadata:
result["metadata"].update(extra_metadata)
return result
except Exception as e:
error_explainer(e)
return {"error": 500, "reason": str(e)}
[docs]
def stream(
self,
user_message: str,
matched_sentence: Optional[List[str]] = None,
system_prompt: Optional[str] = None,
chat_hist: List = None,
sampling_paras: Optional[Dict] = None,
tools: Optional[List[Dict]] = None,
tool_choice: Optional[str] = None,
extra_metadata: Optional[Dict] = None,
do_json: bool = False,
do_sse: bool = True,
**kwds: Any,
) -> Generator:
"""Stream responses from the language model.
Args:
- user_message (str): The user's input message.
- system_prompt (str, optional): System prompt to guide the model's behavior.
- chat_hist (List[Dict[str, str]], optional): Chat history for context.
- sampling_paras (Dict, optional): Parameters for sampling (temperature, top_p, etc.).
- model_name (str, optional): Specifies the model to use. If not provided, the default is the model set during class instantiation.
- do_json (bool, optional): Whether to return JSON responses. Default: False.
- do_sse (bool, optional): Whether to format responses as Server-Sent Events. Default: True.
- tools (List[Dict], optional): List of tools to be used by the model. Example format
[{"name": "tool name", "description": "tool description", "parameters": {"type": "object", "properties": {"arg_1": {"type": "string", "description": "An example argument for the tool."}}, "required": ["arg_1"]}}, .....]
- tool_choice (str, optional): Defines tool usage:
- "auto" (default) lets the model decide
- "none" disables tools (not supported on AWSBedrock/AnthropicBedrock)
- "required" forces tool use (unsupported on AzureOpenAI < 2024-06-01)
- provide a tool name to call it directly.
- matched_sentence (List[str], optional): A list of matched text chunks for context. Not used internally, but included in the response under the matched_sentence key.
- extra_metadata (Dict, optional): Additional metadata to be included in the response.
**Generation Arguments by provider:**
**AWS Bedrock models:**
- additional_model_fields (Dict, optional): additionalModelRequestFields passed to the client in the request body.
**Google Gemini & Vertex AI models:**
- config (google.genai.types.GenerateContentConfig, optional): Optional model configuration parameters provided to the client.chats.create API.
- response_mime_type (str, optional): Output response mimetype of the generated candidate text. Supported mimetype: "text/plain" (Default), "application/json" (if do_json=True)
Yields:
Generator: Stream of responses from the language model, followed by a final dictionary containing the complete response, including tool calls and metadata.
"""
try:
# Handle model switching if a different model is specified in kwds
if self._model_n_model_type_validator(**kwds):
model_name = kwds.get("model_name", self.model_name)
else:
model_name = self.model_name
# Default empty dictionaries
sampling_paras = sampling_paras or {}
extra_metadata = extra_metadata or {}
# Stream responses from the model
result = self.model.streaming(
model_name=model_name,
user_message=user_message,
system_prompt=system_prompt,
chat_hist=chat_hist,
sampling_paras=sampling_paras,
tools=tools,
tool_choice=tool_choice,
do_json=do_json,
**kwds,
)
# Process each chunk in the stream
for chunk in result:
if isinstance(chunk, str):
if do_sse:
yield self._format_sse(chunk, event="text")
else:
yield chunk
elif isinstance(chunk, Dict):
if "error" not in chunk:
chunk.update(
{
"message": user_message,
}
)
if matched_sentence:
chunk.update({"matched_sentence": matched_sentence})
if extra_metadata:
chunk["metadata"].update(extra_metadata)
if do_sse:
yield self._format_sse(chunk, event="body")
else:
yield chunk
except Exception as e:
error_explainer(e)
yield self._format_sse({"error": 500, "reason": str(e)}, event="body")
def _format_sse(self, data: Any, event=None) -> str:
"""Format data for Server-Sent Events (SSE).
Args:
data (Any): The data to format.
event (str, optional): The event type.
Returns:
str: Formatted SSE message.
"""
msg = f"data: {json.dumps(data)}\n\n"
if event is not None:
msg = f"event: {event}\n{msg}"
return msg
def _model_n_model_type_validator(self, **kwds: Any) -> bool:
"""Validate if the requested model is compatible with the current model type.
Args:
**kwds: Keyword arguments that may contain a 'model_name'.
Returns:
bool: True if the model is compatible, False otherwise.
"""
if kwds.get("model_name"):
if kwds.get("model_name") in self.supported_models.get(self.model_provider):
return True
elif kwds.get("model_name") in [
item for sublist in self.supported_models.values() for item in sublist
]:
warnings.warn(
f"{kwds.get('model_name')} is a supported model but does not belong to {self.model_provider} provider. "
f"Please reinitialize the LLM class with the '{kwds.get('model_name')}' model and the correct provider. "
f"Hence defaulting the model to {self.model_name}",
UserWarning,
)
return False
else:
warnings.warn(
f"\nModel {kwds.get('model_name')} for provider {self.model_provider} is not supported by Orichain. Supported models for {self.model_provider} are: [{', '.join(self.supported_models.get(self.model_provider))}] \nUsing an unsupported model may lead to unexpected issues. Please verify that you are using the correct 'model_name'",
UserWarning,
)
return True
else:
return False
[docs]
class AsyncLLM(object):
"""Asynchronous Language Model class for interacting with various LLM providers.
This class provides a unified interface to interact with different language models
from providers such as OpenAI, AWS Bedrock, Google Gemini and Vertex AI, Anthropic, and Azure OpenAI.
"""
default_model = DEFAULT_MODEL
default_model_provider = DEFAULT_MODEL_PROVIDER
supported_models = SUPPORTED_MODELS
model_handler = model_handler = {
"OpenAI": openai_llm.AsyncGenerate,
"AWSBedrock": awsbedrock_llm.AsyncGenerate,
"AnthropicBedrock": anthropicbedrock_llm.AsyncGenerate,
"Anthropic": anthropic_llm.AsyncGenerate,
"AzureOpenAI": azureopenai_llm.AsyncGenerate,
"GoogleGemini": gcp_gemini_llm.AsyncGenerate,
"GoogleVertexAI": gcp_vertex_llm.AsyncGenerate,
"TogetherAI": togetherai_llm.AsyncGenerate,
}
[docs]
def __init__(self, **kwds: Any) -> None:
"""Initialize the Language Model class with the required parameters.
Args:
- model_name (str, optional): Name of the model to be used. Default: "gpt-4.1-mini"
- provider (str, optional): Name of the model provider. Default: "OpenAI". Allowed values:
- OpenAI
- AzureOpenAI
- AWSBedrock
- GoogleGemini
- GoogleVertexAI
- AnthropicBedrock
- Anthropic
- TogetherAI
**Authentication Arguments by provider:**
**OpenAI models:**
- api_key (str): OpenAI API key.
- timeout (Timeout, optional): Request timeout parameter like connect, read, write. Default: 60.0, 5.0, 10.0, 2.0
- max_retries (int, optional): Number of retries for the request. Default: 2
**AWS Bedrock models:**
- aws_access_key (str): AWS access key.
- aws_secret_key (str): AWS secret key.
- aws_region (str): AWS region name.
- prompt_caching (bool, optional): Whether to use prompt caching. Default: True
- config (botocore.config.Config, optional):
- connect_timeout (float or int, optional): The time in seconds till a timeout exception is thrown when attempting to make a connection. Default: 60
- read_timeout: (float or int, optional): The time in seconds till a timeout exception is thrown when attempting to read from a connection. Default: 60
- region_name (str, optional): region name Note: If specifing config you need to still pass region_name even if you have already passed in aws_region
- max_pool_connections: The maximum number of connections to keep in a connection pool. Defualt: 10
- retries (Dict, optional):
- total_max_attempts: Number of retries for the request. Default: 2
**Google Gemini models:**
- api_key (str): Gemini API key
- http_options (types.HttpOptions, optional): HTTP options to be used in each of the requests. Default is None
- debug_config (DebugConfig, optional): Configuration options that change client network behavior when testing. Default is None
**Google Vertex AI models:**
- api_key (str): Vertex AI API key
- credentials (google.auth.credentials.Credentials): The credentials to use for authentication when calling the Vertex AI APIs.
- project (str): The Google Cloud project ID to use for quota.
- location (str): The location to send API requests to (for example, us-central1).
- http_options (types.HttpOptions, optional): HTTP options to be used in each of the requests. Default is None
- debug_config (DebugConfig, optional): Configuration options that change client network behavior when testing. Default is None
**Anthropic models:**
- api_key (str): Anthropic API key.
- timeout (Timeout, optional): Request timeout parameter like connect, read, write. Default: 60.0, 5.0, 10.0, 2.0
- max_retries (int, optional): Number of retries for the request. Default: 2
- prompt_caching (bool, optional): Whether to use prompt caching. Default: True
**Azure OpenAI models:**
- api_key (str): Azure OpenAI API key.
- azure_endpoint (str): Azure OpenAI endpoint.
- api_version (str): Azure OpenAI API version.
- timeout (Timeout, optional): Request timeout parameter like connect, read, write. Default: 60.0, 5.0, 10.0, 2.0
- max_retries (int, optional): Number of retries for the request. Default: 2
**TogetherAI models:**
- api_key (str): TogetherAI API key.
- timeout (float or int, optional): Request timeout in seconds. Default: 60
- max_retries (int, optional): Number of retries for the request. Default: 2
Raises:
- ValueError: If an unsupported model is specified.
- KeyError: If required parameters are not provided.
- TypeError: If an invalid type is provided for a parameter.
Warns:
- UserWarning: If the model name is not provided, it defaults to the default model.
"""
# Set model name and model provider, defaulting if not provided
if not kwds.get("model_name"):
warnings.warn(
f"\nNo 'model_name' specified, hence defaulting to {self.default_model}",
UserWarning,
)
if not kwds.get("provider"):
warnings.warn(
f"\nNo 'provider' specified, hence defaulting to {self.default_model_provider}",
UserWarning,
)
self.model_name = kwds.get("model_name", self.default_model)
self.model_provider = kwds.get("provider", self.default_model_provider)
# Validating model name and model provider name
if self.model_provider not in self.model_handler:
raise ValueError(
f"\nUnsupported model provider: {self.model_provider}\nSupported providers are:"
f"\n- " + "\n- ".join(list(self.model_handler.keys()))
)
elif self.model_name not in self.supported_models.get(self.model_provider):
warnings.warn(
f"\nModel {self.model_name} for provider {self.model_provider} is not supported by Orichain. Supported models for {self.model_provider} are: [{', '.join(self.supported_models.get(self.model_provider))}] \nUsing an unsupported model may lead to unexpected issues. Please verify that you are using the correct 'model_name' and 'provider'",
UserWarning,
)
# Initialize the appropriate model handler
self.model = self.model_handler.get(self.model_provider)(**kwds)
[docs]
async def __call__(
self,
user_message: str,
request: Optional[Request] = None,
matched_sentence: Optional[List[str]] = None,
system_prompt: Optional[str] = None,
chat_hist: Optional[List[Dict[str, str]]] = None,
sampling_paras: Optional[Dict] = None,
tools: Optional[List[Dict]] = None,
tool_choice: Optional[str] = None,
extra_metadata: Optional[Dict] = None,
do_json: bool = False,
**kwds: Any,
) -> Dict:
"""Generate a synchronous response from the language model.
Args:
- user_message (str): The user's input message.
- system_prompt (str, optional): System prompt to guide the model's behavior.
- chat_hist (List[Dict[str, str]], optional): Chat history for context.
- sampling_paras (Dict, optional): Parameters for sampling (temperature, top_p, etc.).
- model_name (str, optional): Specifies the model to use. If not provided, the default is the model set during class instantiation.
- do_json (bool, optional): Whether to return a JSON response. Default: False.
- tools (List[Dict], optional): List of tools to be used by the model. Example format
[{"name": "tool name", "description": "tool description", "parameters": {"type": "object", "properties": {"arg_1": {"type": "string", "description": "An example argument for the tool."}}, "required": ["arg_1"]}}, .....]
- tool_choice (str, optional): Defines tool usage:
- "auto" (default) lets the model decide
- "none" disables tools (not supported on AWSBedrock/AnthropicBedrock and TogetherAI)
- "required" forces tool use (unsupported on AzureOpenAI < 2024-06-01)
- provide a tool name to call it directly.
- request (Request, optional): FastAPI Request object for cancellation detection.
- matched_sentence (List[str], optional): A list of matched text chunks for context. Not used internally, but included in the response under the matched_sentence key.
- extra_metadata (Dict, optional): Additional metadata to be included in the response.
**Generation Arguments by provider:**
**AWS Bedrock models:**
- additional_model_fields (Dict, optional): additionalModelRequestFields passed to the client in the request body.
**Google Gemini & Vertex AI models:**
- config (google.genai.types.GenerateContentConfig, optional): Optional model configuration parameters provided to the client.chats.create API.
- response_mime_type (str, optional): Output response mimetype of the generated candidate text. Supported mimetype: "text/plain" (Default), "application/json" (if do_json=True)
**Anthropic & AnthropicBedrock models:**
- timeout (httpx.Timeout, optional): - timeout (httpx.Timeout, optional): Request timeout parameter like connect, read, write. Default is 60.0, 5.0, 10.0, 2.0
Returns:
Dict: The model's response with tool calls and metadata.
"""
try:
# Handle model switching if a different model is specified in kwds
if await self._model_n_model_type_validator(**kwds):
model_name = kwds.pop("model_name", self.model_name)
else:
model_name = self.model_name
# Default empty dictionaries
sampling_paras = sampling_paras or {}
extra_metadata = extra_metadata or {}
# Check if request is disconnected
if request and await request.is_disconnected():
return {"error": 400, "reason": "request aborted by user"}
# Generate the response
result = await self.model(
request=request,
model_name=model_name,
user_message=user_message,
system_prompt=system_prompt,
chat_hist=chat_hist,
sampling_paras=sampling_paras,
tools=tools,
tool_choice=tool_choice,
do_json=do_json,
**kwds,
)
# Add user message and matched sentence to the response
if "error" not in result:
result.update({"message": user_message})
if matched_sentence:
result.update({"matched_sentence": matched_sentence})
# Add extra metadata to the response
if extra_metadata:
result["metadata"].update(extra_metadata)
return result
except Exception as e:
error_explainer(e)
return {"error": 500, "reason": str(e)}
[docs]
async def stream(
self,
user_message: str,
request: Optional[Request] = None,
matched_sentence: Optional[List[str]] = None,
system_prompt: Optional[str] = None,
chat_hist: List = None,
sampling_paras: Optional[Dict] = None,
tools: Optional[List[Dict]] = None,
tool_choice: Optional[str] = None,
extra_metadata: Optional[Dict] = None,
do_json: bool = False,
do_sse: bool = True,
**kwds: Any,
) -> AsyncGenerator:
"""Stream responses from the language model.
Args:
- user_message (str): The user's input message.
- system_prompt (str, optional): System prompt to guide the model's behavior.
- chat_hist (List[Dict[str, str]], optional): Chat history for context.
- sampling_paras (Dict, optional): Parameters for sampling (temperature, top_p, etc.).
- model_name (str, optional): Specifies the model to use. If not provided, the default is the model set during class instantiation.
- do_json (bool, optional): Whether to return JSON responses. Default: False.
- do_sse (bool, optional): Whether to format responses as Server-Sent Events. Default: True.
- tools (List[Dict], optional): List of tools to be used by the model. Example format
[{"name": "tool name", "description": "tool description", "parameters": {"type": "object", "properties": {"arg_1": {"type": "string", "description": "An example argument for the tool."}}, "required": ["arg_1"]}}, .....]
- tool_choice (str, optional): Defines tool usage:
- "auto" (default) lets the model decide
- "none" disables tools (not supported on AWSBedrock/AnthropicBedrock)
- "required" forces tool use (unsupported on AzureOpenAI < 2024-06-01)
- provide a tool name to call it directly.
- request (Request, optional): FastAPI Request object for cancellation detection.
- matched_sentence (List[str], optional): A list of matched text chunks for context. Not used internally, but included in the response under the matched_sentence key.
- extra_metadata (Dict, optional): Additional metadata to be included in the response.
**Generation Arguments by provider:**
**AWS Bedrock models:**
- additional_model_fields (Dict, optional): additionalModelRequestFields passed to the client in the request body.
**Google Gemini & Vertex AI models:**
- config (google.genai.types.GenerateContentConfig, optional): Optional model configuration parameters provided to the client.chats.create API.
- response_mime_type (str, optional): Output response mimetype of the generated candidate text. Supported mimetype: "text/plain" (Default), "application/json" (if do_json=True)
Yields:
AsyncGenerator: Stream of responses from the language model, followed by a final dictionary containing the complete response, including tool calls and metadata.
"""
try:
# Handle model switching if a different model is specified in kwds
if await self._model_n_model_type_validator(**kwds):
model_name = kwds.get("model_name", self.model_name)
else:
model_name = self.model_name
# Default empty dictionaries
sampling_paras = sampling_paras or {}
extra_metadata = extra_metadata or {}
# Check if the request has been disconnected
if request and await request.is_disconnected():
yield await self._format_sse(
{"error": 400, "reason": "request aborted by user"}, event="body"
)
else:
# Stream responses from the model
result = self.model.streaming(
request=request,
model_name=model_name,
user_message=user_message,
system_prompt=system_prompt,
chat_hist=chat_hist,
sampling_paras=sampling_paras,
tools=tools,
tool_choice=tool_choice,
do_json=do_json,
**kwds,
)
# Process each chunk in the stream
async for chunk in result:
if isinstance(chunk, str):
if do_sse:
yield await self._format_sse(chunk, event="text")
else:
yield chunk
elif isinstance(chunk, Dict):
if "error" not in chunk:
chunk.update(
{
"message": user_message,
}
)
if matched_sentence:
chunk.update({"matched_sentence": matched_sentence})
if extra_metadata:
chunk["metadata"].update(extra_metadata)
if do_sse:
yield await self._format_sse(chunk, event="body")
else:
yield chunk
except Exception as e:
error_explainer(e)
yield await self._format_sse({"error": 500, "reason": str(e)}, event="body")
async def _format_sse(self, data: Any, event=None) -> str:
"""Format data for Server-Sent Events (SSE).
Args:
data (Any): The data to format.
event (str, optional): The event type.
Returns:
str: Formatted SSE message.
"""
msg = f"data: {json.dumps(data)}\n\n"
if event is not None:
msg = f"event: {event}\n{msg}"
return msg
async def _model_n_model_type_validator(self, **kwds: Any) -> bool:
"""Validate if the requested model is compatible with the current model type.
Args:
**kwds: Keyword arguments that may contain a 'model_name'.
Returns:
bool: True if the model is compatible, False otherwise.
"""
if kwds.get("model_name"):
if kwds.get("model_name") in self.supported_models.get(self.model_provider):
return True
elif kwds.get("model_name") in [
item for sublist in self.supported_models.values() for item in sublist
]:
warnings.warn(
f"{kwds.get('model_name')} is a supported model but does not belong to {self.model_provider} provider. "
f"Please reinitialize the AsyncLLM class with the '{kwds.get('model_name')}' model and the correct provider. "
f"Hence defaulting the model to {self.model_name}",
UserWarning,
)
return False
else:
warnings.warn(
f"\nModel {kwds.get('model_name')} for provider {self.model_provider} is not supported by Orichain. Supported models for {self.model_provider} are: [{', '.join(self.supported_models.get(self.model_provider))}] \nUsing an unsupported model may lead to unexpected issues. Please verify that you are using the correct 'model_name'",
UserWarning,
)
return True
else:
return False