| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from typing import Optional, List |
| from diffusers.modular_pipelines import ( |
| ModularPipelineBlocks, |
| ComponentSpec, |
| InputParam, |
| OutputParam, |
| ModularPipeline, |
| PipelineState, |
| ) |
| from diffusers.guiders import ClassifierFreeGuidance |
| from transformers import UMT5EncoderModel, AutoTokenizer |
| from diffusers.image_processor import PipelineImageInput |
| import torch |
| from diffusers.modular_pipelines.wan.encoders import WanTextEncoderStep |
| from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor |
| from diffusers.video_processor import VideoProcessor |
| from diffusers.configuration_utils import FrozenDict |
|
|
|
|
| class ChronoEditImageEncoderStep(ModularPipelineBlocks): |
| model_name = "chronoedit" |
|
|
| @property |
| def expected_components(self) -> List[ComponentSpec]: |
| return [ |
| ComponentSpec("image_processor", CLIPImageProcessor), |
| ComponentSpec("image_encoder", CLIPVisionModelWithProjection), |
| ] |
|
|
| @property |
| def inputs(self) -> List[InputParam]: |
| return [InputParam("image", type_hint=PipelineImageInput)] |
|
|
| @property |
| def intermediate_outputs(self) -> List[OutputParam]: |
| return [ |
| OutputParam( |
| "image_embeds", |
| type_hint=torch.Tensor, |
| description="Image embeddings to use as conditions during the denoising process.", |
| ) |
| ] |
|
|
| @staticmethod |
| def encode_image(components, image: PipelineImageInput, device: Optional[torch.device] = None): |
| device = device or components.image_encoder.device |
| image = components.image_processor(images=image, return_tensors="pt").to(device) |
| image_embeds = components.image_encoder(**image, output_hidden_states=True) |
| return image_embeds.hidden_states[-2] |
|
|
| @torch.no_grad() |
| def __call__(self, components: ModularPipeline, state: PipelineState) -> PipelineState: |
| block_state = self.get_block_state(state) |
| block_state.image_embeds = self.encode_image(components, block_state.image, components._execution_device) |
| self.set_block_state(state, block_state) |
| return components, state |
|
|
|
|
| class ChronoEditProcessImageStep(ModularPipelineBlocks): |
| model_name = "chronoedit" |
|
|
| @property |
| def inputs(self) -> List[InputParam]: |
| return [ |
| InputParam("image", type_hint=PipelineImageInput), |
| InputParam("image_embeds", type_hint=torch.Tensor, required=False), |
| InputParam("batch_size", type_hint=int, required=False), |
| InputParam("height", type_hint=int), |
| InputParam("width", type_hint=int), |
| ] |
|
|
| @property |
| def intermediate_outputs(self) -> List[OutputParam]: |
| return [ |
| OutputParam("processed_image", type_hint=PipelineImageInput), |
| OutputParam("image_embeds", type_hint=torch.Tensor), |
| ] |
|
|
| @property |
| def expected_components(self) -> List[ComponentSpec]: |
| return [ |
| ComponentSpec( |
| "video_processor", |
| VideoProcessor, |
| config=FrozenDict({"vae_scale_factor": 8}), |
| default_creation_method="from_config", |
| ) |
| ] |
|
|
| @torch.no_grad() |
| def __call__(self, components: ModularPipeline, state: PipelineState) -> PipelineState: |
| block_state = self.get_block_state(state) |
| image = block_state.image |
| device = components._execution_device |
|
|
| block_state.processed_image = components.video_processor.preprocess( |
| image, height=block_state.height, width=block_state.width |
| ).to(device, dtype=torch.bfloat16) |
|
|
| if block_state.image_embeds is not None: |
| image_embeds = block_state.image_embeds |
| batch_size = block_state.batch_size |
| block_state.image_embeds = image_embeds.repeat(batch_size, 1, 1).to(torch.bfloat16) |
|
|
| self.set_block_state(state, block_state) |
|
|
| return components, state |
|
|
|
|
| |
| class ChronoEditTextEncoderStep(WanTextEncoderStep): |
| model_name = "chronoedit" |
|
|
| @property |
| def expected_components(self) -> List[ComponentSpec]: |
| return [ |
| ComponentSpec("text_encoder", UMT5EncoderModel), |
| ComponentSpec("tokenizer", AutoTokenizer), |
| ComponentSpec( |
| "guider", |
| ClassifierFreeGuidance, |
| config=FrozenDict({"guidance_scale": 1.0}), |
| default_creation_method="from_config", |
| ), |
| ] |
|
|
| @torch.no_grad() |
| def __call__(self, components: ModularPipeline, state: PipelineState) -> PipelineState: |
| |
| block_state = self.get_block_state(state) |
| self.check_inputs(block_state) |
|
|
| block_state.prepare_unconditional_embeds = components.guider.num_conditions > 1 |
| block_state.device = components._execution_device |
|
|
| block_state.negative_prompt_embeds = None |
| |
| ( |
| block_state.prompt_embeds, |
| block_state.negative_prompt_embeds, |
| ) = self.encode_prompt( |
| components, |
| block_state.prompt, |
| block_state.device, |
| 1, |
| block_state.prepare_unconditional_embeds, |
| block_state.negative_prompt, |
| prompt_embeds=None, |
| negative_prompt_embeds=block_state.negative_prompt_embeds, |
| ) |
|
|
| |
| self.set_block_state(state, block_state) |
| return components, state |
|
|