File size: 4,602 Bytes
6478973 3f4f8b5 6478973 3f4f8b5 fb3e330 3f4f8b5 6906edc 3f4f8b5 fb3e330 3f4f8b5 fb3e330 3f4f8b5 fb3e330 3f4f8b5 fb3e330 3f4f8b5 fb3e330 3f4f8b5 fb3e330 3f4f8b5 fb3e330 3f4f8b5 fb3e330 3f4f8b5 fb3e330 3f4f8b5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 | ---
license: mit
pipeline_tag: image-to-image
tags:
- style-transfer
- pytorch
---
# Fast Neural Style Transfer β Starry Night
This repository contains weights for a Fast Neural Style Transfer network based on Johnson et al. It is trained on the COCO val2017 dataset to instantly apply Vincent van Gogh's *The Starry Night* style to any input image.
## Style Transfer Preview
| Content Image | Stylized Output |
| :---: | :---: |
| <img src="before.jpg" width="400"> | <img src="after.jpg" width="400"> |
---
## How to Use Programmatically
You can run inference using the official `huggingface_hub` utility library. The script automatically downloads your weights file directly from the cloud and applies the necessary ImageNet normalization matching the training routine.
### Dependencies
Ensure you have the required packages installed:
```bash
pip install torch torchvision pillow huggingface_hub
```
### Inference Script (`inference.py`)
Save the following code as `inference.py`. You can run it via terminal with `python inference.py your_image.jpg`.
```python
import sys
import torch
import torch.nn as nn
from PIL import Image
from torchvision import transforms
from torchvision.utils import save_image
from huggingface_hub import hf_hub_download
# ββ CONFIG βββββββββββββββββββββββββββββββββββββββββββββββββββ
REPO_ID = "Rohanify/Brawnz-StyleTransferSN"
FILENAME = "pytorch_model.bin"
IMG_SIZE = 512
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# ββ NATIVE PYTORCH NETWORK DEFINITION ββββββββββββββββββββββββ
def conv_bn_relu(in_c, out_c, k, stride=1, pad=0):
return nn.Sequential(
nn.ReflectionPad2d(pad),
nn.Conv2d(in_c, out_c, k, stride),
nn.InstanceNorm2d(out_c),
nn.ReLU(inplace=True),
)
class ResBlock(nn.Module):
def __init__(self, c):
super().__init__()
self.block = nn.Sequential(
nn.ReflectionPad2d(1),
nn.Conv2d(c, c, 3),
nn.InstanceNorm2d(c),
nn.ReLU(inplace=True),
nn.ReflectionPad2d(1),
nn.Conv2d(c, c, 3),
nn.InstanceNorm2d(c),
)
def forward(self, x):
return x + self.block(x)
class TransformNet(nn.Module):
def __init__(self):
super().__init__()
self.net = nn.Sequential(
conv_bn_relu(3, 32, 9, pad=4),
conv_bn_relu(32, 64, 3, stride=2, pad=1),
conv_bn_relu(64, 128, 3, stride=2, pad=1),
ResBlock(128), ResBlock(128), ResBlock(128),
ResBlock(128), ResBlock(128),
nn.Upsample(scale_factor=2, mode="nearest"),
conv_bn_relu(128, 64, 3, pad=1),
nn.Upsample(scale_factor=2, mode="nearest"),
conv_bn_relu(64, 32, 3, pad=1),
nn.ReflectionPad2d(4),
nn.Conv2d(32, 3, 9),
nn.Tanh(),
)
def forward(self, x):
return self.net(x)
# ββ LOAD INPUT IMAGE βββββββββββββββββββββββββββββββββββββββββ
if len(sys.argv) < 2:
print("Usage: python inference.py path_to_input_image.jpg")
sys.exit(1)
input_path = sys.argv[1]
output_path = "output_styled.jpg"
transform = transforms.Compose([
transforms.Resize((IMG_SIZE, IMG_SIZE)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
img = Image.open(input_path).convert("RGB")
x = transform(img).unsqueeze(0).to(DEVICE)
# ββ SECURE FILE DOWNLOAD & STATE LOAD ββββββββββββββββββββββββ
print("Downloading weights from Hugging Face Hub...")
weights_path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME)
model = TransformNet().to(DEVICE)
model.load_state_dict(torch.load(weights_path, map_location=DEVICE))
model.eval()
print(f"Weights successfully loaded on: {DEVICE}")
# ββ RUN INFERENCE ββββββββββββββββββββββββββββββββββββββββββββ
print("Processing style transfer...")
with torch.no_grad():
out = model(x)
save_image(out[0] * 0.5 + 0.5, output_path)
print(f"Success! Styled image saved to: {output_path}")
```
|