| """ |
| API Client for Smart Auto-Complete |
| Handles communication with OpenAI and Anthropic APIs |
| """ |
|
|
| import logging |
| import time |
| from typing import Dict, List, Optional, Union |
|
|
| import anthropic |
| import openai |
|
|
| from .utils import validate_api_key |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class APIClient: |
| """ |
| Unified API client for multiple AI providers |
| Supports OpenAI GPT and Anthropic Claude models |
| """ |
|
|
| def __init__(self, settings=None): |
| """ |
| Initialize the API client with settings |
| |
| Args: |
| settings: Application settings object |
| """ |
| self.settings = settings |
| self.openai_client = None |
| self.anthropic_client = None |
| self.current_provider = None |
| self.request_count = 0 |
| self.last_request_time = 0 |
|
|
| self._initialize_clients() |
|
|
| def _get_token_param_name(self, model: str) -> str: |
| """ |
| Get the correct token parameter name based on the model |
| |
| Args: |
| model: The model name |
| |
| Returns: |
| The correct parameter name ('max_tokens' or 'max_completion_tokens') |
| """ |
| |
| if model.startswith(("o3", "o1")): |
| return "max_completion_tokens" |
| |
| return "max_tokens" |
|
|
| def _initialize_clients(self): |
| """Initialize API clients based on available keys""" |
| try: |
| |
| if ( |
| self.settings |
| and hasattr(self.settings, "OPENAI_API_KEY") |
| and self.settings.OPENAI_API_KEY |
| and validate_api_key(self.settings.OPENAI_API_KEY, "openai") |
| ): |
| self.openai_client = openai.OpenAI(api_key=self.settings.OPENAI_API_KEY) |
| logger.info("OpenAI client initialized successfully") |
|
|
| |
| if ( |
| self.settings |
| and hasattr(self.settings, "ANTHROPIC_API_KEY") |
| and self.settings.ANTHROPIC_API_KEY |
| and validate_api_key(self.settings.ANTHROPIC_API_KEY, "anthropic") |
| ): |
| self.anthropic_client = anthropic.Anthropic( |
| api_key=self.settings.ANTHROPIC_API_KEY |
| ) |
| logger.info("Anthropic client initialized successfully") |
|
|
| |
| if hasattr(self.settings, "DEFAULT_PROVIDER"): |
| self.current_provider = self.settings.DEFAULT_PROVIDER |
| elif self.openai_client: |
| self.current_provider = "openai" |
| elif self.anthropic_client: |
| self.current_provider = "anthropic" |
| else: |
| logger.warning("No valid API clients initialized") |
|
|
| except Exception as e: |
| logger.error(f"Error initializing API clients: {str(e)}") |
|
|
| def get_completion( |
| self, |
| messages: List[Dict[str, str]], |
| temperature: float = 0.7, |
| max_tokens: int = 150, |
| provider: Optional[str] = None, |
| ) -> Optional[str]: |
| """ |
| Get a completion from the specified provider |
| |
| Args: |
| messages: List of message dictionaries with 'role' and 'content' |
| temperature: Sampling temperature (0.0 to 1.0) |
| max_tokens: Maximum tokens in response |
| provider: Specific provider to use ('openai' or 'anthropic') |
| |
| Returns: |
| Generated completion text or None if failed |
| """ |
| try: |
| |
| if not self._check_rate_limit(): |
| logger.warning("Rate limit exceeded, skipping request") |
| return None |
|
|
| |
| use_provider = provider or self.current_provider |
|
|
| if use_provider == "openai" and self.openai_client: |
| return self._get_openai_completion(messages, temperature, max_tokens) |
| elif use_provider == "anthropic" and self.anthropic_client: |
| return self._get_anthropic_completion(messages, temperature, max_tokens) |
| else: |
| |
| if self.openai_client: |
| return self._get_openai_completion( |
| messages, temperature, max_tokens |
| ) |
| elif self.anthropic_client: |
| return self._get_anthropic_completion( |
| messages, temperature, max_tokens |
| ) |
| else: |
| logger.error("No API clients available") |
| return None |
|
|
| except Exception as e: |
| logger.error(f"Error getting completion: {str(e)}") |
| return None |
|
|
| def _get_openai_completion( |
| self, messages: List[Dict[str, str]], temperature: float, max_tokens: int |
| ) -> Optional[str]: |
| """Get completion from OpenAI API""" |
| try: |
| |
| model = ( |
| self.settings.get_model_for_provider("openai") |
| if self.settings |
| else "gpt-4o-mini" |
| ) |
|
|
| logger.debug(f"Using OpenAI model: {model}") |
|
|
| |
| token_param = self._get_token_param_name(model) |
| logger.debug(f"Using token parameter: {token_param} = {max_tokens}") |
|
|
| |
| request_params = { |
| "model": model, |
| "messages": messages, |
| token_param: max_tokens, |
| "n": 1, |
| "stop": None, |
| } |
|
|
| |
| |
| if not model.startswith(("o3", "o1")): |
| request_params["temperature"] = temperature |
| logger.debug(f"Using custom temperature: {temperature}") |
| else: |
| logger.debug(f"Using default temperature for reasoning model {model}") |
|
|
| |
| |
| if not model.startswith(("o3", "o1")): |
| request_params["presence_penalty"] = 0.1 |
| request_params["frequency_penalty"] = 0.1 |
|
|
| response = self.openai_client.chat.completions.create(**request_params) |
|
|
| self._update_request_stats() |
|
|
| if response.choices and len(response.choices) > 0: |
| return response.choices[0].message.content.strip() |
| else: |
| logger.warning("No choices returned from OpenAI API") |
| return None |
|
|
| except openai.RateLimitError: |
| logger.warning("OpenAI rate limit exceeded") |
| return None |
| except openai.APIError as e: |
| logger.error(f"OpenAI API error: {str(e)}") |
| return None |
| except Exception as e: |
| logger.error(f"Unexpected error with OpenAI: {str(e)}") |
| return None |
|
|
| def _get_anthropic_completion( |
| self, messages: List[Dict[str, str]], temperature: float, max_tokens: int |
| ) -> Optional[str]: |
| """Get completion from Anthropic API""" |
| try: |
| |
| system_message = "" |
| user_messages = [] |
|
|
| for msg in messages: |
| if msg["role"] == "system": |
| system_message = msg["content"] |
| else: |
| user_messages.append(msg) |
|
|
| |
| model = ( |
| self.settings.get_model_for_provider("anthropic") |
| if self.settings |
| else "claude-3-haiku-20240307" |
| ) |
|
|
| logger.debug(f"Using Anthropic model: {model}") |
|
|
| |
| response = self.anthropic_client.messages.create( |
| model=model, |
| max_tokens=max_tokens, |
| temperature=temperature, |
| system=system_message, |
| messages=user_messages, |
| ) |
|
|
| self._update_request_stats() |
|
|
| if response.content and len(response.content) > 0: |
| return response.content[0].text.strip() |
| else: |
| logger.warning("No content returned from Anthropic API") |
| return None |
|
|
| except anthropic.RateLimitError: |
| logger.warning("Anthropic rate limit exceeded") |
| return None |
| except anthropic.APIError as e: |
| logger.error(f"Anthropic API error: {str(e)}") |
| return None |
| except Exception as e: |
| logger.error(f"Unexpected error with Anthropic: {str(e)}") |
| return None |
|
|
| def _check_rate_limit(self) -> bool: |
| """ |
| Check if we're within rate limits |
| Simple implementation - can be enhanced with more sophisticated logic |
| """ |
| current_time = time.time() |
|
|
| |
| if current_time - self.last_request_time < 1.0: |
| return False |
|
|
| return True |
|
|
| def _update_request_stats(self): |
| """Update request statistics""" |
| self.request_count += 1 |
| self.last_request_time = time.time() |
|
|
| def get_available_providers(self) -> List[str]: |
| """Get list of available providers""" |
| providers = [] |
| if self.openai_client: |
| providers.append("openai") |
| if self.anthropic_client: |
| providers.append("anthropic") |
| return providers |
|
|
| def switch_provider(self, provider: str) -> bool: |
| """ |
| Switch to a different provider |
| |
| Args: |
| provider: Provider name ('openai' or 'anthropic') |
| |
| Returns: |
| True if switch was successful, False otherwise |
| """ |
| if provider == "openai" and self.openai_client: |
| self.current_provider = "openai" |
| logger.info("Switched to OpenAI provider") |
| return True |
| elif provider == "anthropic" and self.anthropic_client: |
| self.current_provider = "anthropic" |
| logger.info("Switched to Anthropic provider") |
| return True |
| else: |
| logger.warning(f"Cannot switch to provider: {provider}") |
| return False |
|
|
| def get_stats(self) -> Dict[str, Union[int, float, str]]: |
| """Get API usage statistics""" |
| return { |
| "request_count": self.request_count, |
| "current_provider": self.current_provider, |
| "available_providers": self.get_available_providers(), |
| "last_request_time": self.last_request_time, |
| } |
|
|
| def test_connection(self, provider: Optional[str] = None) -> bool: |
| """ |
| Test connection to the API provider |
| |
| Args: |
| provider: Specific provider to test, or None for current provider |
| |
| Returns: |
| True if connection is successful, False otherwise |
| """ |
| try: |
| test_messages = [ |
| {"role": "system", "content": "You are a helpful assistant."}, |
| {"role": "user", "content": "Say 'Hello' in one word."}, |
| ] |
|
|
| result = self.get_completion( |
| messages=test_messages, |
| temperature=0.1, |
| max_tokens=10, |
| provider=provider, |
| ) |
|
|
| return result is not None and len(result.strip()) > 0 |
|
|
| except Exception as e: |
| logger.error(f"Connection test failed: {str(e)}") |
| return False |
|
|