train.py
1
2
3
4
5
6
7
8
9
10加载vae
构建主模型
定义权重的数据类型
加载lpips模型
模型梯度准备
加载lora
学习率调度器,优化器
准备可训练参数
加载预训练模型
划分数据集
训练循环取batch的时候进入provider.py的get_item函数
provider.py
1
2
3get_item到 get_data_static 到 _read_latent_data_static 到 self.dataset.get_data读取对应的数据,可以读取整个视频latent/对应几帧的rgb/外参/内参,会读取六个轨迹的latent然后拼接到一起
读取的target有两帧?似乎重复了,在最后preprocess的时候取得最后一帧静态训练时模型的输入
- latent:一个轨迹的latent是 16, 16, 88, 160, 六条轨迹拼接成 96,16,88,160
- 外参:726,4,4
- 内参:726,4
监督部分
- RGB图:54,3,704,1280
- 外参:54,4,4
- 内参:54,4
组合数据:
- RGB图:54,3,704,1280
- latent:96,16,88,160
- 外参:726+54,4,4
- 内参:726+54,4
batch内数据:
- images_output : 1,54,3,704,1280
- intrinsics : 1,54,4
- cam_view : 1,54,4,4
- time_embeddings:1,121,3
- time_embeddings_target:1,1,3
- num_input_multi_views:6
- intrinsics_input:1,726,4
- c2ws_input:1,726,4,4
- flip_flag:1,54,全零
- target_index:1,614
- rgb_latents:1,96,16,88,160
- depth_output:1,54,704,1280
- images_input_embed:1,96,16,88,160
get_plucker_embedding_and_rays 输出数据:
- plucker_embedding:1,726,6,704,1280
- rays_os:1,726,3,88,160
- rays_ds:1,726,3,88,160
encode_plucker_vae 输出数据:
- plucker_embedding:1,96,32,704,1280
- rays_os:1,726,3,88,160
- rays_ds:1,726,3,88,160
网络中:
- images_input:6, 16, 16, 88, 160
- plucker_embedding:6, 16, 32, 88, 160
- rays_os:6, 121, 3, 88, 160
- rays_ds:6, 121, 3, 88, 160
- x:6,56320,512
前向传播需要的变量:
- data[‘rays_os’]
- data[‘rays_ds’]
- data[‘plucker_embedding’]
- data[‘images_input_embed’]
- data[‘time_embeddings’]
- data[‘cam_view’]
- data[‘intrinsics’]
- data[‘num_input_multi_views’]
显存分析
计算plucker之前:3710 / 3830
计算plucker之后:34926 / 36284
进入plucker vae之前:34926 / 36284
进入plucker vae之后:29794