| | import json |
| | from typing import List |
| |
|
| | import torch |
| | from fastapi import FastAPI, Request, status, HTTPException |
| | from pydantic import BaseModel |
| | from torch.cuda import get_device_properties |
| | from transformers import AutoModel, AutoTokenizer |
| | from sse_starlette.sse import EventSourceResponse |
| | from fastapi.middleware.cors import CORSMiddleware |
| | import uvicorn |
| |
|
| | import os |
| |
|
| | os.environ['TRANSFORMERS_CACHE'] = ".cache" |
| |
|
| | bits = 4 |
| | kernel_path = "models/models--silver--chatglm-6b-int4-slim/quantization_kernels.so" |
| | model_path = "./models/models--silver--chatglm-6b-int4-slim/snapshots/02e096b3805c579caf5741a6d8eddd5ba7a74e0d" |
| | cache_dir = './models' |
| | model_name = 'chatglm-6b-int4' |
| | min_memory = 5.5 |
| | tokenizer = None |
| | model = None |
| |
|
| | app = FastAPI() |
| |
|
| | app.add_middleware( |
| | CORSMiddleware, |
| | allow_origins=["*"], |
| | allow_credentials=True, |
| | allow_methods=["*"], |
| | allow_headers=["*"], |
| | ) |
| |
|
| |
|
| | @app.on_event('startup') |
| | def init(): |
| | global tokenizer, model |
| | tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, cache_dir=cache_dir) |
| | model = AutoModel.from_pretrained(model_path, trust_remote_code=True, cache_dir=cache_dir) |
| |
|
| | if torch.cuda.is_available() and get_device_properties(0).total_memory / 1024 ** 3 > min_memory: |
| | model = model.half().quantize(bits=bits).cuda() |
| | print("Using GPU") |
| | else: |
| | model = model.float().quantize(bits=bits) |
| | if torch.cuda.is_available(): |
| | print("Total Memory: ", get_device_properties(0).total_memory / 1024 ** 3) |
| | else: |
| | print("No GPU available") |
| | print("Using CPU") |
| | model = model.eval() |
| | if os.environ.get("ngrok_token") is not None: |
| | ngrok_connect() |
| |
|
| |
|
| | class Message(BaseModel): |
| | role: str |
| | content: str |
| |
|
| |
|
| | class Body(BaseModel): |
| | messages: List[Message] |
| | model: str |
| | stream: bool |
| | max_tokens: int |
| |
|
| |
|
| | @app.get("/") |
| | def read_root(): |
| | return {"Hello": "World!"} |
| |
|
| |
|
| | @app.post("/chat/completions") |
| | async def completions(body: Body, request: Request): |
| | if not body.stream or body.model != model_name: |
| | raise HTTPException(status.HTTP_400_BAD_REQUEST, "Not Implemented") |
| |
|
| | question = body.messages[-1] |
| | if question.role == 'user': |
| | question = question.content |
| | else: |
| | raise HTTPException(status.HTTP_400_BAD_REQUEST, "No Question Found") |
| |
|
| | user_question = '' |
| | history = [] |
| | for message in body.messages: |
| | if message.role == 'user': |
| | user_question = message.content |
| | elif message.role == 'system' or message.role == 'assistant': |
| | assistant_answer = message.content |
| | history.append((user_question, assistant_answer)) |
| |
|
| | async def event_generator(): |
| | for response in model.stream_chat(tokenizer, question, history, max_length=max(2048, body.max_tokens)): |
| | if await request.is_disconnected(): |
| | return |
| | yield json.dumps({"response": response[0]}) |
| | yield "[DONE]" |
| |
|
| | return EventSourceResponse(event_generator()) |
| |
|
| |
|
| | def ngrok_connect(): |
| | from pyngrok import ngrok, conf |
| | conf.set_default(conf.PyngrokConfig(ngrok_path="./ngrok")) |
| | |
| | http_tunnel = ngrok.connect(8000) |
| | print(http_tunnel.public_url) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | uvicorn.run("main:app", reload=True, app_dir=".") |
| |
|