| from fastapi import FastAPI, HTTPException, Request |
| from fastapi.responses import JSONResponse |
| from fastapi.middleware.cors import CORSMiddleware |
| from pydantic import BaseModel, Field |
| from typing import List, Optional, Dict, Any |
| import numpy as np |
| import base64 |
| import logging |
| import sys |
| import traceback |
| import io |
| from PIL import Image |
| import json |
|
|
| |
| try: |
| import faceforge_core |
| from faceforge_core.latent_explorer import LatentSpaceExplorer |
| from faceforge_core.attribute_directions import LatentDirectionFinder |
| from faceforge_core.custom_loss import attribute_preserving_loss |
| HAS_CORE = True |
| except ImportError as e: |
| logging.warning(f"Failed to import faceforge_core modules: {e}") |
| logging.warning("Using mock implementations instead") |
| HAS_CORE = False |
|
|
| |
| logging.basicConfig( |
| level=logging.DEBUG, |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', |
| handlers=[logging.StreamHandler(sys.stdout)] |
| ) |
| logger = logging.getLogger("faceforge_api") |
|
|
| |
|
|
| class PointIn(BaseModel): |
| text: str |
| encoding: Optional[List[float]] = Field(None) |
| xy_pos: Optional[List[float]] = Field(None) |
|
|
| class GenerateRequest(BaseModel): |
| prompts: List[str] |
| positions: Optional[List[List[float]]] = Field(None) |
| mode: str = "distance" |
| player_pos: Optional[List[float]] = Field(None) |
|
|
| class ManipulateRequest(BaseModel): |
| encoding: List[float] |
| direction: List[float] |
| alpha: float |
|
|
| class AttributeDirectionRequest(BaseModel): |
| latents: List[List[float]] |
| labels: Optional[List[int]] = Field(None) |
| n_components: Optional[int] = 10 |
|
|
| |
|
|
| class MockLatentSpaceExplorer: |
| def __init__(self): |
| self.points = [] |
| logger.warning("Using mock LatentSpaceExplorer") |
| |
| def add_point(self, text, encoding=None, xy_pos=None): |
| logger.debug(f"Mock add_point: {text}") |
| self.points.append({"text": text, "xy_pos": xy_pos}) |
| |
| def sample_encoding(self, player_pos, mode="distance"): |
| logger.debug(f"Mock sample_encoding: {player_pos}, {mode}") |
| |
| return np.random.randn(1, 4, 64, 64) |
|
|
| class MockLatentDirectionFinder: |
| def __init__(self, latents): |
| self.latents = latents |
| logger.warning("Using mock LatentDirectionFinder") |
| |
| def classifier_direction(self, labels): |
| return np.random.randn(512) |
| |
| def pca_direction(self, n_components=10): |
| components = np.random.randn(n_components, 512) |
| explained = np.random.rand(n_components) |
| return components, explained |
|
|
| |
|
|
| app = FastAPI( |
| title="FaceForge API", |
| description="API for latent space exploration and manipulation", |
| version="1.0.0", |
| |
| root_path="" |
| ) |
|
|
| |
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
| |
| explorer = LatentSpaceExplorer() if HAS_CORE else MockLatentSpaceExplorer() |
|
|
| |
| @app.middleware("http") |
| async def error_handling_middleware(request: Request, call_next): |
| try: |
| return await call_next(request) |
| except Exception as e: |
| logger.error(f"Unhandled exception: {str(e)}") |
| logger.debug(traceback.format_exc()) |
| return JSONResponse( |
| status_code=500, |
| content={"detail": "Internal server error", "error": str(e)}, |
| ) |
|
|
| @app.get("/") |
| def read_root(): |
| logger.debug("API root endpoint called") |
| return {"message": "FaceForge API is running"} |
|
|
| @app.post("/generate") |
| async def generate_image(req: GenerateRequest): |
| try: |
| logger.debug(f"Generate image request: {json.dumps(req.dict(), default=str)}") |
| |
| |
| logger.debug(f"Request schema: {GenerateRequest.schema_json()}") |
| |
| |
| explorer.points = [] |
| |
| |
| for i, prompt in enumerate(req.prompts): |
| logger.debug(f"Processing prompt {i}: {prompt}") |
| |
| |
| encoding = np.random.randn(512) |
| |
| |
| xy_pos = req.positions[i] if req.positions and i < len(req.positions) else None |
| logger.debug(f"Position for prompt {i}: {xy_pos}") |
| |
| |
| explorer.add_point(prompt, encoding, xy_pos) |
| |
| |
| if req.player_pos is None: |
| player_pos = [0.0, 0.0] |
| else: |
| player_pos = req.player_pos |
| logger.debug(f"Player position: {player_pos}") |
| |
| |
| logger.debug(f"Sampling with mode: {req.mode}") |
| sampled = explorer.sample_encoding(tuple(player_pos), mode=req.mode) |
| |
| |
| img = (np.random.rand(256, 256, 3) * 255).astype(np.uint8) |
| |
| |
| logger.debug("Converting image to base64") |
| pil_img = Image.fromarray(img) |
| buffer = io.BytesIO() |
| pil_img.save(buffer, format="PNG") |
| img_b64 = base64.b64encode(buffer.getvalue()).decode("utf-8") |
| |
| |
| response = {"status": "success", "image": img_b64} |
| logger.debug(f"Response structure: {list(response.keys())}") |
| logger.debug(f"Image base64 length: {len(img_b64)}") |
| |
| logger.debug("Image generated successfully") |
| return response |
| |
| except Exception as e: |
| logger.error(f"Error in generate_image: {str(e)}") |
| logger.debug(traceback.format_exc()) |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
| @app.post("/manipulate") |
| def manipulate(req: ManipulateRequest): |
| try: |
| logger.debug(f"Manipulate request: {json.dumps(req.dict(), default=str)}") |
| encoding = np.array(req.encoding) |
| direction = np.array(req.direction) |
| manipulated = encoding + req.alpha * direction |
| logger.debug("Manipulation successful") |
| return {"manipulated_encoding": manipulated.tolist()} |
| except Exception as e: |
| logger.error(f"Error in manipulate: {str(e)}") |
| logger.debug(traceback.format_exc()) |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
| @app.post("/attribute_direction") |
| def attribute_direction(req: AttributeDirectionRequest): |
| try: |
| logger.debug(f"Attribute direction request: {json.dumps(req.dict(), default=str)}") |
| latents = np.array(req.latents) |
| |
| finder = LatentDirectionFinder(latents) if HAS_CORE else MockLatentDirectionFinder(latents) |
| |
| if req.labels is not None: |
| logger.debug("Using classifier-based direction finding") |
| direction = finder.classifier_direction(req.labels) |
| logger.debug("Direction found successfully") |
| return {"direction": direction.tolist()} |
| else: |
| logger.debug(f"Using PCA with {req.n_components} components") |
| components, explained = finder.pca_direction(n_components=req.n_components) |
| logger.debug("PCA completed successfully") |
| return {"components": components.tolist(), "explained_variance": explained.tolist()} |
| except Exception as e: |
| logger.error(f"Error in attribute_direction: {str(e)}") |
| logger.debug(traceback.format_exc()) |
| raise HTTPException(status_code=500, detail=str(e)) |