193 lines
6.4 KiB
Python
193 lines
6.4 KiB
Python
import os
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
from mem0.configs.llms.base import BaseLlmConfig
|
|
|
|
|
|
class AWSBedrockConfig(BaseLlmConfig):
|
|
"""
|
|
Configuration class for AWS Bedrock LLM integration.
|
|
|
|
Supports all available Bedrock models with automatic provider detection.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
model: Optional[str] = None,
|
|
temperature: float = 0.1,
|
|
max_tokens: int = 2000,
|
|
top_p: float = 0.9,
|
|
top_k: int = 1,
|
|
aws_access_key_id: Optional[str] = None,
|
|
aws_secret_access_key: Optional[str] = None,
|
|
aws_region: str = "",
|
|
aws_session_token: Optional[str] = None,
|
|
aws_profile: Optional[str] = None,
|
|
model_kwargs: Optional[Dict[str, Any]] = None,
|
|
**kwargs,
|
|
):
|
|
"""
|
|
Initialize AWS Bedrock configuration.
|
|
|
|
Args:
|
|
model: Bedrock model identifier (e.g., "amazon.nova-3-mini-20241119-v1:0")
|
|
temperature: Controls randomness (0.0 to 2.0)
|
|
max_tokens: Maximum tokens to generate
|
|
top_p: Nucleus sampling parameter (0.0 to 1.0)
|
|
top_k: Top-k sampling parameter (1 to 40)
|
|
aws_access_key_id: AWS access key (optional, uses env vars if not provided)
|
|
aws_secret_access_key: AWS secret key (optional, uses env vars if not provided)
|
|
aws_region: AWS region for Bedrock service
|
|
aws_session_token: AWS session token for temporary credentials
|
|
aws_profile: AWS profile name for credentials
|
|
model_kwargs: Additional model-specific parameters
|
|
**kwargs: Additional arguments passed to base class
|
|
"""
|
|
super().__init__(
|
|
model=model or "anthropic.claude-3-5-sonnet-20240620-v1:0",
|
|
temperature=temperature,
|
|
max_tokens=max_tokens,
|
|
top_p=top_p,
|
|
top_k=top_k,
|
|
**kwargs,
|
|
)
|
|
|
|
self.aws_access_key_id = aws_access_key_id
|
|
self.aws_secret_access_key = aws_secret_access_key
|
|
self.aws_region = aws_region or os.getenv("AWS_REGION", "us-west-2")
|
|
self.aws_session_token = aws_session_token
|
|
self.aws_profile = aws_profile
|
|
self.model_kwargs = model_kwargs or {}
|
|
|
|
@property
|
|
def provider(self) -> str:
|
|
"""Get the provider from the model identifier."""
|
|
if not self.model or "." not in self.model:
|
|
return "unknown"
|
|
return self.model.split(".")[0]
|
|
|
|
@property
|
|
def model_name(self) -> str:
|
|
"""Get the model name without provider prefix."""
|
|
if not self.model or "." not in self.model:
|
|
return self.model
|
|
return ".".join(self.model.split(".")[1:])
|
|
|
|
def get_model_config(self) -> Dict[str, Any]:
|
|
"""Get model-specific configuration parameters."""
|
|
base_config = {
|
|
"temperature": self.temperature,
|
|
"max_tokens": self.max_tokens,
|
|
"top_p": self.top_p,
|
|
"top_k": self.top_k,
|
|
}
|
|
|
|
# Add custom model kwargs
|
|
base_config.update(self.model_kwargs)
|
|
|
|
return base_config
|
|
|
|
def get_aws_config(self) -> Dict[str, Any]:
|
|
"""Get AWS configuration parameters."""
|
|
config = {
|
|
"region_name": self.aws_region,
|
|
}
|
|
|
|
if self.aws_access_key_id:
|
|
config["aws_access_key_id"] = self.aws_access_key_id or os.getenv("AWS_ACCESS_KEY_ID")
|
|
|
|
if self.aws_secret_access_key:
|
|
config["aws_secret_access_key"] = self.aws_secret_access_key or os.getenv("AWS_SECRET_ACCESS_KEY")
|
|
|
|
if self.aws_session_token:
|
|
config["aws_session_token"] = self.aws_session_token or os.getenv("AWS_SESSION_TOKEN")
|
|
|
|
if self.aws_profile:
|
|
config["profile_name"] = self.aws_profile or os.getenv("AWS_PROFILE")
|
|
|
|
return config
|
|
|
|
def validate_model_format(self) -> bool:
|
|
"""
|
|
Validate that the model identifier follows Bedrock naming convention.
|
|
|
|
Returns:
|
|
True if valid, False otherwise
|
|
"""
|
|
if not self.model:
|
|
return False
|
|
|
|
# Check if model follows provider.model-name format
|
|
if "." not in self.model:
|
|
return False
|
|
|
|
provider, model_name = self.model.split(".", 1)
|
|
|
|
# Validate provider
|
|
valid_providers = [
|
|
"ai21", "amazon", "anthropic", "cohere", "meta", "mistral",
|
|
"stability", "writer", "deepseek", "gpt-oss", "perplexity",
|
|
"snowflake", "titan", "command", "j2", "llama"
|
|
]
|
|
|
|
if provider not in valid_providers:
|
|
return False
|
|
|
|
# Validate model name is not empty
|
|
if not model_name:
|
|
return False
|
|
|
|
return True
|
|
|
|
def get_supported_regions(self) -> List[str]:
|
|
"""Get list of AWS regions that support Bedrock."""
|
|
return [
|
|
"us-east-1",
|
|
"us-west-2",
|
|
"us-east-2",
|
|
"eu-west-1",
|
|
"ap-southeast-1",
|
|
"ap-northeast-1",
|
|
]
|
|
|
|
def get_model_capabilities(self) -> Dict[str, Any]:
|
|
"""Get model capabilities based on provider."""
|
|
capabilities = {
|
|
"supports_tools": False,
|
|
"supports_vision": False,
|
|
"supports_streaming": False,
|
|
"supports_multimodal": False,
|
|
}
|
|
|
|
if self.provider == "anthropic":
|
|
capabilities.update({
|
|
"supports_tools": True,
|
|
"supports_vision": True,
|
|
"supports_streaming": True,
|
|
"supports_multimodal": True,
|
|
})
|
|
elif self.provider == "amazon":
|
|
capabilities.update({
|
|
"supports_tools": True,
|
|
"supports_vision": True,
|
|
"supports_streaming": True,
|
|
"supports_multimodal": True,
|
|
})
|
|
elif self.provider == "cohere":
|
|
capabilities.update({
|
|
"supports_tools": True,
|
|
"supports_streaming": True,
|
|
})
|
|
elif self.provider == "meta":
|
|
capabilities.update({
|
|
"supports_vision": True,
|
|
"supports_streaming": True,
|
|
})
|
|
elif self.provider == "mistral":
|
|
capabilities.update({
|
|
"supports_vision": True,
|
|
"supports_streaming": True,
|
|
})
|
|
|
|
return capabilities
|