| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | from dataclasses import dataclass |
| | from typing import Dict, Optional |
| |
|
| | import torch |
| |
|
| | from .df_conditioner import BaseVideoCondition, GeneralConditioner |
| | from .df_config_base_conditioner import ( |
| | FPSConfig, |
| | ImageSizeConfig, |
| | LatentConditionConfig, |
| | LatentConditionSigmaConfig, |
| | NumFramesConfig, |
| | PaddingMaskConfig, |
| | TextConfig, |
| | ) |
| | from .lazy_config_init import LazyCall as L |
| | from .lazy_config_init import LazyDict |
| |
|
| |
|
| | @dataclass |
| | class VideoLatentDiffusionDecoderCondition(BaseVideoCondition): |
| | |
| | |
| | latent_condition: Optional[torch.Tensor] = None |
| | latent_condition_sigma: Optional[torch.Tensor] = None |
| |
|
| |
|
| | class VideoDiffusionDecoderConditioner(GeneralConditioner): |
| | def forward( |
| | self, |
| | batch: Dict, |
| | override_dropout_rate: Optional[Dict[str, float]] = None, |
| | ) -> VideoLatentDiffusionDecoderCondition: |
| | output = super()._forward(batch, override_dropout_rate) |
| | return VideoLatentDiffusionDecoderCondition(**output) |
| |
|
| |
|
| | VideoLatentDiffusionDecoderConditionerConfig: LazyDict = L(VideoDiffusionDecoderConditioner)( |
| | text=TextConfig(), |
| | fps=FPSConfig(), |
| | num_frames=NumFramesConfig(), |
| | image_size=ImageSizeConfig(), |
| | padding_mask=PaddingMaskConfig(), |
| | latent_condition=LatentConditionConfig(), |
| | latent_condition_sigma=LatentConditionSigmaConfig(), |
| | ) |
| |
|