| import copy |
| import pickle |
| import torch |
|
|
| class EnhancerDataset(torch.utils.data.Dataset): |
| def __init__(self, mel_enhancer=True, split='train'): |
| all_data = pickle.load(open(f'./dataset/enhancer_data/Deep{"MEL2" if mel_enhancer else "FlyBrain"}_data.pkl', 'rb')) |
| self.seqs = torch.argmax(torch.from_numpy(copy.deepcopy(all_data[f'{split}_data'])), dim=-1) |
| self.clss = torch.argmax(torch.from_numpy(copy.deepcopy(all_data[f'y_{split}'])), dim=-1) |
| self.num_cls = all_data[f'y_{split}'].shape[-1] |
| self.alphabet_size = 4 |
|
|
| def __len__(self): |
| return len(self.seqs) |
|
|
| def __getitem__(self, idx): |
| return self.seqs[idx], self.clss[idx] |
|
|