Skip to content
Snippets Groups Projects
Commit da21fe21 authored by Simon Perche's avatar Simon Perche
Browse files

Initial commit

parents
Branches master
No related tags found
No related merge requests found
Showing
with 869 additions and 0 deletions
LICENSE 0 → 100755
MIT License
Copyright (c) 2023, Simon Perche
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
# Authoring Terrains with Spatialised Style
![](docs/teaser.jpg)
This repository contains the official code and resources for the paper *[Authoring Terrains with Spatialised Style](https://simonperche.github.io/authoring_terrains_with_spatialised_style)*.
# Structure
The repository is divided into three folders :
- data: where you can download data for the training with associated scripts
- generator: generator code based on StyleGAN2
- encoders: encoders code based on Pixel2Style2Pixel
Please train a generator before training encoders. You can also use [our pre-trained StyleGAN2 models](https://drive.google.com/drive/folders/1fuZDELhNvHpZ9k8VaqvUfICX2HyGDf1A?usp=drive_link).
Each folder contains instructions for running the code.
A Blender addon has been developed for this paper. Please find the code and pretrained models [here](https://gitlab.liris.cnrs.fr/sperche/styledem-blender-addon).
# Citation
If you find our work useful, please consider citing:
```
@article{perche2023spatialisedstyle,
author = {Perche, Simon and Peytavie, Adrien and Benes, Bedrich and Galin, Eric and Guérin, Eric},
title = {Authoring Terrains with Spatialised Style},
journal = {Pacific Graphics},
year = {2023},
}
```
## License
This project is licensed under the MIT License. See the [LICENSE](LICENSE) file for details.
# Authoring Terrains with Spatialised Style
This folder contains scripts to export downloaded data into 16-bit png files.
Due to licensing, we cannot provide direct download links.
Please note that for using pretrained models, data is not necessary.
# Usage
Three models were trained for our paper: 5 meters/pixel, 30 meters/pixel and 180 meters/pixel at 1024x1024 resolution.
## IGN (5 meters per pixel)
You will find the data on the [IGN website](https://geoservices.ign.fr/rgealti#telechargement5m). This website contains all regions of France at 5m/pixel 1000x1000 images.
The file format is .asc. You can convert this to .png using :
```
python asc2png.py -f=asc_folder/ -o=output_folder --width=1024 --height=1024 --normalize
```
## SRTM
### 30 meters per pixel
You will find the data on the [NASA website](https://search.earthdata.nasa.gov/search?q=SRTM). Select `NASA Shuttle Radar Topography Mission Global 1 arc second V003`. The images are 4000x4000 pixels at 30m/pixel.
The file format is .hgt. You can convert this to .png using :
```
python hgt2png.py -f=asc_folder/ -o=output_folder --width=1024 --height=1024 --normalize --crop
```
### 180 meters per pixel
Please follow the same steps as 30m/pixel but choose `NASA Shuttle Radar Topography Mission Global 3 arc second`. This will produce 90m/pixel image. To get 180m/pixel, downsample them by 2.
# Citation
If you find our work useful, please consider citing:
```
@article{perche2023spatialisedstyle,
author = {Perche, Simon and Peytavie, Adrien and Benes, Bedrich and Galin, Eric and Guérin, Eric},
title = {Authoring Terrains with Spatialised Style},
journal = {Pacific Graphics},
year = {2023},
}
```
## License
This project is licensed under the MIT License. See the [LICENSE](../LICENSE) file for details.
import os
import argparse
import shutil
import py7zr
import cv2
import numpy as np
def main():
parser = argparse.ArgumentParser(
'Convert asc files to png 16 bits images. To get float precision into integer, all values are multiplied by 6 before saving them as 16bits png.')
parser.add_argument('-f', '--folder', type=str, help='Folder of .asc files', default='')
parser.add_argument('--zip', type=str, help='Zipfile containing .asc files', default='')
parser.add_argument('-o', '--output', type=str, help='Output folder of .png files', default='out/')
parser.add_argument('--width', type=int, help='Width of the output image')
parser.add_argument('--height', type=int, help='Width of the output image')
parser.add_argument('--normalize', action='store_true', help='If present, normalize images before saving.')
args = parser.parse_args()
assert bool(args.folder) ^ bool(args.zip), 'Must set either --folder or --zip (not both)'
i = 0
step = 50
if args.folder:
for filename in os.listdir(args.folder):
file = os.path.join(args.folder, filename)
if os.path.join(file):
asc2png(file, args.output, args.width if args.width else None, args.height if args.height else None)
i += 1
if i % step == 0:
print(i)
elif args.zip:
with py7zr.SevenZipFile(args.zip) as z:
names = [n for n in z.getnames() if n.endswith('.asc')]
z.extract(targets=names)
for n in names:
asc2png(n, args.output, args.width if args.width else None,
args.height if args.height else None, args.normalize)
shutil.rmtree(names[0].split('/')[0])
i += 1
if i % step == 0:
print(i)
def asc2png(filepath, out_folder, width, height, normalize):
if out_folder and not os.path.isdir(out_folder):
os.mkdir(out_folder)
folder = '/'.join(filepath.split('/')[:-1])
filename = filepath.split('/')[-1].split('.')[0]
img = np.loadtxt(os.path.join(folder, filename) + '.asc', skiprows=6)
img[img < 0] = 0
if width and height:
img = cv2.resize(img, (width, height))
if normalize:
img = ((img - np.min(img)) / (np.max(img) - np.min(img))) * 65535
else:
img = (img * 6)
img_int = img.astype(np.uint16)
out = os.path.join(out_folder, f'{filename}.png')
cv2.imwrite(out, img_int)
if __name__ == '__main__':
main()
import math
import os
import argparse
import shutil
import cv2
import numpy as np
import py7zr
def main():
parser = argparse.ArgumentParser('Convert tif files to png 16 bits images.')
parser.add_argument('-f', '--folder', type=str, help='Folder of .hgt files', default='')
parser.add_argument('--zip', type=str, help='Zipfile containing .asc files', default='')
parser.add_argument('-o', '--output', type=str, help='Output folder of .png files', default='out/')
parser.add_argument('--width', type=int, help='Width of the output image')
parser.add_argument('--height', type=int, help='Width of the output image')
parser.add_argument('--normalize', action='store_true', help='If present, normalize images before saving.')
parser.add_argument('--crop', action='store_true', help='If present, crop instead of resizing.')
args = parser.parse_args()
assert bool(args.folder) ^ bool(args.zip), 'Must set either --folder or --zip (not both)'
i = 0
step = 50
if args.folder:
for filename in os.listdir(args.folder):
file = os.path.join(args.folder, filename)
if os.path.join(file):
hgt2png(file, args.output, args.width if args.width else None, args.height if args.height else None,
args.normalize, args.crop)
i += 1
if i % step == 0:
print(i)
elif args.zip:
with py7zr.SevenZipFile(args.zip) as z:
names = [n for n in z.getnames() if n.endswith('.asc')]
z.extract(targets=names)
for n in names:
hgt2png(n, args.output, args.width if args.width else None,
args.height if args.height else None, args.normalize, args.crop)
shutil.rmtree(names[0].split('/')[0])
i += 1
if i % step == 0:
print(i)
def hgt2png(filepath, out_folder, width, height, normalize, crop):
if out_folder and not os.path.isdir(out_folder):
os.mkdir(out_folder)
folder = '/'.join(filepath.split(os.sep)[:-1])
filename = filepath.split(os.sep)[-1].split('.')[0]
file = os.path.join(folder, f'{filename}.hgt')
siz = os.path.getsize(file)
dim = int(math.sqrt(siz / 2))
assert dim * dim * 2 == siz, 'Invalid file size'
img = np.fromfile(file, np.dtype('>i2'), dim * dim).reshape((dim, dim))
if normalize:
img = ((img - np.min(img)) / (np.max(img) - np.min(img))) * 65535
img = img.astype(np.uint16)
if crop:
assert bool(width) & bool(height)
n_x = img.shape[0] // height
n_y = img.shape[1] // width
for i in range(n_x):
for j in range(n_y):
crop_img = img[i*height:i*height+height, j*width:j*width+width]
if normalize:
crop_img = ((crop_img - np.min(crop_img)) / (np.max(crop_img) - np.min(crop_img))) * 65535
out = os.path.join(out_folder, f'{filename}_{i:04d}_{j:04d}.png')
cv2.imwrite(out, crop_img.astype(np.uint16))
else:
if width and height:
img = cv2.resize(img, (width, height))
out = os.path.join(out_folder, f'{filename}.png')
cv2.imwrite(out, img.astype(np.uint16))
if __name__ == '__main__':
main()
docs/teaser.jpg

80 KiB

# Auto detect text files and perform LF normalization
* text=auto
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# celery beat schedule file
celerybeat-schedule
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
.idea/*
train_datasets/
pretrained_models/*.p*
experiment/*
MIT License
Copyright (c) 2020 Elad Richardson, Yuval Alaluf
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
# Authoring Terrains with Spatialised Style
This is the official repository for the *Authoring Terrains with Spatialised Style* paper.
This code is based on *[Encoding in Style: a StyleGAN Encoder for Image-to-Image Translation](https://github.com/eladrich/pixel2style2pixel)* adapted to the official StyleGAN2 implementation, and with our contributions.
# Installation
Set up a new environment with the `environment/terrains_style.yaml` requirements.
# Usage
## Train
The training must be done in two phases. First, train the GradualStyleEncoder:
```
python -m scripts.train.py --dataset_type=dataset_type_defined_in_configs_folder --exp_dir=experiment/your_folder --workers=1 --batch_size=4 --test_batch_size=4 --val_interval=1000 --save_interval=1000 --encoder_type=GradualStyleEncoder --start_from_latent_avg --lpips_lambda=0.8 --l2_lambda=1 --id_lambda=0 --stylegan_weights=pretrained_models/stylegan2_network.pkl --output_size=1024 --label_nc=1 --input_nc=1 --overwrite_exp
```
Then, train our encoder:
```
python -m scripts.train --dataset_type=dataset_type_defined_in_configs_folder --exp_dir=experiment/your_folder --workers=1 --batch_size=4 --test_batch_size=4 --val_interval=1000 --image_interval=100 --save_interval=1000 --encoder_type=CNNSkipInAllF --start_from_latent_avg --lpips_lambda=0.8 --l2_lambda=1 --id_lambda=0 --output_size=1024 --resize_outputs --label_nc=1 --input_nc=1 --checkpoint_path=previously_trained_psp_model.pt
```
## Inference
```
python -m scripts.inference --dataset_type=dataset_type_defined_in_configs_folder --checkpoint_path=your_model/model.pt --test_batch_size=1 --exp_dir=out_folder/ --data_path=your_data_folder/ --couple_outputs
```
## Super-resolution
```
python -m scripts.super_resolution_blend.py --exp_dir=/output/path --checkpoint_path=/path/checkpoint --low_res_image=/low/res/image/path --sr_grid=2 --make_collage [--keep_tmp_folders]
```
## Blender addon
A Blender addon is available [here](https://gitlab.liris.cnrs.fr/sperche/styledem-blender-addon) for easy use.
# Citation
If you find our work useful, please consider citing:
```
@article{perche2023spatialisedstyle,
author = {Perche, Simon and Peytavie, Adrien and Benes, Bedrich and Galin, Eric and Guérin, Eric},
title = {Authoring Terrains with Spatialised Style},
journal = {Pacific Graphics},
year = {2023},
}
```
# License
This project is licensed under the MIT License. See the [LICENSE](../LICENSE) file for details.
from configs import transforms_config
from configs.paths_config import dataset_paths
DATASETS = {
'srtm': {
'transforms': transforms_config.DEMTransforms,
'train_source_root': dataset_paths['srtm_train'],
'train_target_root': dataset_paths['srtm_train'],
'test_source_root': dataset_paths['srtm_test'],
'test_target_root': dataset_paths['srtm_test'],
},
}
dataset_paths = {
'srtm_train': '/your/dataset/path/train',
'srtm_test': '/your/dataset/path/test',
}
model_paths = {
'stylegan_ffhq': 'pretrained_models/stylegan2-ffhq-config-f.pt',
'ir_se50': 'pretrained_models/model_ir_se50.pth',
'circular_face': 'pretrained_models/CurricularFace_Backbone.pth',
'mtcnn_pnet': 'pretrained_models/mtcnn/pnet.npy',
'mtcnn_rnet': 'pretrained_models/mtcnn/rnet.npy',
'mtcnn_onet': 'pretrained_models/mtcnn/onet.npy',
'shape_predictor': 'shape_predictor_68_face_landmarks.dat',
'moco': 'pretrained_models/moco_v2_800ep_pretrain.pth.tar',
'stylegan_dem_5m_128': 'pretrained_models/sg2_dem_128.pkl',
'stylegan_dem_5m_1024': 'pretrained_models/sg2_dem_1024_batch32.pt',
'stylegan_dem_30m_256': 'pretrained_models/sg2_dem_30m_246.pt',
}
from abc import abstractmethod
import torchvision.transforms as transforms
from dataset import augmentations
class TransformsConfig(object):
def __init__(self, opts):
self.opts = opts
@abstractmethod
def get_transforms(self):
pass
class EncodeTransforms(TransformsConfig):
def __init__(self, opts):
super(EncodeTransforms, self).__init__(opts)
def get_transforms(self):
transforms_dict = {
'transform_gt_train': transforms.Compose([
transforms.Resize((256, 256)),
transforms.RandomHorizontalFlip(0.5),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
'transform_source': None,
'transform_test': transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
'transform_inference': transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
}
return transforms_dict
class FrontalizationTransforms(TransformsConfig):
def __init__(self, opts):
super(FrontalizationTransforms, self).__init__(opts)
def get_transforms(self):
transforms_dict = {
'transform_gt_train': transforms.Compose([
transforms.Resize((256, 256)),
transforms.RandomHorizontalFlip(0.5),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
'transform_source': transforms.Compose([
transforms.Resize((256, 256)),
transforms.RandomHorizontalFlip(0.5),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
'transform_test': transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
'transform_inference': transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
}
return transforms_dict
class SketchToImageTransforms(TransformsConfig):
def __init__(self, opts):
super(SketchToImageTransforms, self).__init__(opts)
def get_transforms(self):
transforms_dict = {
'transform_gt_train': transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
'transform_source': transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor()]),
'transform_test': transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
'transform_inference': transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor()]),
}
return transforms_dict
class SegToImageTransforms(TransformsConfig):
def __init__(self, opts):
super(SegToImageTransforms, self).__init__(opts)
def get_transforms(self):
transforms_dict = {
'transform_gt_train': transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
'transform_source': transforms.Compose([
transforms.Resize((256, 256)),
augmentations.ToOneHot(self.opts.label_nc),
transforms.ToTensor()]),
'transform_test': transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
'transform_inference': transforms.Compose([
transforms.Resize((256, 256)),
augmentations.ToOneHot(self.opts.label_nc),
transforms.ToTensor()])
}
return transforms_dict
class SuperResTransforms(TransformsConfig):
def __init__(self, opts):
super(SuperResTransforms, self).__init__(opts)
def get_transforms(self):
if self.opts.resize_factors is None:
self.opts.resize_factors = '1,2,4,8,16,32'
factors = [int(f) for f in self.opts.resize_factors.split(",")]
print("Performing down-sampling with factors: {}".format(factors))
transforms_dict = {
'transform_gt_train': transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
'transform_source': transforms.Compose([
transforms.Resize((256, 256)),
augmentations.BilinearResize(factors=factors),
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
'transform_test': transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
'transform_inference': transforms.Compose([
transforms.Resize((256, 256)),
augmentations.BilinearResize(factors=factors),
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
}
return transforms_dict
class DEMTransforms(TransformsConfig):
def __init__(self, opts):
super(DEMTransforms, self).__init__(opts)
def get_transforms(self):
transforms_dict = {
'transform_gt_train': transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor()]),
'transform_source': transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor()]),
'transform_test': transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor()]),
'transform_inference': transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor()])
}
return transforms_dict
class DEMTransformsAdaptative(TransformsConfig):
def __init__(self, opts):
super(DEMTransformsAdaptative, self).__init__(opts)
self.opts = opts
def get_transforms(self):
transforms_dict = {
'transform_gt_train': transforms.Compose([
transforms.Resize((self.opts.loss_resolution, self.opts.loss_resolution)),
transforms.ToTensor()]),
'transform_source': transforms.Compose([
transforms.Resize((self.opts.input_size, self.opts.input_size)),
transforms.ToTensor()]),
'transform_test': transforms.Compose([
transforms.Resize((self.opts.loss_resolution, self.opts.loss_resolution)),
transforms.ToTensor()]),
'transform_inference': transforms.Compose([
transforms.Resize((self.opts.input_size, self.opts.input_size)),
transforms.ToTensor()])
}
return transforms_dict
class DEMTransformsNoResize(TransformsConfig):
def __init__(self, opts):
super(DEMTransformsNoResize, self).__init__(opts)
def get_transforms(self):
transforms_dict = {
'transform_gt_train': transforms.Compose([
transforms.ToTensor()]),
'transform_source': transforms.Compose([
transforms.ToTensor()]),
'transform_test': transforms.Compose([
transforms.ToTensor()]),
'transform_inference': transforms.Compose([
transforms.ToTensor()])
}
return transforms_dict
import torch
import torch.nn as nn
from criteria.lpips.networks import get_network, LinLayers
from criteria.lpips.utils import get_state_dict
class LPIPS(nn.Module):
r"""Creates a criterion that measures
Learned Perceptual Image Patch Similarity (LPIPS).
Arguments:
net_type (str): the network type to compare the features:
'alex' | 'squeeze' | 'vgg'. Default: 'alex'.
version (str): the version of LPIPS. Default: 0.1.
"""
def __init__(self, net_type: str = 'alex', version: str = '0.1'):
assert version in ['0.1'], 'v0.1 is only supported now'
super(LPIPS, self).__init__()
# pretrained network
self.net = get_network(net_type).to("cuda")
# linear layers
self.lin = LinLayers(self.net.n_channels_list).to("cuda")
self.lin.load_state_dict(get_state_dict(net_type, version))
def forward(self, x: torch.Tensor, y: torch.Tensor):
feat_x, feat_y = self.net(x), self.net(y)
diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)]
res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)]
return torch.sum(torch.cat(res, 0)) / x.shape[0]
from typing import Sequence
from itertools import chain
import torch
import torch.nn as nn
from torchvision import models
from criteria.lpips.utils import normalize_activation
def get_network(net_type: str):
if net_type == 'alex':
return AlexNet()
elif net_type == 'squeeze':
return SqueezeNet()
elif net_type == 'vgg':
return VGG16()
else:
raise NotImplementedError('choose net_type from [alex, squeeze, vgg].')
class LinLayers(nn.ModuleList):
def __init__(self, n_channels_list: Sequence[int]):
super(LinLayers, self).__init__([
nn.Sequential(
nn.Identity(),
nn.Conv2d(nc, 1, 1, 1, 0, bias=False)
) for nc in n_channels_list
])
for param in self.parameters():
param.requires_grad = False
class BaseNet(nn.Module):
def __init__(self):
super(BaseNet, self).__init__()
# register buffer
self.register_buffer(
'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
self.register_buffer(
'std', torch.Tensor([.458, .448, .450])[None, :, None, None])
def set_requires_grad(self, state: bool):
for param in chain(self.parameters(), self.buffers()):
param.requires_grad = state
def z_score(self, x: torch.Tensor):
return (x - self.mean) / self.std
def forward(self, x: torch.Tensor):
x = self.z_score(x)
output = []
for i, (_, layer) in enumerate(self.layers._modules.items(), 1):
x = layer(x)
if i in self.target_layers:
output.append(normalize_activation(x))
if len(output) == len(self.target_layers):
break
return output
class SqueezeNet(BaseNet):
def __init__(self):
super(SqueezeNet, self).__init__()
self.layers = models.squeezenet1_1(True).features
self.target_layers = [2, 5, 8, 10, 11, 12, 13]
self.n_channels_list = [64, 128, 256, 384, 384, 512, 512]
self.set_requires_grad(False)
class AlexNet(BaseNet):
def __init__(self):
super(AlexNet, self).__init__()
self.layers = models.alexnet(True).features
self.target_layers = [2, 5, 8, 10, 12]
self.n_channels_list = [64, 192, 384, 256, 256]
self.set_requires_grad(False)
class VGG16(BaseNet):
def __init__(self):
super(VGG16, self).__init__()
self.layers = models.vgg16(True).features
self.target_layers = [4, 9, 16, 23, 30]
self.n_channels_list = [64, 128, 256, 512, 512]
self.set_requires_grad(False)
\ No newline at end of file
from collections import OrderedDict
import torch
def normalize_activation(x, eps=1e-10):
norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True))
return x / (norm_factor + eps)
def get_state_dict(net_type: str = 'alex', version: str = '0.1'):
# build url
url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \
+ f'master/lpips/weights/v{version}/{net_type}.pth'
# download
old_state_dict = torch.hub.load_state_dict_from_url(
url, progress=True,
map_location=None if torch.cuda.is_available() else torch.device('cpu')
)
# rename keys
new_state_dict = OrderedDict()
for key, val in old_state_dict.items():
new_key = key
new_key = new_key.replace('lin', '')
new_key = new_key.replace('model.', '')
new_state_dict[new_key] = val
return new_state_dict
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment