| | package server |
| |
|
| | import ( |
| | "context" |
| | "crypto/subtle" |
| | "encoding/json" |
| | "fmt" |
| | "net/http" |
| | "regexp" |
| | "strings" |
| | "time" |
| |
|
| | "gcli2api/internal/codeassist" |
| | "gcli2api/internal/config" |
| | "gcli2api/internal/gemini" |
| |
|
| | |
| |
|
| | "github.com/sirupsen/logrus" |
| | "github.com/tiktoken-go/tokenizer" |
| | ) |
| |
|
| | var ( |
| | modelPathUnary = regexp.MustCompile(`^/v1beta/models/([^/]+):generateContent$`) |
| | modelPathStream = regexp.MustCompile(`^/v1beta/models/([^/]+):streamGenerateContent$`) |
| | ) |
| |
|
| | |
| | type CodeAssist interface { |
| | GenerateContent(ctx context.Context, model, project string, req gemini.GeminiRequest) (*gemini.GeminiAPIResponse, error) |
| | GenerateContentStream(ctx context.Context, model, project string, req gemini.GeminiRequest) (<-chan gemini.GeminiAPIResponse, <-chan error) |
| | } |
| |
|
| | type Server struct { |
| | cfg config.Config |
| | httpCli *http.Client |
| | caClient CodeAssist |
| | |
| | sem chan struct{} |
| | } |
| |
|
| | func New(cfg config.Config, httpCli *http.Client) *Server { |
| | |
| | if cfg.RequestMaxRetries == 0 { |
| | cfg.RequestMaxRetries = 3 |
| | } |
| | if cfg.RequestBaseDelayMillis == 0 { |
| | cfg.RequestBaseDelayMillis = 1000 |
| | } |
| | if cfg.RequestMaxBodyBytes == 0 { |
| | cfg.RequestMaxBodyBytes = 16 * 1024 * 1024 |
| | } |
| | if cfg.MaxConcurrentRequests == 0 { |
| | cfg.MaxConcurrentRequests = 64 |
| | } |
| | return &Server{ |
| | cfg: cfg, |
| | httpCli: httpCli, |
| | caClient: codeassist.NewCaClient(httpCli, cfg.RequestMaxRetries, time.Duration(cfg.RequestBaseDelayMillis)*time.Millisecond), |
| | sem: make(chan struct{}, cfg.MaxConcurrentRequests), |
| | } |
| | } |
| |
|
| | |
| | func NewWithCAClient(cfg config.Config, ca CodeAssist) *Server { |
| | |
| | if cfg.RequestMaxRetries == 0 { |
| | cfg.RequestMaxRetries = 3 |
| | } |
| | if cfg.RequestBaseDelayMillis == 0 { |
| | cfg.RequestBaseDelayMillis = 1000 |
| | } |
| | if cfg.RequestMaxBodyBytes == 0 { |
| | cfg.RequestMaxBodyBytes = 16 * 1024 * 1024 |
| | } |
| | if cfg.MaxConcurrentRequests == 0 { |
| | cfg.MaxConcurrentRequests = 64 |
| | } |
| | return &Server{cfg: cfg, caClient: ca, sem: make(chan struct{}, cfg.MaxConcurrentRequests)} |
| | } |
| |
|
| | func (s *Server) Router() http.Handler { |
| | mux := http.NewServeMux() |
| | mux.HandleFunc("/health", s.handleHealth) |
| | mux.HandleFunc("/v1beta/models", s.handleListModels) |
| | mux.HandleFunc("/v1beta/models/", s.handleModel) |
| | |
| | return s.withRecover(s.withLogging(s.withConcurrencyLimit(mux))) |
| | } |
| |
|
| | func (s *Server) handleHealth(w http.ResponseWriter, r *http.Request) { |
| | w.Header().Set("Content-Type", "application/json") |
| | w.WriteHeader(http.StatusOK) |
| | _ = json.NewEncoder(w).Encode(map[string]string{"status": "ok"}) |
| | } |
| |
|
| | func (s *Server) authorize(r *http.Request) bool { |
| | key := s.cfg.AuthKey |
| | if key == "" { |
| | return true |
| | } |
| | if ah := r.Header.Get("Authorization"); ah != "" { |
| | const p = "Bearer " |
| | if strings.HasPrefix(ah, p) { |
| | |
| | if 1 == subtle.ConstantTimeCompare([]byte(strings.TrimSpace(ah[len(p):])), []byte(key)) { |
| | return true |
| | } |
| | } |
| | } |
| | if h := r.Header.Get("x-goog-api-key"); h != "" { |
| | if 1 == subtle.ConstantTimeCompare([]byte(h), []byte(key)) { |
| | return true |
| | } |
| | } |
| | |
| | if qk := r.URL.Query().Get("key"); qk != "" { |
| | if 1 == subtle.ConstantTimeCompare([]byte(qk), []byte(key)) { |
| | return true |
| | } |
| | } |
| | return false |
| | } |
| |
|
| | func (s *Server) handleListModels(w http.ResponseWriter, r *http.Request) { |
| | if r.Method != http.MethodGet { |
| | http.Error(w, "method not allowed", http.StatusMethodNotAllowed) |
| | return |
| | } |
| | if !s.authorize(r) { |
| | http.Error(w, "unauthorized", http.StatusUnauthorized) |
| | return |
| | } |
| | w.Header().Set("Content-Type", "application/json") |
| | _ = json.NewEncoder(w).Encode(listModels()) |
| | } |
| |
|
| | func (s *Server) handleModel(w http.ResponseWriter, r *http.Request) { |
| | if !s.authorize(r) { |
| | http.Error(w, "unauthorized", http.StatusUnauthorized) |
| | return |
| | } |
| | if r.Method != http.MethodPost { |
| | http.Error(w, "method not allowed", http.StatusMethodNotAllowed) |
| | return |
| | } |
| | path := r.URL.Path |
| | if m := modelPathUnary.FindStringSubmatch(path); m != nil { |
| | model := m[1] |
| | s.handleGenerateContent(model, w, r) |
| | return |
| | } |
| | if m := modelPathStream.FindStringSubmatch(path); m != nil { |
| | model := m[1] |
| | s.handleStreamGenerateContent(model, w, r) |
| | return |
| | } |
| | http.NotFound(w, r) |
| | } |
| |
|
| | func (s *Server) validateModel(model string) bool { |
| | return gemini.IsSupportedModel(model) |
| | } |
| |
|
| | func (s *Server) decodeGeminiRequest(r *http.Request) (gemini.GeminiRequest, error) { |
| | var req gemini.GeminiRequest |
| | dec := json.NewDecoder(r.Body) |
| | if err := dec.Decode(&req); err != nil { |
| | return req, err |
| | } |
| | req = gemini.NormalizeGeminiRequest(req) |
| | return req, nil |
| | } |
| |
|
| | func (s *Server) handleGenerateContent(model string, w http.ResponseWriter, r *http.Request) { |
| | if !s.validateModel(model) { |
| | http.Error(w, "unknown model", http.StatusBadRequest) |
| | return |
| | } |
| | |
| | r.Body = http.MaxBytesReader(w, r.Body, s.cfg.RequestMaxBodyBytes) |
| | req, err := s.decodeGeminiRequest(r) |
| | if err != nil { |
| | http.Error(w, fmt.Sprintf("bad request: %v", err), http.StatusBadRequest) |
| | return |
| | } |
| | |
| | var thinking any |
| | if req.GenerationConfig != nil { |
| | thinking = req.GenerationConfig.ThinkingConfig |
| | } |
| | totalTokens := countRequestTokens(req) |
| | logrus.WithFields(logrus.Fields{ |
| | "model": model, |
| | "thinkingConfig": thinking, |
| | "totalTokens": totalTokens, |
| | }).Info("sending to upstream") |
| | ctx, cancel := context.WithTimeout(r.Context(), 5*time.Minute) |
| | defer cancel() |
| | resp, err := s.caClient.GenerateContent(ctx, model, "", req) |
| | if err != nil { |
| | http.Error(w, err.Error(), httpStatusFromError(err)) |
| | return |
| | } |
| | w.Header().Set("Content-Type", "application/json") |
| | _ = json.NewEncoder(w).Encode(resp) |
| | } |
| |
|
| | func (s *Server) handleStreamGenerateContent(model string, w http.ResponseWriter, r *http.Request) { |
| | if !s.validateModel(model) { |
| | http.Error(w, "unknown model", http.StatusBadRequest) |
| | return |
| | } |
| | |
| | r.Body = http.MaxBytesReader(w, r.Body, s.cfg.RequestMaxBodyBytes) |
| | req, err := s.decodeGeminiRequest(r) |
| | if err != nil { |
| | http.Error(w, fmt.Sprintf("bad request: %v", err), http.StatusBadRequest) |
| | return |
| | } |
| | |
| | flusher, ok := w.(http.Flusher) |
| | if !ok { |
| | logrus.Warn("streaming unsupported") |
| | http.Error(w, "streaming unsupported", http.StatusInternalServerError) |
| | return |
| | } |
| | |
| | w.Header().Set("Content-Type", "text/event-stream") |
| | w.Header().Set("Cache-Control", "no-cache") |
| | w.Header().Set("Connection", "keep-alive") |
| | w.Header().Set("X-Accel-Buffering", "no") |
| |
|
| | ctx, cancel := context.WithCancel(r.Context()) |
| | defer cancel() |
| | out, errs := s.caClient.GenerateContentStream(ctx, model, "", req) |
| |
|
| | |
| | var thinking any |
| | if req.GenerationConfig != nil { |
| | thinking = req.GenerationConfig.ThinkingConfig |
| | } |
| | totalTokens := countRequestTokens(req) |
| | logrus.WithFields(logrus.Fields{ |
| | "model": model, |
| | "thinkingConfig": thinking, |
| | "totalTokens": totalTokens, |
| | }).Info("sending to upstream") |
| | enc := json.NewEncoder(w) |
| | for { |
| | select { |
| | case g, ok := <-out: |
| | if !ok { |
| | return |
| | } |
| | |
| | if _, err := fmt.Fprint(w, "data: "); err != nil { |
| | logrus.Errorf("error writing data prefix: %v", err) |
| | return |
| | } |
| | if err := enc.Encode(g); err != nil { |
| | return |
| | } |
| | |
| | if _, err := fmt.Fprint(w, "\n"); err != nil { |
| | logrus.Errorf("error writing newline: %v", err) |
| | return |
| | } |
| | flusher.Flush() |
| | case e, ok := <-errs: |
| | |
| | |
| | |
| | if !ok || e == nil { |
| | |
| | errs = nil |
| | continue |
| | } |
| | |
| | if _, err := fmt.Fprint(w, "event: error\n"); err != nil { |
| | logrus.Errorf("error writing error event: %v", err) |
| | return |
| | } |
| | if _, err := fmt.Fprintf(w, "data: {\"error\":{\"message\":%q}}\n\n", e.Error()); err != nil { |
| | logrus.Errorf("error writing error data: %v", err) |
| | return |
| | } |
| | flusher.Flush() |
| | return |
| | case <-ctx.Done(): |
| | return |
| | } |
| | } |
| | } |
| |
|
| | |
| | |
| | |
| | func countRequestTokens(req gemini.GeminiRequest) int { |
| | enc, err := tokenizer.Get(tokenizer.O200kBase) |
| | if err != nil { |
| | return 0 |
| | } |
| | total := 0 |
| | |
| | |
| | for _, c := range req.Contents { |
| | for _, p := range c.Parts { |
| | if p.Text != "" { |
| | if n, err := enc.Count(p.Text); err == nil { |
| | total += n |
| | } |
| | } |
| | } |
| | } |
| | return total |
| | } |
| |
|
| | func httpStatusFromError(err error) int { |
| | |
| | s := err.Error() |
| | if strings.Contains(s, "status 401") { |
| | return http.StatusUnauthorized |
| | } |
| | if strings.Contains(s, "status 403") { |
| | return http.StatusForbidden |
| | } |
| | if strings.Contains(s, "status 429") { |
| | return http.StatusTooManyRequests |
| | } |
| | if strings.Contains(s, "status 5") { |
| | return http.StatusBadGateway |
| | } |
| | return http.StatusBadRequest |
| | } |
| |
|
| | func listModels() interface{} { |
| | type model struct { |
| | Name string `json:"name"` |
| | Version string `json:"version"` |
| | DisplayName string `json:"displayName"` |
| | Description string `json:"description"` |
| | SupportedGenerationMethods []string `json:"supportedGenerationMethods"` |
| | } |
| | out := struct { |
| | Models []model `json:"models"` |
| | }{Models: make([]model, 0, len(gemini.SupportedModels))} |
| | for _, m := range gemini.SupportedModels { |
| | out.Models = append(out.Models, model{ |
| | Name: "models/" + m.Name, |
| | Version: "001", |
| | DisplayName: m.DisplayName, |
| | Description: m.Description, |
| | SupportedGenerationMethods: []string{"generateContent", "streamGenerateContent"}, |
| | }) |
| | } |
| | return out |
| | } |