junaid17 commited on
Commit
faeb2df
·
verified ·
1 Parent(s): 967a65f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +105 -0
app.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from fastapi.responses import StreamingResponse
3
+ from pydantic import BaseModel
4
+
5
+ from transformers import (
6
+ AutoTokenizer,
7
+ AutoModelForCausalLM,
8
+ TextIteratorStreamer
9
+ )
10
+
11
+ import torch
12
+ from threading import Thread
13
+
14
+ # ============================================
15
+ # MODEL
16
+ # ============================================
17
+
18
+ MODEL_NAME = "junaid17/qwen-0.5b-16bit_merged"
19
+
20
+ tokenizer = AutoTokenizer.from_pretrained(
21
+ MODEL_NAME,
22
+ trust_remote_code=True
23
+ )
24
+
25
+ model = AutoModelForCausalLM.from_pretrained(
26
+ MODEL_NAME,
27
+ trust_remote_code=True,
28
+ torch_dtype=torch.float16,
29
+ device_map="auto"
30
+ )
31
+
32
+ # ============================================
33
+ # FASTAPI
34
+ # ============================================
35
+
36
+ app = FastAPI()
37
+
38
+ # ============================================
39
+ # REQUEST SCHEMA
40
+ # ============================================
41
+
42
+ class ChatRequest(BaseModel):
43
+ query: str
44
+ max_new_tokens: int = 256
45
+ temperature: float = 0.7
46
+
47
+ # ============================================
48
+ # STREAM CHAT
49
+ # ============================================
50
+
51
+ @app.post("/chat")
52
+ async def chat(request: ChatRequest):
53
+
54
+ messages = [
55
+ {
56
+ "role": "system",
57
+ "content": "You are a helpful AI assistant."
58
+ },
59
+ {
60
+ "role": "user",
61
+ "content": request.query
62
+ }
63
+ ]
64
+
65
+ prompt = tokenizer.apply_chat_template(
66
+ messages,
67
+ tokenize=False,
68
+ add_generation_prompt=True
69
+ )
70
+
71
+ inputs = tokenizer(
72
+ prompt,
73
+ return_tensors="pt"
74
+ ).to(model.device)
75
+
76
+ streamer = TextIteratorStreamer(
77
+ tokenizer,
78
+ skip_prompt=True,
79
+ skip_special_tokens=True
80
+ )
81
+
82
+ generation_kwargs = dict(
83
+ **inputs,
84
+ streamer=streamer,
85
+ max_new_tokens=request.max_new_tokens,
86
+ temperature=request.temperature,
87
+ do_sample=True,
88
+ pad_token_id=tokenizer.eos_token_id
89
+ )
90
+
91
+ thread = Thread(
92
+ target=model.generate,
93
+ kwargs=generation_kwargs
94
+ )
95
+
96
+ thread.start()
97
+
98
+ def generate_tokens():
99
+ for token in streamer:
100
+ yield token
101
+
102
+ return StreamingResponse(
103
+ generate_tokens(),
104
+ media_type="text/plain"
105
+ )