Spaces:
Paused
Paused
| from typing import List, Optional, Union | |
| from vllm.config import ModelConfig | |
| from vllm.engine.async_llm_engine import AsyncLLMEngine | |
| # yapf conflicts with isort for this block | |
| # yapf: disable | |
| from vllm.entrypoints.chat_utils import (ConversationMessage, | |
| load_chat_template, | |
| parse_chat_message_content) | |
| from vllm.entrypoints.logger import RequestLogger | |
| from vllm.entrypoints.openai.protocol import (DetokenizeRequest, | |
| DetokenizeResponse, | |
| ErrorResponse, | |
| TokenizeChatRequest, | |
| TokenizeRequest, | |
| TokenizeResponse) | |
| # yapf: enable | |
| from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, | |
| OpenAIServing) | |
| from vllm.utils import random_uuid | |
| class OpenAIServingTokenization(OpenAIServing): | |
| def __init__( | |
| self, | |
| engine: AsyncLLMEngine, | |
| model_config: ModelConfig, | |
| served_model_names: List[str], | |
| *, | |
| lora_modules: Optional[List[LoRAModulePath]], | |
| request_logger: Optional[RequestLogger], | |
| chat_template: Optional[str], | |
| ): | |
| super().__init__(engine=engine, | |
| model_config=model_config, | |
| served_model_names=served_model_names, | |
| lora_modules=lora_modules, | |
| prompt_adapters=None, | |
| request_logger=request_logger) | |
| # If this is None we use the tokenizer's default chat template | |
| self.chat_template = load_chat_template(chat_template) | |
| async def create_tokenize( | |
| self, | |
| request: TokenizeRequest, | |
| ) -> Union[TokenizeResponse, ErrorResponse]: | |
| error_check_ret = await self._check_model(request) | |
| if error_check_ret is not None: | |
| return error_check_ret | |
| request_id = f"tokn-{random_uuid()}" | |
| ( | |
| lora_request, | |
| prompt_adapter_request, | |
| ) = self._maybe_get_adapters(request) | |
| tokenizer = await self.engine.get_tokenizer(lora_request) | |
| if isinstance(request, TokenizeChatRequest): | |
| model_config = self.model_config | |
| conversation: List[ConversationMessage] = [] | |
| for message in request.messages: | |
| result = parse_chat_message_content(message, model_config, | |
| tokenizer) | |
| conversation.extend(result.messages) | |
| prompt = tokenizer.apply_chat_template( | |
| add_generation_prompt=request.add_generation_prompt, | |
| conversation=conversation, | |
| tokenize=False, | |
| chat_template=self.chat_template) | |
| assert isinstance(prompt, str) | |
| else: | |
| prompt = request.prompt | |
| self._log_inputs(request_id, | |
| prompt, | |
| params=None, | |
| lora_request=lora_request, | |
| prompt_adapter_request=prompt_adapter_request) | |
| # Silently ignore prompt adapter since it does not affect tokenization | |
| prompt_input = self._tokenize_prompt_input( | |
| request, | |
| tokenizer, | |
| prompt, | |
| add_special_tokens=request.add_special_tokens, | |
| ) | |
| input_ids = prompt_input["prompt_token_ids"] | |
| return TokenizeResponse(tokens=input_ids, | |
| count=len(input_ids), | |
| max_model_len=self.max_model_len) | |
| async def create_detokenize( | |
| self, | |
| request: DetokenizeRequest, | |
| ) -> Union[DetokenizeResponse, ErrorResponse]: | |
| error_check_ret = await self._check_model(request) | |
| if error_check_ret is not None: | |
| return error_check_ret | |
| request_id = f"tokn-{random_uuid()}" | |
| ( | |
| lora_request, | |
| prompt_adapter_request, | |
| ) = self._maybe_get_adapters(request) | |
| tokenizer = await self.engine.get_tokenizer(lora_request) | |
| self._log_inputs(request_id, | |
| request.tokens, | |
| params=None, | |
| lora_request=lora_request, | |
| prompt_adapter_request=prompt_adapter_request) | |
| if prompt_adapter_request is not None: | |
| raise NotImplementedError("Prompt adapter is not supported " | |
| "for tokenization") | |
| prompt_input = self._tokenize_prompt_input( | |
| request, | |
| tokenizer, | |
| request.tokens, | |
| ) | |
| input_text = prompt_input["prompt"] | |
| return DetokenizeResponse(prompt=input_text) |