Latent Diffusion部署&生成Unconditional Model
SSD罐头
2024年02月01日 23:29

最近有生成无条件模型的需求,最开始使用了DDPM的算法,但是发现256x256分辨率的图像生成显存占用非常离谱,400timesteps和256x256分辨率的情况下(不知道还有没有其他参数会不会影响显存占用),需要46G显存,而且生成速度非常慢。800张的数据集训练至Loss“稳定”(但是timesteps太小了生成质量不稳定),使用A40的情况下需要60小时。于是打算切换到Stable Diffusion使用的算法——Latent Diffusion上面。

但Latent Diffusion的源码是2021年的,有些依赖更新后,按库中的environment.yaml文件配置conda环境后无法正常使用,但幸好issue有人已经解决了,所以在此梳理一下流程即可。


训练环境配置   吐槽:B站啥时候把插入代码块的功能砍了???

克隆latent-diffusion的库,进入latent-diffusion的文件夹  cd latent-diffusion    如果机器在国内,修改environment.yaml文件  在https://github.com前面添加https://mirror.ghproxy.com/  即如下所示:  https://mirror.ghproxy.com/https://github.com/  使国内可以正常下载github内容(感谢提供代理的同学,如果代理不能用的话想法版换一个即可)   给pip配置清华源  pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple    执行安装ldm环境的命令:  conda env create -f environment.yaml    如果需要重新装,先卸载:  conda env remove --name ldm    退出ldm环境  conda deactivate    按照官方的版本要求安装之后,需要执行下列命令(或者直接写到environment.yaml文件中也可以):  pip install packaging==21.3  pip install 'torchmetrics<0.8'  以解决运行训练命令报错,具体错误为:  packaging.version.InvalidVersion: Invalid version: '0.10.1,<0.11'  参考:[packaging.version.InvalidVersion: Invalid version: '0.10.1,<0.11' · Issue #207 · CompVis/latent-diffusion · GitHub]    解决该错误后执行还会报错:  File "/latent-diffusion/ldm/models/diffusion/ddpm.py", line 1030, in p_losses  logvar_t = self.logvar[t].to(self.device)  RuntimeError: indices should be either on cpu or on the same device as the indexed tensor (cpu)    需要修改ddpm.py文件,在上面提到ldm/models/diffusion/ddpm.py  的1030行前面添加一行:  self.logvar = self.logvar.to(self.device)  来手动指定设备。  参考:[RuntimeError: indices should be either on cpu or on the same device as the indexed tensor (cpu) · Issue #228 · CompVis/latent-diffusion · GitHub] 

准备数据集 

Latent diffusion的训练需要训练集和验证集两部分,我们使用随机算法来分割两部分的图片,生成图片列表文本文档。  一般来说小数据集中,为了保证模型生成的准确性,可以给大一些的验证集数量。例如5000张左右图片,可以给30%的图片作为验证集。  在本文的样例数据中,约有3000张图片。本文取15%的图片作为验证集。  分类代码如下: 

(暂略)

生成后,数据集的文件夹中会多出来两个txt文件,即为训练集和验证集的图片。把数据集放在data文件夹中即可。

Autoencoder训练 

修改first_stage配置文件   Ldm模型训练需要先训练适配数据集的autoencoder,否则直接跑ldm是没办法用的。 

在models/first_stage_models/kl-f8/config.yaml文件中,修改data.params.batch_size至适合你显卡的显存,按使用过的配置来说,256大小的输入图片,batch_size=1大约占用5G显存,以此类推。 生成的ckpt文件大小约1.1G。 

data.param.train.params.size.crop_size我在设置中注释掉了,因为不知道其意义。validation中的也注释掉了。 其他有关训练的参数都没有改动。 

把config.yaml文件中的data.params.train.target以及data.params.validation.target修改为你自己的数据集的目标,例如: data:  target: main.DataModuleFromConfig  params:   batch_size: 2   wrap: true   train:    target: ldm.data.lsun.LSUNSoilTrain    params:     size: 256      #crop_size: 256   validation:    target: ldm.data.lsun.LSUNSoilValidation    params:     size: 256      #crop_size: 256

修改lsun.py文件 按照上述设置的target,把“ldm/data/lsun.py”中的class LSUNChurchesTrain及Validation复制一份,修改class名为上述文件的名称,修改其中的txt_file及data_root至你的数据集及txt文件所在的位置即可。

运行autoencoder训练

命令行如下,如果多卡增加gpu数量即可:  python main.py --base models/first_stage_models/kl-f8/config.yaml -t --gpus 0,    可以使用python main.py -h查看支持的扩展指令。 

生成的checkpoints及其他的内容在“logs/$dateT$time_config/”文件夹中,$date和$time为你运行命令开始时的时间自动生成的文件夹。模型在checkpoints文件夹中,images为训练过程生成的图片configs为训练的配置文件,testtube中的metrics.csv中有训练过程的记录,可以用来生成可视化文件。

Ldm训练 

库中在models/ldm/lsun_churches256/config.yaml提供了样例文件。我们可以借助现成的设置文件来修改:  修改其中的data.params.train.target以及data.params.validation.target为我们想要的target。如修改为ldm.data.lsun.LSUNSoilTrain,ldm.data.lsun.LSUNSoilValidation lsun.py文件在上面训练autoencoder的时候已经 修改完毕,此处直接使用即可。也即autoencoder和ldm训练使用了相同的数据集。

在config.yaml中,在model.params.first_stage_config.params中添加一行:  ckpt_path: "models/ldm/lsun_churches256/autoencoder_soil.ckpt"  即为我们上个步骤训练的autoencoder模型,将训练出来的最佳模型复制到上述的位置,改名即可。   训练的命令行如下:  python main.py --base models/ldm/lsun_churches256/config.yaml -t --gpus 0, 

使用Unconditonal model生成图片 

生成过程如果直接使用pip默认安装的版本也会出错,报错如下:  (忘记记录了,反正也是一个依赖问题,但换环境尝试了没有复现)   安装好依赖版本后,把训练好的模型放在和ldm的配置文件相同文件夹中,改好名字,执行如下命令生成:  python scripts/sample_diffusion.py -r models/ldm/lsun_churches256/soil.ckpt -n 20 --batch_size 2 -c 30    即以2的batch_size,30的timesteps,生成20张训练好的模型的图片。 具体的扩展命令,可以直接查看scripts/sample_diffusion.py文件指定。

扩展:训练conditonal model

(待续)