Spaces:
Build error
Build error
| import time | |
| from dataclasses import dataclass | |
| from enum import Enum | |
| import torch | |
| from finetrainers.constants import FINETRAINERS_ENABLE_TIMING | |
| from finetrainers.logging import get_logger | |
| logger = get_logger() | |
| class TimerDevice(str, Enum): | |
| CPU = "cpu" | |
| CUDA = "cuda" | |
| class TimerData: | |
| name: str | |
| device: TimerDevice | |
| start_time: float = 0.0 | |
| end_time: float = 0.0 | |
| class Timer: | |
| def __init__(self, name: str, device: TimerDevice, device_sync: bool = False): | |
| self.data = TimerData(name=name, device=device) | |
| self._device_sync = device_sync | |
| self._start_event = None | |
| self._end_event = None | |
| self._active = False | |
| self._enabled = FINETRAINERS_ENABLE_TIMING | |
| def __enter__(self): | |
| self.start() | |
| return self | |
| def __exit__(self, exc_type, exc_val, exc_tb): | |
| self.end() | |
| return False | |
| def start(self): | |
| if self._active: | |
| logger.warning(f"Timer {self.data.name} is already running. Please stop it before starting again.") | |
| return | |
| self._active = True | |
| if not self._enabled: | |
| return | |
| if self.data.device == TimerDevice.CUDA and torch.cuda.is_available(): | |
| self._start_cuda() | |
| else: | |
| self._start_cpu() | |
| if not self.data.device == TimerDevice.CPU: | |
| logger.warning( | |
| f"Timer device {self.data.device} is either not supported or incorrect device selected. Falling back to CPU." | |
| ) | |
| def end(self): | |
| if not self._active: | |
| logger.warning(f"Timer {self.data.name} is not running. Please start it before stopping.") | |
| return | |
| self._active = False | |
| if not self._enabled: | |
| return | |
| if self.data.device == TimerDevice.CUDA and torch.cuda.is_available(): | |
| self._end_cuda() | |
| else: | |
| self._end_cpu() | |
| if not self.data.device == TimerDevice.CPU: | |
| logger.warning( | |
| f"Timer device {self.data.device} is either not supported or incorrect device selected. Falling back to CPU." | |
| ) | |
| def elapsed_time(self) -> float: | |
| if self._active: | |
| if self.data.device == TimerDevice.CUDA and torch.cuda.is_available(): | |
| premature_end_event = torch.cuda.Event(enable_timing=True) | |
| premature_end_event.record() | |
| premature_end_event.synchronize() | |
| return self._start_event.elapsed_time(premature_end_event) / 1000.0 | |
| else: | |
| return time.time() - self.data.start_time | |
| else: | |
| if self.data.device == TimerDevice.CUDA and torch.cuda.is_available(): | |
| return self._start_event.elapsed_time(self._end_event) / 1000.0 | |
| else: | |
| return self.data.end_time - self.data.start_time | |
| def _start_cpu(self): | |
| self.data.start_time = time.time() | |
| def _start_cuda(self): | |
| torch.cuda.synchronize() | |
| self._start_event = torch.cuda.Event(enable_timing=True) | |
| self._end_event = torch.cuda.Event(enable_timing=True) | |
| self._start_event.record() | |
| def _end_cpu(self): | |
| self.data.end_time = time.time() | |
| def _end_cuda(self): | |
| if self._device_sync: | |
| torch.cuda.synchronize() | |
| self._end_event.record() | |