| import logging |
| import pathlib |
| import gradio as gr |
| import pandas as pd |
| from gt4sd.algorithms.generation.diffusion import ( |
| DiffusersGenerationAlgorithm, |
| DDPMGenerator, |
| DDIMGenerator, |
| ScoreSdeGenerator, |
| LDMTextToImageGenerator, |
| LDMGenerator, |
| StableDiffusionGenerator, |
| ) |
| from gt4sd.algorithms.registry import ApplicationsRegistry |
|
|
| logger = logging.getLogger(__name__) |
| logger.addHandler(logging.NullHandler()) |
|
|
|
|
| def run_inference(model_type: str, prompt: str): |
|
|
| if prompt == "": |
| config = eval(f"{model_type}()") |
| else: |
| config = eval(f'{model_type}(prompt="{prompt}")') |
| if config.modality != "token2image" and prompt != "": |
| raise ValueError( |
| f"{model_type} is an unconditional generative model, please remove prompt (not={prompt})" |
| ) |
| model = DiffusersGenerationAlgorithm(config) |
| image = list(model.sample(1))[0] |
|
|
| return image |
|
|
|
|
| if __name__ == "__main__": |
|
|
| |
| all_algos = ApplicationsRegistry.list_available() |
| algos = [ |
| x["algorithm_application"] |
| for x in list(filter(lambda x: "Diff" in x["algorithm_name"], all_algos)) |
| ] |
| algos = [a for a in algos if not "GeoDiff" in a] |
|
|
| |
| metadata_root = pathlib.Path(__file__).parent.joinpath("model_cards") |
|
|
| examples = pd.read_csv(metadata_root.joinpath("examples.csv"), header=None).fillna( |
| "" |
| ) |
|
|
| with open(metadata_root.joinpath("article.md"), "r") as f: |
| article = f.read() |
| with open(metadata_root.joinpath("description.md"), "r") as f: |
| description = f.read() |
|
|
| demo = gr.Interface( |
| fn=run_inference, |
| title="Diffusion-based image generators", |
| inputs=[ |
| gr.Dropdown( |
| algos, label="Diffusion model", value="StableDiffusionGenerator" |
| ), |
| gr.Textbox(label="Text prompt", placeholder="A blue tree", lines=1), |
| ], |
| outputs=gr.Image(type="pil"), |
| article=article, |
| description=description, |
| examples=examples.values.tolist(), |
| ) |
| demo.launch(debug=True, show_error=True) |
|
|