ArtFusion Controllable Arbitrary Style Transfer using Dual Conditional Latent Diffusion Models代码理解
1. 环境配置
项目地址:https://github.com/ChenDarYen/ArtFusion
1
2
3git clone https://github.com/ChenDarYen/ArtFusion.git
conda env create -f environment.yaml
conda activate artfusion下载模型
- vae: https://ommer-lab.com/files/latent-diffusion/kl-f16.zip
- 放到
./checkpoints/vae/kl-f16.ckpt
- 放到
- artfusion: https://1drv.ms/u/s!AuZJlZC8oVPfgWC2O77TUlhIfELG?e=RoSa8a
- 放到
./checkpoints/artfusion/
- 放到
- 注意: artfusion下载过程容易中断,导致下载下来的模型大小不是3G, 注意检查
- vae: https://ommer-lab.com/files/latent-diffusion/kl-f16.zip
运行代码
运行
notebooks/style_transfer.ipynb
- 如果出现
numpy.core.multiarray failed to import
, 可能是下载的numpy版本不对(不知道为啥会下错), 重新安装:1
2
3pip uninstall numpy
conda install numpy=1.23.4 --override-channels -c defaults -c pytorch
pip install numpy==1.23.4
- 如果出现
2. 代码结构
notebooks/style_transfer.ipynb
: 推断main.py
: 训练
3. 代码理解
3.1 推断部分(notebooks/style_transfer.ipynb)
1 | # 1. 参数设置 |
跳转: instantiate_from_config: ▼
1 | def instantiate_from_config(config): #传入的是config.model |
1 | 定义了三个图像处理函数, 后面再看. |
1 | style_image_paths = [ #风格图像路径 |
跳转: preprocess_image: ▼
1 | #3. 图片处理 |
1 | # 显示多张图片 |
跳转: display_samples: ▼
1 | # |
1 | # 4. 查看模型学习到的风格 |
跳转: DualCondLDM.vgg_scaling_layer: ▼
1 | class DualCondLDM(LatentDiffusion): #继承自LatentDiffusion |
跳转DualCondLDM父类: ldm.modules.diffusion.ddpm.LatenDiffusion: ▼
1 | class LatentDiffusion(pl.LightningModule): |
跳转: make_beta_schedule: ▼
1 | def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): |
跳转:model_class=ldm.models.diffusion.dual_cond_ddpm.DualConditionDiffusionWrapper ▼
1 | class DualConditionDiffusionWrapper(pl.LightningModule): |
跳转: self.diffusion_model = instantiate_from_config -> ldm.modules.diffusionmodules.model.StyleUNetModel: ▼
1 | # 在unet基础上扩展一些新功能, 特别是与内容相关的处理 |
跳转: ldm.modules.ema.LitEma: ▼
1 | import torch |
跳转: ldm.modules.losses.lpips.vgg16: ▼
1 | import torch |
跳转: ldm.modules.losses.lpips.ScalingLayer: ▼
1 | #对输入进行标准化处理 |
跳转: samples = model.sample_log -> ldm.models.diffusion.ddpm.LatentDiffusion.sample_log: ▼
1 | class LatentDiffusion(pl.LightningModule): |
跳转: ddim_sampler = DDIMSampler -> ldm.models.diffusion.ddim.DDIMSampler: ▼
1 | class DDIMSampler(object): |
跳转: samples, intermediates = ddim_sampler.sample -> ldm.models.diffusion.ddim.DDIMSampler.sample: ▼
1 | class DDIMSampler(object): |
跳转: make_ddim_timesteps -> ldm.modules.diffusionmodules.util.make_ddim_timesteps: ▼
1 | def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): |
跳转: make_ddim_sampling_parameters -> ldm.modules.diffusionmodules.util: ▼
1 | def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): |
跳转: e_t = self.model.apply_model -> ldm.models.diffusion.dual_cond_ddpm.DualCondLDM: ▼
1 | class DualCondLDM(LatentDiffusion): |
跳转: x_recon = self.model-> ldm.models.diffusion.dual_cond_ddpm.DualConditionDiffusionWrapper ▼
1 | # 主要用于封装和处理双条件扩散模型 |
跳转: return self.diffusion_model -> ldm.modules.diffusionmodules.model.StyleUNetModel ▼
1 | class StyleUNetModel(UNetModel): |
跳转: emb = self.time_embed -> TimestepEmbedder.forward ▼
1 | class TimestepEmbedder(nn.Module): |
跳转: x_samples = model.decode_first_stage -> ldm.models.diffusion.dual_cond_ddpm.DualCondLDM ▼
1 | def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False): |
跳转: model.ema_scope("Plotting") -> ldm.models.diffusion.dual_cond_ddpm.DualCondLDM ▼
1 | class DualCondLDM(LatentDiffusion): |
跳转: self.model_ema.store -> ldm.modules.ema.LitEma ▼
1 | class LitEma(nn.Module): |
1 | def get_content_style_features(content_image_path, style_image_path, h=H, w=W): |
跳转: encode_first_stage -> ldm.models.diffusion.dual_cond_ddpm.DualCondLDM ▼
1 | class DualCondLDM(LatentDiffusion): |
跳转: first_stage_model.encode -> ldm.models.autoencoder. ▼
1 | class AutoencoderKL(pl.LightningModule): |
跳转: DiagonalGaussianDistribution -> ldm.modules.distributions.distributions ▼
1 | class DiagonalGaussianDistribution(object): |
跳转: model.get_first_stage_encoding -> ldm.models.diffusion.dual_cond_ddpm.DualCondLDM ▼
1 | class DualCondLDM(LatentDiffusion): |
跳转: model.get_content_features -> ldm.models.diffusion.dual_cond_ddpm.DualCondLDM ▼
1 | class DualCondLDM(LatentDiffusion): |
- 后面的block与前面的block类似,不再赘述
3.2 训练部分(train.py)
训练之前需要下载数据集(根据readme.md中的说明, wiki-art需要将images文件夹里的子文件夹放出来):
1
2
3
4
5
6
7
8
9
10datasets/
|-- ms-coco/
| |-- train2017/
| |-- val2017/
| |-- test2017/
|-- wiki-art/
|-- 类别1/
|-- 类别2/
|-- ...
|-- WikiArt.csv运行训练代码
1 | python main.py --name experiment_name --base ./configs/kl16_content12.yaml --basedir ./checkpoints -t True --gpus 0, |
1 |
|
跳转: data = instantiate_from_config(config.data) -> DataModuleFromConfig ▼
1 |
|