| | import ast |
| | import json |
| | import logging |
| | import math |
| | import os |
| | import random |
| | import h5py |
| | from dataclasses import dataclass |
| | from models.CLAP.training.params import parse_args |
| | import braceexpand |
| | import numpy as np |
| | import pandas as pd |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | import torchvision.datasets as datasets |
| | import torchvision.transforms |
| | import webdataset as wds |
| | from PIL import Image |
| | from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler |
| | from torch.utils.data.distributed import DistributedSampler |
| | from functools import partial |
| | import soundfile as sf |
| | import io |
| | from pathlib import Path |
| | import wget |
| |
|
| | from models.CLAP.open_clip.utils import get_tar_path_from_dataset_name, dataset_split |
| | from models.CLAP.open_clip.utils import load_p, load_class_label |
| | import tempfile |
| | import copy |
| |
|
| | try: |
| | import horovod.torch as hvd |
| | except ImportError: |
| | hvd = None |
| |
|
| | try: |
| | import torchaudio |
| | except ImportError: |
| | torchaudio = None |
| |
|
| | from models.CLAP.open_clip import tokenize |
| |
|
| |
|
| | def tokenizer(text): |
| | return tokenize(text).squeeze(0) |
| |
|
| |
|
| | from transformers import RobertaTokenizer |
| |
|
| | tokenize = RobertaTokenizer.from_pretrained("roberta-base") |
| |
|
| |
|
| | def tokenizer(text): |
| | result = tokenize( |
| | text, |
| | padding="max_length", |
| | truncation=True, |
| | max_length=77, |
| | return_tensors="pt", |
| | ) |
| | return {k: v.squeeze(0) for k, v in result.items()} |
| |
|
| |
|
| | |
| | _AUDIOSET_MAP_PATH = os.path.join(Path(__file__).parent, "audioset_textmap.npy") |
| | _AUDIOSET_MAP = np.load(_AUDIOSET_MAP_PATH, allow_pickle=True) |
| |
|
| |
|
| | def int16_to_float32(x): |
| | return (x / 32767.0).astype(np.float32) |
| |
|
| |
|
| | def float32_to_int16(x): |
| | x = np.clip(x, a_min=-1.0, a_max=1.0) |
| | return (x * 32767.0).astype(np.int16) |
| |
|
| |
|
| | |
| | class ToyDataset(Dataset): |
| | def __init__(self, index_path, ipc, config, eval_mode=False): |
| | """Toy Dataset for testing the audioset input with text labels |
| | Parameters |
| | ---------- |
| | index_path: str |
| | the link to the h5 file of each audio |
| | idc: str |
| | the link to the npy file, the number of samples in each class |
| | config: dict |
| | the audio cfg file |
| | eval_model (bool): to indicate if the dataset is a testing dataset |
| | """ |
| | self.audio_cfg = config["audio_cfg"] |
| | self.text_cfg = config["text_cfg"] |
| | self.fp = h5py.File(index_path, "r") |
| | self.ipc = np.load(ipc, allow_pickle=True) |
| | self.total_size = len(self.fp["audio_name"]) |
| | self.classes_num = self.audio_cfg["class_num"] |
| | self.eval_mode = eval_mode |
| |
|
| | if not eval_mode: |
| | self.generate_queue() |
| | else: |
| | self.queue = [] |
| | for i in range(self.total_size): |
| | target = self.fp["target"][i] |
| | if np.sum(target) > 0: |
| | self.queue.append(i) |
| | self.total_size = len(self.queue) |
| | logging.info("total dataset size: %d" % (self.total_size)) |
| | logging.info("class num: %d" % (self.classes_num)) |
| |
|
| | def time_shifting(self, x): |
| | frame_num = len(x) |
| | shift_len = random.randint(0, frame_num - 1) |
| | new_sample = np.concatenate([x[shift_len:], x[:shift_len]], axis=0) |
| | return new_sample |
| |
|
| | def generate_queue(self): |
| | self.queue = [] |
| | while len(self.queue) < self.total_size: |
| | class_set = [*range(self.classes_num)] |
| | random.shuffle(class_set) |
| | self.queue += [ |
| | self.ipc[d][random.randint(0, len(self.ipc[d]) - 1)] for d in class_set |
| | ] |
| | self.queue = self.queue[: self.total_size] |
| |
|
| | logging.info("queue regenerated:%s" % (self.queue[-5:])) |
| |
|
| | def crop_wav(self, x): |
| | crop_size = self.audio_cfg["crop_size"] |
| | crop_pos = random.randint(0, len(x) - crop_size - 1) |
| | return x[crop_pos : crop_pos + crop_size] |
| |
|
| | def prompt_text(self, target): |
| | events = _AUDIOSET_MAP[np.where(target > 0)] |
| | event_text = "The sounds of " + ", ".join(events[:-1]) + " and " + events[-1] |
| | text = tokenize(event_text)[0] |
| | return text |
| |
|
| | def __getitem__(self, index): |
| | """Load waveform, text, and target of an audio clip |
| | |
| | Parameters |
| | ---------- |
| | index: int |
| | the index number |
| | Return |
| | ------ |
| | output: dict { |
| | "hdf5_path": str, |
| | "index_in_hdf5": int, |
| | "audio_name": str, |
| | "waveform": list (audio_length,), |
| | "target": list (class_num, ), |
| | "text": torch.tensor (context_length,) |
| | } |
| | the output dictionary |
| | """ |
| | s_index = self.queue[index] |
| |
|
| | audio_name = self.fp["audio_name"][s_index].decode() |
| | |
| | hdf5_path = ( |
| | self.fp["hdf5_path"][s_index] |
| | .decode() |
| | .replace( |
| | "../workspace", |
| | "/home/la/kechen/Research/ke_zsasp/workspace", |
| | ) |
| | ) |
| | r_idx = self.fp["index_in_hdf5"][s_index] |
| | target = self.fp["target"][s_index].astype(np.float32) |
| | text = self.prompt_text(target) |
| | with h5py.File(hdf5_path, "r") as f: |
| | waveform = int16_to_float32(f["waveform"][r_idx])[ |
| | : self.audio_cfg["clip_samples"] |
| | ] |
| | assert ( |
| | len(waveform) == self.audio_cfg["clip_samples"] |
| | ), "The sample length is not match" |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | mel_spec = get_mel(torch.from_numpy(waveform), self.audio_cfg)[None, :, :] |
| | mel_spec = ( |
| | torch.cat( |
| | [mel_spec, mel_spec.clone(), mel_spec.clone(), mel_spec.clone()], dim=0 |
| | ) |
| | .cpu() |
| | .numpy() |
| | ) |
| | longer = random.choice([True, False]) |
| | if longer == False: |
| | mel_spec[1:, :, :] = 0.0 |
| | data_dict = { |
| | "hdf5_path": hdf5_path, |
| | "index_in_hdf5": r_idx, |
| | "audio_name": audio_name, |
| | "waveform": waveform, |
| | "class_label": target, |
| | "text": text, |
| | "longer": longer, |
| | "mel_fusion": mel_spec, |
| | } |
| | return data_dict |
| |
|
| | def __len__(self): |
| | return self.total_size |
| |
|
| |
|
| | class CsvDataset(Dataset): |
| | def __init__(self, input_filename, transforms, img_key, caption_key, sep="\t"): |
| | logging.debug(f"Loading csv data from {input_filename}.") |
| | df = pd.read_csv(input_filename, sep=sep) |
| |
|
| | self.images = df[img_key].tolist() |
| | self.captions = df[caption_key].tolist() |
| | self.transforms = transforms |
| | logging.debug("Done loading data.") |
| |
|
| | def __len__(self): |
| | return len(self.captions) |
| |
|
| | def __getitem__(self, idx): |
| | images = self.transforms(Image.open(str(self.images[idx]))) |
| | texts = tokenize([str(self.captions[idx])])[0] |
| | return images, texts |
| |
|
| |
|
| | @dataclass |
| | class DataInfo: |
| | dataloader: DataLoader |
| | sampler: DistributedSampler |
| |
|
| |
|
| | def preprocess_txt(text): |
| | return tokenize([str(text)])[0] |
| |
|
| |
|
| | def get_dataset_size(shards, sizefilepath_=None, is_local=True): |
| | if isinstance(shards, list): |
| | size_list = [] |
| | for s in shards: |
| | size_list.append( |
| | get_dataset_size(s, sizefilepath_=sizefilepath_, is_local=is_local)[0] |
| | ) |
| | else: |
| | if not is_local: |
| | for n in dataset_split.keys(): |
| | if n in shards.split("/"): |
| | break |
| | for s in dataset_split[n]: |
| | if s in shards.split("/"): |
| | break |
| | sizefilepath_ = f"./json_files/{n}/{s}/sizes.json" |
| | shards_list = list(braceexpand.braceexpand(shards)) |
| | dir_path = os.path.dirname(shards) |
| | if sizefilepath_ is not None: |
| | sizes = json.load(open(sizefilepath_, "r")) |
| | total_size = sum( |
| | [ |
| | int(sizes[os.path.basename(shard.replace(".tar -", ".tar"))]) |
| | for shard in shards_list |
| | ] |
| | ) |
| | else: |
| | sizes_filename = os.path.join(dir_path, "sizes.json") |
| | len_filename = os.path.join(dir_path, "__len__") |
| | if os.path.exists(sizes_filename): |
| | sizes = json.load(open(sizes_filename, "r")) |
| | total_size = sum( |
| | [int(sizes[os.path.basename(shard)]) for shard in shards_list] |
| | ) |
| | elif os.path.exists(len_filename): |
| | |
| | total_size = ast.literal_eval(open(len_filename, "r").read()) |
| | else: |
| | raise Exception( |
| | "Cannot find sizes file for dataset. Please specify the path to the file." |
| | ) |
| | |
| | |
| | |
| | |
| | |
| | num_shards = len(shards_list) |
| | if isinstance(shards, list): |
| | return sum(size_list), len(shards) |
| | else: |
| | return total_size, num_shards |
| |
|
| |
|
| | def get_imagenet(args, preprocess_fns, split): |
| | assert split in ["train", "val", "v2"] |
| | is_train = split == "train" |
| | preprocess_train, preprocess_val = preprocess_fns |
| |
|
| | if split == "v2": |
| | from imagenetv2_pytorch import ImageNetV2Dataset |
| |
|
| | dataset = ImageNetV2Dataset(location=args.imagenet_v2, transform=preprocess_val) |
| | else: |
| | if is_train: |
| | data_path = args.imagenet_train |
| | preprocess_fn = preprocess_train |
| | else: |
| | data_path = args.imagenet_val |
| | preprocess_fn = preprocess_val |
| | assert data_path |
| |
|
| | dataset = datasets.ImageFolder(data_path, transform=preprocess_fn) |
| |
|
| | if is_train: |
| | idxs = np.zeros(len(dataset.targets)) |
| | target_array = np.array(dataset.targets) |
| | k = 50 |
| | for c in range(1000): |
| | m = target_array == c |
| | n = len(idxs[m]) |
| | arr = np.zeros(n) |
| | arr[:k] = 1 |
| | np.random.shuffle(arr) |
| | idxs[m] = arr |
| |
|
| | idxs = idxs.astype("int") |
| | sampler = SubsetRandomSampler(np.where(idxs)[0]) |
| | else: |
| | sampler = None |
| |
|
| | dataloader = torch.utils.data.DataLoader( |
| | dataset, |
| | batch_size=args.batch_size, |
| | num_workers=args.workers, |
| | sampler=sampler, |
| | ) |
| |
|
| | return DataInfo(dataloader, sampler) |
| |
|
| |
|
| | def count_samples(dataloader): |
| | os.environ["WDS_EPOCH"] = "0" |
| | n_elements, n_batches = 0, 0 |
| | for images, texts in dataloader: |
| | n_batches += 1 |
| | n_elements += len(images) |
| | assert len(images) == len(texts) |
| | return n_elements, n_batches |
| |
|
| |
|
| | def filter_no_caption(sample): |
| | return "txt" in sample |
| |
|
| |
|
| | def log_and_continue(exn): |
| | """Call in an exception handler to ignore any exception, isssue a warning, and continue.""" |
| | logging.warning(f"Handling webdataset error ({repr(exn)}). Ignoring.") |
| | return True |
| |
|
| |
|
| | _SHARD_SHUFFLE_SIZE = 2000 |
| | _SHARD_SHUFFLE_INITIAL = 500 |
| | _SAMPLE_SHUFFLE_SIZE = 5000 |
| | _SAMPLE_SHUFFLE_INITIAL = 1000 |
| |
|
| |
|
| | def sample_prop(sizefile, inputs, proportion, is_local=True): |
| | """ |
| | Sample a proportion of the data. |
| | """ |
| | file_path_dict = { |
| | os.path.split(inputs[i])[1]: os.path.split(inputs[i])[0] |
| | for i in range(len(inputs)) |
| | } |
| | sampled_filepath_dict = {} |
| | sampled_size_dict = {} |
| | if not is_local: |
| | if os.path.exists("sizes.json"): |
| | os.remove("sizes.json") |
| | wget.download(sizefile, "sizes.json") |
| | sizefile = "sizes.json" |
| | with open(sizefile, "r", encoding="UTF-8") as f: |
| | load_dict = json.load(f) |
| | L = int(len(file_path_dict) * proportion) |
| | subkeys = random.sample(file_path_dict.keys(), L) |
| | for k in subkeys: |
| | sampled_size_dict[k] = load_dict[k] |
| | sampled_filepath_dict[k] = file_path_dict[k] |
| | return ( |
| | sum(sampled_size_dict.values()), |
| | L, |
| | [os.path.join(v, k) for k, v in sampled_filepath_dict.items()], |
| | sampled_size_dict, |
| | ) |
| |
|
| |
|
| | def get_mel(audio_data, audio_cfg): |
| | |
| | mel = torchaudio.transforms.MelSpectrogram( |
| | sample_rate=audio_cfg["sample_rate"], |
| | n_fft=audio_cfg["window_size"], |
| | win_length=audio_cfg["window_size"], |
| | hop_length=audio_cfg["hop_size"], |
| | center=True, |
| | pad_mode="reflect", |
| | power=2.0, |
| | norm=None, |
| | onesided=True, |
| | n_mels=64, |
| | f_min=audio_cfg["fmin"], |
| | f_max=audio_cfg["fmax"], |
| | ).to(audio_data.device) |
| | mel = mel(audio_data) |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | mel = torchaudio.transforms.AmplitudeToDB(top_db=None)(mel) |
| | return mel.T |
| |
|
| |
|
| | def get_audio_features( |
| | sample, audio_data, max_len, data_truncating, data_filling, audio_cfg |
| | ): |
| | """ |
| | Calculate and add audio features to sample. |
| | Sample: a dict containing all the data of current sample. |
| | audio_data: a tensor of shape (T) containing audio data. |
| | max_len: the maximum length of audio data. |
| | data_truncating: the method of truncating data. |
| | data_filling: the method of filling data. |
| | audio_cfg: a dict containing audio configuration. Comes from model_cfg['audio_cfg']. |
| | """ |
| | with torch.no_grad(): |
| | if len(audio_data) > max_len: |
| | if data_truncating == "rand_trunc": |
| | longer = torch.tensor([True]) |
| | elif data_truncating == "fusion": |
| | |
| | mel = get_mel(audio_data, audio_cfg) |
| | |
| | chunk_frames = ( |
| | max_len // audio_cfg["hop_size"] + 1 |
| | ) |
| | total_frames = mel.shape[0] |
| | if chunk_frames == total_frames: |
| | |
| | |
| | |
| | mel_fusion = torch.stack([mel, mel, mel, mel], dim=0) |
| | sample["mel_fusion"] = mel_fusion |
| | longer = torch.tensor([False]) |
| | else: |
| | ranges = np.array_split( |
| | list(range(0, total_frames - chunk_frames + 1)), 3 |
| | ) |
| | |
| | |
| | |
| | |
| | if len(ranges[1]) == 0: |
| | |
| | ranges[1] = [0] |
| | if len(ranges[2]) == 0: |
| | |
| | ranges[2] = [0] |
| | |
| | idx_front = np.random.choice(ranges[0]) |
| | idx_middle = np.random.choice(ranges[1]) |
| | idx_back = np.random.choice(ranges[2]) |
| | |
| | mel_chunk_front = mel[idx_front : idx_front + chunk_frames, :] |
| | mel_chunk_middle = mel[idx_middle : idx_middle + chunk_frames, :] |
| | mel_chunk_back = mel[idx_back : idx_back + chunk_frames, :] |
| |
|
| | |
| | mel_shrink = torchvision.transforms.Resize(size=[chunk_frames, 64])( |
| | mel[None] |
| | )[0] |
| | |
| |
|
| | |
| | mel_fusion = torch.stack( |
| | [mel_chunk_front, mel_chunk_middle, mel_chunk_back, mel_shrink], |
| | dim=0, |
| | ) |
| | sample["mel_fusion"] = mel_fusion |
| | longer = torch.tensor([True]) |
| | else: |
| | raise NotImplementedError( |
| | f"data_truncating {data_truncating} not implemented" |
| | ) |
| | |
| | overflow = len(audio_data) - max_len |
| | idx = np.random.randint(0, overflow + 1) |
| | audio_data = audio_data[idx : idx + max_len] |
| |
|
| | else: |
| | if len(audio_data) < max_len: |
| | if data_filling == "repeatpad": |
| | n_repeat = int(max_len / len(audio_data)) |
| | audio_data = audio_data.repeat(n_repeat) |
| | |
| | |
| | audio_data = F.pad( |
| | audio_data, |
| | (0, max_len - len(audio_data)), |
| | mode="constant", |
| | value=0, |
| | ) |
| | elif data_filling == "pad": |
| | audio_data = F.pad( |
| | audio_data, |
| | (0, max_len - len(audio_data)), |
| | mode="constant", |
| | value=0, |
| | ) |
| | elif data_filling == "repeat": |
| | n_repeat = int(max_len / len(audio_data)) |
| | audio_data = audio_data.repeat(n_repeat + 1)[:max_len] |
| | else: |
| | raise NotImplementedError( |
| | f"data_filling {data_filling} not implemented" |
| | ) |
| | if data_truncating == "fusion": |
| | mel = get_mel(audio_data, audio_cfg) |
| | mel_fusion = torch.stack([mel, mel, mel, mel], dim=0) |
| | sample["mel_fusion"] = mel_fusion |
| | longer = torch.tensor([False]) |
| |
|
| | sample["longer"] = longer |
| | sample["waveform"] = audio_data |
| |
|
| | return sample |
| |
|
| |
|
| | def preprocess( |
| | sample, |
| | audio_ext, |
| | text_ext, |
| | max_len, |
| | audio_cfg, |
| | class_index_dict=None, |
| | data_filling="pad", |
| | data_truncating="rand_trunc", |
| | text_augment_selection=None, |
| | ): |
| | """ |
| | Preprocess a single sample for wdsdataloader. |
| | """ |
| | audio_data, orig_sr = sf.read(io.BytesIO(sample[audio_ext])) |
| | audio_data = int16_to_float32(float32_to_int16(audio_data)) |
| | audio_data = torch.tensor(audio_data).float() |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | sample = get_audio_features( |
| | sample, audio_data, max_len, data_truncating, data_filling, audio_cfg |
| | ) |
| | del sample[audio_ext] |
| |
|
| | try: |
| | json_dict_raw = json.loads(sample[text_ext].decode("utf-8")) |
| | except: |
| | print("sample[__url__]:", sample["__url__"]) |
| |
|
| | |
| | if text_augment_selection is None or text_augment_selection == "none": |
| | texts = json_dict_raw["text"] |
| | elif text_augment_selection == "all": |
| | if "text_augment_all" in json_dict_raw.keys(): |
| | texts = json_dict_raw["text_augment_all"] |
| | else: |
| | texts = json_dict_raw["text"] |
| | elif text_augment_selection == "augment_only": |
| | if "text_augment_all" in json_dict_raw.keys(): |
| | if json_dict_raw["text_augment_t5"] is None: |
| | texts = json_dict_raw["text"] |
| | else: |
| | texts = json_dict_raw["text_augment_t5"] |
| | else: |
| | texts = json_dict_raw["text"] |
| | else: |
| | raise NotImplementedError( |
| | f"text_augment_selection {text_augment_selection} not implemented" |
| | ) |
| | sample["full_text"] = texts |
| |
|
| | if isinstance(texts, list) and isinstance(texts[0], str) and len(texts) > 1: |
| | texts = random.choice(texts) |
| | sample["raw_text"] = texts |
| | sample["text"] = tokenizer(texts) |
| | if class_index_dict is not None: |
| | |
| | |
| | |
| | |
| | |
| | sample["class_label"] = np.zeros(len(class_index_dict.keys())) |
| | for x in json_dict_raw["tag"]: |
| | sample["class_label"][class_index_dict[x]] = 1 |
| | sample["class_label"] = torch.tensor(sample["class_label"]).float() |
| | del sample[text_ext] |
| | sample["audio_name"] = sample["__key__"].split("/")[-1] + "." + audio_ext |
| | sample["text_name"] = sample["__key__"].split("/")[-1] + "." + text_ext |
| | sample["audio_orig_sr"] = orig_sr |
| | return sample |
| |
|
| |
|
| | def collate_fn(batch): |
| | """ |
| | Collate function for wdsdataloader. |
| | batch: a list of dict, each dict is a sample |
| | """ |
| | |
| | batch_dict = {} |
| | for k in batch[0].keys(): |
| | if isinstance(batch[0][k], dict): |
| | batch_dict[k] = {} |
| | for kk in batch[0][k].keys(): |
| | tmp = [] |
| | for i in range(len(batch)): |
| | tmp.append(batch[i][k][kk]) |
| | batch_dict[k][kk] = torch.vstack(tmp) |
| | elif isinstance(batch[0][k], torch.Tensor): |
| | batch_dict[k] = torch.stack([sample[k] for sample in batch]) |
| | elif isinstance(batch[0][k], np.ndarray): |
| | batch_dict[k] = torch.tensor(np.stack([sample[k] for sample in batch])) |
| | else: |
| | batch_dict[k] = [sample[k] for sample in batch] |
| | return batch_dict |
| |
|
| |
|
| | def get_wds_dataset( |
| | args, |
| | model_cfg, |
| | is_train, |
| | audio_ext="flac", |
| | text_ext="json", |
| | max_len=480000, |
| | proportion=1.0, |
| | sizefilepath_=None, |
| | is_local=None, |
| | ): |
| | """ |
| | Get a dataset for wdsdataloader. |
| | """ |
| | if is_local is None and (not args.remotedata is None): |
| | is_local = not args.remotedata |
| |
|
| | input_shards = args.train_data if is_train else args.val_data |
| | assert input_shards is not None |
| |
|
| | if not sizefilepath_ is None: |
| | sizefilepath = sizefilepath_ |
| | else: |
| | sizefilepath = os.path.join(os.path.dirname(input_shards[0]), "sizes.json") |
| |
|
| | if proportion != 1.0: |
| | num_samples, num_shards, input_shards, _ = sample_prop( |
| | sizefilepath, input_shards, proportion, is_local=is_local |
| | ) |
| | else: |
| | num_samples, num_shards = get_dataset_size( |
| | input_shards, sizefilepath_=sizefilepath_, is_local=is_local |
| | ) |
| |
|
| | if not num_samples: |
| | if is_train: |
| | num_samples = args.train_num_samples |
| | if not num_samples: |
| | raise RuntimeError( |
| | "Currently, number of dataset samples must be specified for training dataset. " |
| | "Please specify via `--train-num-samples` if no dataset length info present." |
| | ) |
| | else: |
| | num_samples = ( |
| | args.val_num_samples or 0 |
| | ) |
| |
|
| | pipeline = [wds.SimpleShardList(input_shards)] |
| | |
| | |
| | if is_train or args.parallel_eval: |
| | pipeline.extend( |
| | [ |
| | wds.detshuffle( |
| | bufsize=_SHARD_SHUFFLE_SIZE, |
| | initial=_SHARD_SHUFFLE_INITIAL, |
| | seed=args.seed, |
| | ), |
| | wds.split_by_node, |
| | wds.split_by_worker, |
| | |
| | wds.tarfile_to_samples(handler=log_and_continue), |
| | wds.shuffle( |
| | bufsize=_SAMPLE_SHUFFLE_SIZE, |
| | initial=_SAMPLE_SHUFFLE_INITIAL, |
| | rng=random.Random(args.seed), |
| | ), |
| | |
| | ] |
| | ) |
| | else: |
| | pipeline.extend( |
| | [ |
| | wds.split_by_worker, |
| | |
| | wds.tarfile_to_samples(handler=log_and_continue), |
| | ] |
| | ) |
| | pipeline.append( |
| | wds.map( |
| | partial( |
| | preprocess, |
| | audio_ext=audio_ext, |
| | text_ext=text_ext, |
| | max_len=max_len, |
| | audio_cfg=model_cfg["audio_cfg"], |
| | class_index_dict=copy.deepcopy(args.class_index_dict), |
| | data_filling=args.data_filling, |
| | data_truncating=args.data_truncating, |
| | text_augment_selection=args.text_augment_selection, |
| | ) |
| | ), |
| | ) |
| |
|
| | pipeline.append( |
| | wds.batched( |
| | args.batch_size, |
| | partial=not (is_train or args.parallel_eval), |
| | collation_fn=collate_fn, |
| | ) |
| | ) |
| |
|
| | dataset = wds.DataPipeline(*pipeline) |
| | if is_train or args.parallel_eval: |
| | |
| | |
| | |
| | global_batch_size = args.batch_size * args.world_size |
| | num_batches = math.ceil(num_samples / global_batch_size) |
| | num_workers = max(1, args.workers) |
| | num_worker_batches = math.ceil( |
| | num_batches / num_workers |
| | ) |
| | num_batches = num_worker_batches * num_workers |
| | num_samples = num_batches * global_batch_size |
| | dataset = dataset.with_epoch( |
| | num_worker_batches |
| | ) |
| | else: |
| | |
| | num_batches = math.ceil(num_samples / args.batch_size) |
| |
|
| | kwargs = {} |
| | if args.horovod: |
| | kwargs["multiprocessing_context"] = "forkserver" |
| |
|
| | dataloader = wds.WebLoader( |
| | dataset, batch_size=None, shuffle=False, num_workers=args.workers, **kwargs |
| | ) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | dataloader.num_batches = num_batches |
| | dataloader.num_samples = num_samples |
| |
|
| | return DataInfo(dataloader, None) |
| |
|
| |
|
| | def wds_batch_list2dict( |
| | batch, |
| | keys=[ |
| | "__url__", |
| | "__key__", |
| | "waveform", |
| | "text", |
| | "raw_text", |
| | "audio_name", |
| | "text_name", |
| | "audio_orig_sr", |
| | ], |
| | ): |
| | """ |
| | Return a dictionary of the batch, with keys as the names of the fields. |
| | """ |
| | assert len(keys) == len( |
| | batch |
| | ), "batch must have same number of keys as keys argument" |
| | return {keys[i]: batch[i] for i in range(len(batch))} |
| |
|
| |
|
| | def get_csv_dataset(args, preprocess_fn, is_train): |
| | input_filename = args.train_data if is_train else args.val_data |
| | assert input_filename |
| | dataset = CsvDataset( |
| | input_filename, |
| | preprocess_fn, |
| | img_key=args.csv_img_key, |
| | caption_key=args.csv_caption_key, |
| | sep=args.csv_separator, |
| | ) |
| | num_samples = len(dataset) |
| | sampler = DistributedSampler(dataset) if args.distributed and is_train else None |
| | shuffle = is_train and sampler is None |
| |
|
| | dataloader = DataLoader( |
| | dataset, |
| | batch_size=args.batch_size, |
| | shuffle=shuffle, |
| | num_workers=args.workers, |
| | pin_memory=True, |
| | sampler=sampler, |
| | drop_last=is_train, |
| | ) |
| | dataloader.num_samples = num_samples |
| | dataloader.num_batches = len(dataloader) |
| |
|
| | return DataInfo(dataloader, sampler) |
| |
|
| |
|
| | def get_toy_dataset(args, model_cfg, is_train): |
| | index_path = args.train_data if is_train else args.val_data |
| | ipc_path = args.train_ipc if is_train else args.val_ipc |
| | assert index_path and ipc_path |
| | eval_mode = not is_train |
| | dataset = ToyDataset(index_path, ipc_path, model_cfg, eval_mode=eval_mode) |
| |
|
| | num_samples = len(dataset) |
| | sampler = ( |
| | DistributedSampler(dataset, shuffle=False) |
| | if args.distributed and is_train |
| | else None |
| | ) |
| |
|
| | dataloader = DataLoader( |
| | dataset, |
| | batch_size=args.batch_size, |
| | shuffle=False, |
| | num_workers=args.workers, |
| | sampler=sampler, |
| | drop_last=is_train, |
| | ) |
| | dataloader.num_samples = num_samples |
| | dataloader.num_batches = len(dataloader) |
| |
|
| | return DataInfo(dataloader, sampler) |
| |
|
| |
|
| | def get_dataset_fn(data_path, dataset_type): |
| | if dataset_type == "webdataset": |
| | return get_wds_dataset |
| | elif dataset_type == "csv": |
| | return get_csv_dataset |
| | elif dataset_type == "auto": |
| | ext = data_path.split(".")[-1] |
| | if ext in ["csv", "tsv"]: |
| | return get_csv_dataset |
| | elif ext in ["tar"]: |
| | return get_wds_dataset |
| | else: |
| | raise ValueError( |
| | f"Tried to figure out dataset type, but failed for extention {ext}." |
| | ) |
| | elif dataset_type == "toy": |
| | return get_toy_dataset |
| | else: |
| | raise ValueError(f"Unsupported dataset type: {dataset_type}") |
| |
|
| |
|
| | def get_data(args, model_cfg): |
| | data = {} |
| |
|
| | args.class_index_dict = load_class_label(args.class_label_path) |
| |
|
| | if args.datasetinfos is None: |
| | args.datasetinfos = ["train", "unbalanced_train", "balanced_train"] |
| | if args.dataset_type == "webdataset": |
| | args.train_data = get_tar_path_from_dataset_name( |
| | args.datasetnames, |
| | args.datasetinfos, |
| | islocal=not args.remotedata, |
| | proportion=args.dataset_proportion, |
| | dataset_path=args.datasetpath, |
| | full_dataset=args.full_train_dataset, |
| | ) |
| |
|
| | if args.full_train_dataset is None: |
| | args.full_train_dataset = [] |
| | if args.exclude_eval_dataset is None: |
| | args.exclude_eval_dataset = [] |
| | excluded_eval_datasets = args.full_train_dataset + args.exclude_eval_dataset |
| |
|
| | val_dataset_names = ( |
| | [n for n in args.datasetnames if n not in excluded_eval_datasets] |
| | if excluded_eval_datasets |
| | else args.datasetnames |
| | ) |
| | args.val_dataset_names = val_dataset_names |
| | args.val_data = get_tar_path_from_dataset_name( |
| | val_dataset_names, |
| | ["valid", "test", "eval"], |
| | islocal=not args.remotedata, |
| | proportion=1, |
| | dataset_path=args.datasetpath, |
| | full_dataset=None, |
| | ) |
| |
|
| | if args.train_data: |
| | data["train"] = get_dataset_fn(args.train_data, args.dataset_type)( |
| | args, model_cfg, is_train=True |
| | ) |
| |
|
| | if args.val_data: |
| | data["val"] = get_dataset_fn(args.val_data, args.dataset_type)( |
| | args, model_cfg, is_train=False |
| | ) |
| |
|
| | return data |
| |
|