Spaces:
Build error
Build error
| from typing import Any, Dict, List, Tuple | |
| import torch | |
| class ResolutionSampler: | |
| def __init__(self, batch_size: int = 1, dim_keys: Dict[str, Tuple[int, ...]] = None) -> None: | |
| self.batch_size = batch_size | |
| self.dim_keys = dim_keys | |
| assert dim_keys is not None, "dim_keys must be provided" | |
| self._chosen_leader_key = None | |
| self._unsatisfied_buckets: Dict[Tuple[int, ...], List[Dict[Any, Any]]] = {} | |
| self._satisfied_buckets: List[Dict[Any, Any]] = [] | |
| def consume(self, *dict_items: Dict[Any, Any]) -> None: | |
| if self._chosen_leader_key is None: | |
| self._determine_leader_item(*dict_items) | |
| self._update_buckets(*dict_items) | |
| def get_batch(self) -> List[Dict[str, Any]]: | |
| return list(zip(*self._satisfied_buckets.pop(-1))) | |
| def is_ready(self) -> bool: | |
| return len(self._satisfied_buckets) > 0 | |
| def _determine_leader_item(self, *dict_items: Dict[Any, Any]) -> None: | |
| num_observed = 0 | |
| for dict_item in dict_items: | |
| for key in self.dim_keys.keys(): | |
| if key in dict_item.keys(): | |
| self._chosen_leader_key = key | |
| if not torch.is_tensor(dict_item[key]): | |
| raise ValueError(f"Leader key {key} must be a tensor") | |
| num_observed += 1 | |
| if num_observed > 1: | |
| raise ValueError( | |
| f"Only one leader key is allowed in provided list of data dictionaries. Found {num_observed} leader keys" | |
| ) | |
| if self._chosen_leader_key is None: | |
| raise ValueError("No leader key found in provided list of data dictionaries") | |
| def _update_buckets(self, *dict_items: Dict[Any, Any]) -> None: | |
| chosen_value = [ | |
| dict_item[self._chosen_leader_key] | |
| for dict_item in dict_items | |
| if self._chosen_leader_key in dict_item.keys() | |
| ] | |
| if len(chosen_value) == 0: | |
| raise ValueError(f"Leader key {self._chosen_leader_key} not found in provided list of data dictionaries") | |
| chosen_value = chosen_value[0] | |
| dims = tuple(chosen_value.size(x) for x in self.dim_keys[self._chosen_leader_key]) | |
| if dims not in self._unsatisfied_buckets: | |
| self._unsatisfied_buckets[dims] = [] | |
| self._unsatisfied_buckets[dims].append(dict_items) | |
| if len(self._unsatisfied_buckets[dims]) == self.batch_size: | |
| self._satisfied_buckets.append(self._unsatisfied_buckets.pop(dims)) | |