Mini-ImageNet / src /transforms /image_transform.py
ImAMJayKIM's picture
Upload 96 files
c1596ac verified
from torchvision import transforms
### Captioning Transform ###
def get_caption_transform():
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
return transform
### Classification Train Transform ###
def get_classification_train_transform():
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
return transform
### Classification Augmentation Transform ###
def get_classification_aug_transform():
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
return transform
### Classification Validation Transform ###
def get_classification_valid_transform():
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
return transform