Spaces:
Build error
Build error
| import contextlib | |
| import functools | |
| import os | |
| from typing import Callable, List, Tuple | |
| import torch | |
| import torch.backends | |
| from diffusers.hooks import HookRegistry, ModelHook | |
| from finetrainers import logging, parallel, patches | |
| from finetrainers.args import BaseArgsType | |
| from finetrainers.logging import get_logger | |
| from finetrainers.models.attention_dispatch import AttentionProvider, _AttentionProviderRegistry | |
| from finetrainers.state import State | |
| logger = get_logger() | |
| _LATEST_ACTIVE_MODULE_HOOK = "latest_active_module_hook" | |
| class Trainer: | |
| def __init__(self, args: BaseArgsType): | |
| self.args = args | |
| self.state = State() | |
| self._module_name_providers_training = _parse_attention_providers(args.attn_provider_training) | |
| self._module_name_providers_inference = _parse_attention_providers(args.attn_provider_inference) | |
| self._init_distributed() | |
| self._init_config_options() | |
| # Perform any patches that might be necessary for training to work as expected | |
| patches.perform_patches_for_training(self.args, self.state.parallel_backend) | |
| def attention_provider_ctx(self, training: bool = True): | |
| name_providers_active = ( | |
| self._module_name_providers_training if training else self._module_name_providers_inference | |
| ) | |
| name_providers_dict = dict(name_providers_active) | |
| default_provider = _AttentionProviderRegistry._active_provider | |
| all_registered_module_names = [ | |
| attr for attr in dir(self) if isinstance(getattr(self, attr, None), torch.nn.Module) | |
| ] | |
| for module_name in all_registered_module_names: | |
| if module_name in name_providers_dict: | |
| continue | |
| name_providers_dict[module_name] = default_provider | |
| module_providers_dict = {} | |
| for module_name, provider in name_providers_dict.items(): | |
| module = getattr(self, module_name, None) | |
| if module is not None: | |
| module_providers_dict[module] = (module_name, provider) | |
| # We don't want to immediately unset the attention provider to default after forward because if the | |
| # model is being trained, the backward pass must be invoked with the same attention provider | |
| # So, we lazily switch attention providers only when the forward pass of a new module is called | |
| def callback(m: torch.nn.Module): | |
| module_name, provider = module_providers_dict[m] | |
| # HACK: for CP on transformer. Need to support other modules too and improve overall experience for external usage | |
| if module_name in ["transformer"] and self.state.parallel_backend.context_parallel_enabled: | |
| if not _AttentionProviderRegistry.supports_context_parallel(provider): | |
| raise ValueError( | |
| f"Attention provider {provider} does not support context parallel. Please use a different provider." | |
| ) | |
| _AttentionProviderRegistry._set_context_parallel( | |
| mesh=self.state.parallel_backend.get_mesh()["cp"], convert_to_fp32=True, rotate_method="allgather" | |
| ) | |
| _AttentionProviderRegistry._active_provider = provider | |
| # HACK: for VAE | |
| if "vae" in name_providers_dict: | |
| _apply_forward_hooks_hack(self.vae, name_providers_dict["vae"]) | |
| for module in module_providers_dict.keys(): | |
| registry = HookRegistry.check_if_exists_or_initialize(module) | |
| hook = LatestActiveModuleHook(callback) | |
| registry.register_hook(hook, _LATEST_ACTIVE_MODULE_HOOK) | |
| yield | |
| _AttentionProviderRegistry._active_provider = default_provider | |
| _AttentionProviderRegistry._set_context_parallel(reset=True) | |
| for module in module_providers_dict.keys(): | |
| registry: HookRegistry = module._diffusers_hook | |
| registry.remove_hook(_LATEST_ACTIVE_MODULE_HOOK) | |
| def _init_distributed(self) -> None: | |
| world_size = int(os.environ.get("WORLD_SIZE", torch.cuda.device_count())) | |
| # TODO(aryan): handle other backends | |
| backend_cls: parallel.ParallelBackendType = parallel.get_parallel_backend_cls(self.args.parallel_backend) | |
| self.state.parallel_backend = backend_cls( | |
| world_size=world_size, | |
| pp_degree=self.args.pp_degree, | |
| dp_degree=self.args.dp_degree, | |
| dp_shards=self.args.dp_shards, | |
| cp_degree=self.args.cp_degree, | |
| tp_degree=self.args.tp_degree, | |
| backend="nccl", | |
| timeout=self.args.init_timeout, | |
| logging_dir=self.args.logging_dir, | |
| output_dir=self.args.output_dir, | |
| gradient_accumulation_steps=self.args.gradient_accumulation_steps, | |
| ) | |
| if self.args.seed is not None: | |
| self.state.parallel_backend.enable_determinism(self.args.seed) | |
| def _init_logging(self) -> None: | |
| logging._set_parallel_backend(self.state.parallel_backend) | |
| logging.set_dependency_log_level(self.args.verbose, self.state.parallel_backend.is_local_main_process) | |
| logger.info("Initialized FineTrainers") | |
| def _init_trackers(self) -> None: | |
| # TODO(aryan): handle multiple trackers | |
| trackers = [self.args.report_to] | |
| experiment_name = self.args.tracker_name or "finetrainers-experiment" | |
| self.state.parallel_backend.initialize_trackers( | |
| trackers, experiment_name=experiment_name, config=self._get_training_info(), log_dir=self.args.logging_dir | |
| ) | |
| def _init_config_options(self) -> None: | |
| # Enable TF32 for faster training on Ampere GPUs: https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices | |
| if self.args.allow_tf32 and torch.cuda.is_available(): | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| torch.set_float32_matmul_precision(self.args.float32_matmul_precision) | |
| def tracker(self): | |
| return self.state.parallel_backend.tracker | |
| class LatestActiveModuleHook(ModelHook): | |
| def __init__(self, callback: Callable[[torch.nn.Module], None] = None): | |
| super().__init__() | |
| self.callback = callback | |
| def pre_forward(self, module, *args, **kwargs): | |
| self.callback(module) | |
| return args, kwargs | |
| def _parse_attention_providers(attn_providers: List[str] = None) -> List[Tuple[str, AttentionProvider]]: | |
| parsed_providers = [] | |
| if attn_providers: | |
| for provider_str in attn_providers: | |
| parts = provider_str.split(":") | |
| if len(parts) != 2: | |
| raise ValueError( | |
| f"Invalid attention provider format: '{provider_str}'. Expected 'module_name:provider_name'." | |
| ) | |
| parts[1] = AttentionProvider(parts[1]) | |
| parsed_providers.append(tuple(parts)) | |
| return parsed_providers | |
| # TODO(aryan): instead of this, we could probably just apply the hook to vae.children() as we know their forward methods will be invoked | |
| def _apply_forward_hooks_hack(module: torch.nn.Module, provider: AttentionProvider): | |
| if hasattr(module, "_finetrainers_wrapped_methods"): | |
| return | |
| def create_wrapper(old_method): | |
| def wrapper(*args, **kwargs): | |
| _AttentionProviderRegistry._set_context_parallel(reset=True) # HACK: needs improvement | |
| old_provider = _AttentionProviderRegistry._active_provider | |
| _AttentionProviderRegistry._active_provider = provider | |
| output = old_method(*args, **kwargs) | |
| _AttentionProviderRegistry._active_provider = old_provider | |
| return output | |
| return wrapper | |
| methods = ["encode", "decode", "_encode", "_decode", "tiled_encode", "tiled_decode"] | |
| finetrainers_wrapped_methods = [] | |
| for method_name in methods: | |
| if not hasattr(module, method_name): | |
| continue | |
| method = getattr(module, method_name) | |
| wrapper = create_wrapper(method) | |
| setattr(module, method_name, wrapper) | |
| finetrainers_wrapped_methods.append(method_name) | |
| module._finetrainers_wrapped_methods = finetrainers_wrapped_methods | |