Rohanify commited on
Commit
fbff8ee
·
verified ·
1 Parent(s): 100f59f

Delete model.py

Browse files
Files changed (1) hide show
  1. model.py +0 -62
model.py DELETED
@@ -1,62 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- from transformers import PretrainedConfig, PreTrainedModel
4
-
5
- # 1. Custom configuration class required by Hugging Face auto_map
6
- class TransformConfig(PretrainedConfig):
7
- model_type = "brawnz_style_transfer"
8
- def __init__(self, **kwargs):
9
- super().__init__(**kwargs)
10
-
11
- # 2. Network helper blocks matching your training script
12
- def conv_bn_relu(in_c, out_c, k, stride=1, pad=0):
13
- return nn.Sequential(
14
- nn.ReflectionPad2d(pad),
15
- nn.Conv2d(in_c, out_c, k, stride),
16
- nn.InstanceNorm2d(out_c),
17
- nn.ReLU(inplace=True),
18
- )
19
-
20
- class ResBlock(nn.Module):
21
- def __init__(self, c):
22
- super().__init__()
23
- self.block = nn.Sequential(
24
- nn.ReflectionPad2d(1),
25
- nn.Conv2d(c, c, 3),
26
- nn.InstanceNorm2d(c),
27
- nn.ReLU(inplace=True),
28
- nn.ReflectionPad2d(1),
29
- nn.Conv2d(c, c, 3),
30
- nn.InstanceNorm2d(c),
31
- )
32
-
33
- def forward(self, x):
34
- return x + self.block(x)
35
-
36
- # 3. Main Network Class hooked into Hugging Face structures
37
- class TransformNet(PreTrainedModel):
38
- config_class = TransformConfig
39
-
40
- def __init__(self, config=None):
41
- # Fallback to default configuration if none is passed
42
- if config is None:
43
- config = TransformConfig()
44
- super().__init__(config)
45
-
46
- self.net = nn.Sequential(
47
- conv_bn_relu(3, 32, 9, pad=4),
48
- conv_bn_relu(32, 64, 3, stride=2, pad=1),
49
- conv_bn_relu(64, 128, 3, stride=2, pad=1),
50
- ResBlock(128), ResBlock(128), ResBlock(128),
51
- ResBlock(128), ResBlock(128),
52
- nn.Upsample(scale_factor=2, mode="nearest"),
53
- conv_bn_relu(128, 64, 3, pad=1),
54
- nn.Upsample(scale_factor=2, mode="nearest"),
55
- conv_bn_relu(64, 32, 3, pad=1),
56
- nn.ReflectionPad2d(4),
57
- nn.Conv2d(32, 3, 9),
58
- nn.Tanh(),
59
- )
60
-
61
- def forward(self, x):
62
- return self.net(x)