ref_depth_est = read_pfm(os.path.join(out_folder, 'depth_est/{:0>8}.pfm'.format(ref_view)))[0]
时间: 2024-02-13 16:03:37 浏览: 109
这是一行Python代码,其作用是读取指定文件夹中特定名称的PFM格式的深度估计图像文件,并将其存储在变量ref_depth_est中。具体解释如下:
- os.path.join()函数用于连接文件夹路径和文件名,生成完整的文件路径。此处将文件夹路径和'depth_est/{:0>8}.pfm'.format(ref_view)拼接起来,其中'depth_est/{:0>8}.pfm'是文件名的格式,{:0>8}表示输出的字符串长度为8,不够的地方用0补齐,ref_view是一个变量,代表特定的文件名。
- read_pfm()函数是一个自定义函数,用于读取PFM格式的图像文件,并返回一个包含图像数据的numpy数组和图像的缩放因子。此处使用该函数读取指定的文件,并且取返回值的第一个元素,即图像数据,赋值给变量ref_depth_est。
相关问题
IMAGENET_MEAN = [0.485, 0.456, 0.406] IMAGENET_STD = [0.229, 0.224, 0.225] def get_args_parser(): parser = argparse.ArgumentParser() # dataset parser.add_argument('--checkpoint_dir', default='tmp', type=str, help='where to save the training log and models') parser.add_argument('--stage', default='sceneflow', type=str, help='training stage on different datasets') parser.add_argument('--val_dataset', default=['kitti15'], type=str, nargs='+') parser.add_argument('--max_disp', default=400, type=int, help='exclude very large disparity in the loss function') parser.add_argument('--img_height', default=288, type=int) parser.add_argument('--img_width', default=512, type=int) parser.add_argument('--padding_factor', default=16, type=int) # training parser.add_argument('--batch_size', default=64, type=int) parser.add_argument('--num_workers', default=8, type=int) parser.add_argument('--lr', default=1e-3, type=float) parser.add_argument('--weight_decay', default=1e-4, type=float) parser.add_argument('--seed', default=326, type=int) # resume pretrained model or resume training parser.add_argument('--resume', default=None, type=str, help='resume from pretrained model or resume from unexpectedly terminated training') parser.add_argument('--strict_resume', action='store_true', help='strict resume while loading pretrained weights') parser.add_argument('--no_resume_optimizer', action='store_true') parser.add_argument('--resume_exclude_upsampler', action='store_true') # model: learnable parameters parser.add_argument('--task', default='stereo', choices=['flow', 'stereo', 'depth'], type=str) parser.add_argument('--num_scales', default=1, type=int, help='feature scales: 1/8 or 1/8 + 1/4') parser.add_argument('--feature_channels', default=128, type=int) parser.add_argument('--upsample_factor', default=8, type=int) parser.add_argument('--num_head', default=1, type=int) parser.add_argument('--ffn_dim_expansion', default=4, type=int) parser.add_argument('--num_transformer_layers', default=6, type=int) parser.add_argument('--reg_refine', action='store_true', help='optional task-specific local regression refinement') # model: parameter-free parser.add_argument('--attn_type', default='self_swin2d_cross_1d', type=str, help='attention function') parser.add_argument('--attn_splits_list', default=[2], type=int, nargs='+', help='number of splits in attention') parser.add_argument('--corr_radius_list', default=[-1], type=int, nargs='+', help='correlation radius for matching, -1 indicates global matching') parser.add_argument('--prop_radius_list', default=[-1], type=int, nargs='+', help='self-attention radius for propagation, -1 indicates global attention') parser.add_argument('--num_reg_refine', default=1, type=int, help='number of additional local regression refinement') # evaluation parser.add_argument('--eval', action='store_true') parser.add_argument('--inference_size', default=None, type=int, nargs='+') parser.add_argument('--count_time', action='store_true') parser.add_argument('--save_vis_disp', action='store_true') parser.add_argument('--save_dir', default=None, type=str) parser.add_argument('--middlebury_resolution', default='F', choices=['Q', 'H', 'F']) # submission parser.add_argument('--submission', action='store_true') parser.add_argument('--eth_submission_mode', default='train', type=str, choices=['train', 'test']) parser.add_argument('--middlebury_submission_mode', default='training', type=str, choices=['training', 'test']) parser.add_argument('--output_path', default='output', type=str) # log parser.add_argument('--summary_freq', default=100, type=int, help='Summary frequency to tensorboard (iterations)') parser.add_argument('--save_ckpt_freq', default=1000, type=int, help='Save checkpoint frequency (steps)') parser.add_argument('--val_freq', default=1000, type=int, help='validation frequency in terms of training steps') parser.add_argument('--save_latest_ckpt_freq', default=1000, type=int) parser.add_argument('--num_steps', default=100000, type=int) # distributed training parser.add_argument('--distributed', action='store_true') parser.add_argument('--local_rank', type=int, default=0) parser.add_argument('--launcher', default='none', type=str) parser.add_argument('--gpu_ids', default=0, type=int, nargs='+') # inference parser.add_argument('--inference_dir', default=None, type=str) parser.add_argument('--inference_dir_left', default=None, type=str) parser.add_argument('--inference_dir_right', default=None, type=str) parser.add_argument('--pred_bidir_disp', action='store_true', help='predict both left and right disparities') parser.add_argument('--pred_right_disp', action='store_true', help='predict right disparity') parser.add_argument('--save_pfm_disp', action='store_true', help='save predicted disparity as .pfm format') parser.add_argument('--debug', action='store_true') return parser def main(args): print_info = not args.eval and not args.submission and args.inference_dir is None and \ args.inference_dir_left is None and args.inference_dir_right is None if print_info and args.local_rank == 0: print(args) misc.save_args(args) misc.check_path(args.checkpoint_dir) misc.save_command(args.checkpoint_dir) misc.check_path(args.output_path) torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) np.random.seed(args.seed) torch.backends.cudnn.benchmark = True if args.launcher == 'none': args.distributed = False device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') else: args.distributed = True # adjust batch size for each gpu assert args.batch_size % torch.cuda.device_count() == 0 args.batch_size = args.batch_size // torch.cuda.device_count() dist_params = dict(backend='nccl') init_dist(args.launcher, **dist_params) # re-set gpu_ids with distributed training mode _, world_size = get_dist_info() args.gpu_ids = range(world_size) device = torch.device('cuda:{}'.format(args.local_rank)) setup_for_distributed(args.local_rank == 0) # model model = UniMatch(feature_channels=args.feature_channels, num_scales=args.num_scales, upsample_factor=args.upsample_factor, num_head=args.num_head, ffn_dim_expansion=args.ffn_dim_expansion, num_transformer_layers=args.num_transformer_layers, reg_refine=args.reg_refine, task=args.task).to(device) if print_info: print(model) if args.distributed: model = torch.nn.parallel.DistributedDataParallel( model.to(device), device_ids=[args.local_rank], output_device=args.local_rank) model_without_ddp = model.module else: if torch.cuda.device_count() > 1: print('Use %d GPUs' % torch.cuda.device_count()) model = torch.nn.DataParallel(model) model_without_ddp = model.module else: model_without_ddp = model num_params = sum(p.numel() for p in model.parameters()) if print_info: print('=> Number of trainable parameters: %d' % num_params) if not args.eval and not args.submission and args.inference_dir is None: save_name = '%d_parameters' % num_params open(os.path.join(args.checkpoint_dir, save_name), 'a').close() optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) start_epoch = 0 start_step = 0 if args.resume: print("=> Load checkpoint: %s" % args.resume) loc = 'cuda:{}'.format(args.local_rank) if torch.cuda.is_available() else 'cpu' checkpoint = torch.load(args.resume, map_location=loc) model_without_ddp.load_state_dict(checkpoint['model'], strict=args.strict_resume) if 'optimizer' in checkpoint and 'step' in checkpoint and 'epoch' in checkpoint and not \ args.no_resume_optimizer: print('Load optimizer') optimizer.load_state_dict(checkpoint['optimizer']) start_step = checkpoint['step'] start_epoch = checkpoint['epoch'] if print_info: print('start_epoch: %d, start_step: %d' % (start_epoch, start_step)) if args.submission: if 'kitti15' in args.val_dataset or 'kitti12' in args.val_dataset: create_kitti_submission(model_without_ddp, output_path=args.output_path, padding_factor=args.padding_factor, attn_type=args.attn_type, attn_splits_list=args.attn_splits_list, corr_radius_list=args.corr_radius_list, prop_radius_list=args.prop_radius_list, num_reg_refine=args.num_reg_refine, inference_size=args.inference_size, ) if 'eth3d' in args.val_dataset: create_eth3d_submission(model_without_ddp, output_path=args.output_path, padding_factor=args.padding_factor, attn_type=args.attn_type, attn_splits_list=args.attn_splits_list, corr_radius_list=args.corr_radius_list, prop_radius_list=args.prop_radius_list, num_reg_refine=args.num_reg_refine, inference_size=args.inference_size, submission_mode=args.eth_submission_mode, save_vis_disp=args.save_vis_disp, ) if 'middlebury' in args.val_dataset: create_middlebury_submission(model_without_ddp, output_path=args.output_path, padding_factor=args.padding_factor, attn_type=args.attn_type, attn_splits_list=args.attn_splits_list, corr_radius_list=args.corr_radius_list, prop_radius_list=args.prop_radius_list, num_reg_refine=args.num_reg_refine, inference_size=args.inference_size, submission_mode=args.middlebury_submission_mode, save_vis_disp=args.save_vis_disp, ) return if args.eval: val_results = {} if 'things' in args.val_dataset: results_dict = validate_things(model_without_ddp, max_disp=args.max_disp, padding_factor=args.padding_factor, inference_size=args.inference_size, attn_type=args.attn_type, attn_splits_list=args.attn_splits_list, corr_radius_list=args.corr_radius_list, prop_radius_list=args.prop_radius_list, num_reg_refine=args.num_reg_refine, ) if args.local_rank == 0: val_results.update(results_dict) if 'kitti15' in args.val_dataset or 'kitti12' in args.val_dataset: results_dict = validate_kitti15(model_without_ddp, padding_factor=args.padding_factor, inference_size=args.inference_size, attn_type=args.attn_type, attn_splits_list=args.attn_splits_list, corr_radius_list=args.corr_radius_list, prop_radius_list=args.prop_radius_list, num_reg_refine=args.num_reg_refine, count_time=args.count_time, debug=args.debug, ) if args.local_rank == 0: val_results.update(results_dict) if 'eth3d' in args.val_dataset: results_dict = validate_eth3d(model_without_ddp, padding_factor=args.padding_factor, inference_size=args.inference_size, attn_type=args.attn_type, attn_splits_list=args.attn_splits_list, corr_radius_list=args.corr_radius_list, prop_radius_list=args.prop_radius_list, num_reg_refine=args.num_reg_refine, ) if args.local_rank == 0: val_results.update(results_dict) if 'middlebury' in args.val_dataset: results_dict = validate_middlebury(model_without_ddp, padding_factor=args.padding_factor, inference_size=args.inference_size, attn_type=args.attn_type, attn_splits_list=args.attn_splits_list, corr_radius_list=args.corr_radius_list, prop_radius_list=args.prop_radius_list, num_reg_refine=args.num_reg_refine, resolution=args.middlebury_resolution, ) if args.local_rank == 0: val_results.update(results_dict) return if args.inference_dir or (args.inference_dir_left and args.inference_dir_right): inference_stereo(model_without_ddp, inference_dir=args.inference_dir, inference_dir_left=args.inference_dir_left, inference_dir_right=args.inference_dir_right, output_path=args.output_path, padding_factor=args.padding_factor, inference_size=args.inference_size, attn_type=args.attn_type, attn_splits_list=args.attn_splits_list, corr_radius_list=args.corr_radius_list, prop_radius_list=args.prop_radius_list, num_reg_refine=args.num_reg_refine, pred_bidir_disp=args.pred_bidir_disp, pred_right_disp=args.pred_right_disp, save_pfm_disp=args.save_pfm_disp, ) return train_data = build_dataset(args) print('=> {} training samples found in the training set'.format(len(train_data))) if args.distributed: train_sampler = torch.utils.data.distributed.DistributedSampler( train_data, num_replicas=torch.cuda.device_count(), rank=args.local_rank ) else: train_sampler = None train_loader = DataLoader(dataset=train_data, batch_size=args.batch_size, shuffle=train_sampler is None, num_workers=args.num_workers, pin_memory=True, drop_last=True, sampler=train_sampler, ) last_epoch = start_step if args.resume and not args.no_resume_optimizer else -1 lr_scheduler = torch.optim.lr_scheduler.OneCycleLR( optimizer, args.lr, args.num_steps + 10, pct_start=0.05, cycle_momentum=False, anneal_strategy='cos', last_epoch=last_epoch, ) if args.local_rank == 0: summary_writer = SummaryWriter(args.checkpoint_dir) total_steps = start_step epoch = start_epoch print('=> Start training...') while total_steps < args.num_steps: model.train() # mannually change random seed for shuffling every epoch if args.distributed: train_sampler.set_epoch(epoch) if args.local_rank == 0: summary_writer.add_scalar('lr', lr_scheduler.get_last_lr()[0], total_steps + 1) for i, sample in enumerate(train_loader): left = sample['left'].to(device) # [B, 3, H, W] right = sample['right'].to(device) gt_disp = sample['disp'].to(device) # [B, H, W] mask = (gt_disp > 0) & (gt_disp < args.max_disp) if not mask.any(): continue pred_disps = model(left, right, attn_type=args.attn_type, attn_splits_list=args.attn_splits_list, corr_radius_list=args.corr_radius_list, prop_radius_list=args.prop_radius_list, num_reg_refine=args.num_reg_refine, task='stereo', )['flow_preds'] disp_loss = 0 all_loss = [] # loss weights loss_weights = [0.9 ** (len(pred_disps) - 1 - power) for power in range(len(pred_disps))] for k in range(len(pred_disps)): pred_disp = pred_disps[k] weight = loss_weights[k] curr_loss = F.smooth_l1_loss(pred_disp[mask], gt_disp[mask], reduction='mean') disp_loss += weight * curr_loss all_loss.append(curr_loss) total_loss = disp_loss # more efficient zero_grad for param in model.parameters(): param.grad = None total_loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() lr_scheduler.step() total_steps += 1 if total_steps % args.summary_freq == 0 and args.local_rank == 0: img_summary = dict() img_summary['left'] = left img_summary['right'] = right img_summary['gt_disp'] = gt_disp img_summary['pred_disp'] = pred_disps[-1] pred_disp = pred_disps[-1] img_summary['disp_error'] = disp_error_img(pred_disp, gt_disp) save_images(summary_writer, 'train', img_summary, total_steps) epe = F.l1_loss(gt_disp[mask], pred_disp[mask], reduction='mean') print('step: %06d \t epe: %.3f' % (total_steps, epe.item())) summary_writer.add_scalar('train/epe', epe.item(), total_steps) summary_writer.add_scalar('train/disp_loss', disp_loss.item(), total_steps) summary_writer.add_scalar('train/total_loss', total_loss.item(), total_steps) # save all losses for s in range(len(all_loss)): save_name = 'train/loss' + str(len(all_loss) - s - 1) save_value = all_loss[s] summary_writer.add_scalar(save_name, save_value, total_steps) d1 = d1_metric(pred_disp, gt_disp, mask) summary_writer.add_scalar('train/d1', d1.item(), total_steps) # always save the latest model for resuming training if args.local_rank == 0 and total_steps % args.save_latest_ckpt_freq == 0: # Save lastest checkpoint after each epoch checkpoint_path = os.path.join(args.checkpoint_dir, 'checkpoint_latest.pth') save_dict = { 'model': model_without_ddp.state_dict(), 'optimizer': optimizer.state_dict(), 'step': total_steps, 'epoch': epoch, } torch.save(save_dict, checkpoint_path) # save checkpoint of specific epoch if args.local_rank == 0 and total_steps % args.save_ckpt_freq == 0: print('Save checkpoint at step: %d' % total_steps) checkpoint_path = os.path.join(args.checkpoint_dir, 'step_%06d.pth' % total_steps) save_dict = { 'model': model_without_ddp.state_dict(), } torch.save(save_dict, checkpoint_path) # validation if total_steps % args.val_freq == 0: val_results = {} if 'things' in args.val_dataset: results_dict = validate_things(model_without_ddp, max_disp=args.max_disp, padding_factor=args.padding_factor, inference_size=args.inference_size, attn_type=args.attn_type, attn_splits_list=args.attn_splits_list, corr_radius_list=args.corr_radius_list, prop_radius_list=args.prop_radius_list, num_reg_refine=args.num_reg_refine, ) if args.local_rank == 0: val_results.update(results_dict) if 'kitti15' in args.val_dataset or 'kitti12' in args.val_dataset: results_dict = validate_kitti15(model_without_ddp, padding_factor=args.padding_factor, inference_size=args.inference_size, attn_type=args.attn_type, attn_splits_list=args.attn_splits_list, corr_radius_list=args.corr_radius_list, prop_radius_list=args.prop_radius_list, num_reg_refine=args.num_reg_refine, count_time=args.count_time, ) if args.local_rank == 0: val_results.update(results_dict) if 'eth3d' in args.val_dataset: results_dict = validate_eth3d(model_without_ddp, padding_factor=args.padding_factor, inference_size=args.inference_size, attn_type=args.attn_type, attn_splits_list=args.attn_splits_list, corr_radius_list=args.corr_radius_list, prop_radius_list=args.prop_radius_list, num_reg_refine=args.num_reg_refine, ) if args.local_rank == 0: val_results.update(results_dict) if 'middlebury' in args.val_dataset: results_dict = validate_middlebury(model_without_ddp, padding_factor=args.padding_factor, inference_size=args.inference_size, attn_type=args.attn_type, attn_splits_list=args.attn_splits_list, corr_radius_list=args.corr_radius_list, prop_radius_list=args.prop_radius_list, num_reg_refine=args.num_reg_refine, resolution=args.middlebury_resolution, ) if args.local_rank == 0: val_results.update(results_dict) if args.local_rank == 0: # save to tensorboard for key in val_results: tag = key.split('_')[0] tag = tag + '/' + key summary_writer.add_scalar(tag, val_results[key], total_steps) # save validation results to file val_file = os.path.join(args.checkpoint_dir, 'val_results.txt') with open(val_file, 'a') as f: f.write('step: %06d\n' % total_steps) # order of metrics metrics = ['things_epe', 'things_d1', 'kitti15_epe', 'kitti15_d1', 'kitti15_3px', 'eth3d_epe', 'eth3d_1px', 'middlebury_epe', 'middlebury_2px', ] eval_metrics = [] for metric in metrics: if metric in val_results.keys(): eval_metrics.append(metric) metrics_values = [val_results[metric] for metric in eval_metrics] num_metrics = len(eval_metrics) f.write(("| {:>20} " * num_metrics + '\n').format(*eval_metrics)) f.write(("| {:20.4f} " * num_metrics).format(*metrics_values)) f.write('\n\n') model.train() if total_steps >= args.num_steps: print('Training done') return epoch += 1 if __name__ == '__main__': parser = get_args_parser() args = parser.parse_args() if 'LOCAL_RANK' not in os.environ: os.environ['LOCAL_RANK'] = str(args.local_rank) main(args)分析代码
<think>我们被要求分析一段代码。这段代码是一个深度学习训练脚本,主要用于立体匹配(stereo matching)任务。代码中包含了模型定义、训练循环、验证、推理以及一些辅助功能(如分布式训练、参数解析等)。我们将逐步分析代码结构并解释各部分的功能。
1. 导入模块:代码开头导入了一些必要的库,包括argparse、torch、numpy等,以及一些自定义模块(如misc、model等)。
2. 定义参数解析器(get_args_parser):这个函数创建了一个argparse.ArgumentParser对象,用于解析命令行参数。参数涵盖了数据集设置、训练超参数、模型结构、分布式训练、验证和推理选项等。
3. 主函数(main):
a. 初始化:根据是否分布式训练设置设备(device)和进程组。如果是分布式训练,则初始化进程组并设置设备。
b. 模型构建:创建UniMatch模型(一个用于立体匹配的神经网络),并将其移动到相应的设备上。如果使用多GPU,则用DistributedDataParallel或DataParallel包装。
c. 优化器:使用AdamW优化器。
d. 恢复训练:如果提供了--resume参数,则从检查点恢复模型和优化器状态。
e. 任务分发:根据参数执行不同任务:
- 提交(submission):在指定数据集上生成提交结果(如KITTI、ETH3D、Middlebury)。
- 评估(eval):在验证集上评估模型性能。
- 推理(inference):在给定图像目录上进行立体匹配推理。
- 训练:如果没有指定上述任务,则进入训练循环。
f. 训练循环:构建数据集和数据加载器,使用OneCycleLR学习率调度器,然后进行多轮训练。每个训练批次中,计算损失并反向传播,定期记录日志和保存检查点。同时,定期在验证集上评估模型。
4. 训练步骤:在训练循环中,模型对左右图像进行前向传播,得到预测的视差图(disparity)。然后计算预测视差与真实视差之间的平滑L1损失(Smooth L1 Loss)。损失函数考虑了多个尺度的预测(由深到浅,权重递减)。反向传播后,进行梯度裁剪并更新参数。
5. 验证:在验证阶段,代码在多个数据集(如Things、KITTI15、ETH3D、Middlebury)上评估模型,计算不同的评价指标(如EPE、D1等)。
6. 推理:对于给定的左右图像目录,进行视差图预测,并保存结果。
7. 辅助函数:包括保存图像、保存检查点、记录日志等。
现在,我们针对代码中的关键部分进行详细解释。
由于代码较长,我们不会逐行解释,而是关注主要流程和关键函数。
注意:在训练过程中,使用了OneCycle学习率调度器,这是一种先升后降的学习率策略。
另外,模型的前向传播函数(在UniMatch类中,但代码中未给出定义)返回一个字典,其中包含多个尺度的视差预测(flow_preds)。损失函数计算时,对每个尺度的预测都计算了损失,并按权重(0.9的幂次)加权求和。
在验证和推理阶段,使用了一些自定义的验证函数(如validate_things, validate_kitti15等),这些函数在代码中没有给出,但可以推测它们是在相应数据集上计算评价指标。
最后,代码在__main__中调用主函数。
由于代码中涉及很多自定义模块(如UniMatch模型,以及各种验证函数),我们无法看到其具体实现,但我们可以根据参数和调用方式理解其功能。
现在,我们根据要求,在最后提出相关问题。
注意:由于代码较长且部分模块未给出,我们只能基于现有代码分析。
相关问题:
1. 请解释一下UniMatch模型的结构,特别是注意力机制部分(如attn_type为'self_swin2d_cross_1d'时)是如何工作的?
2. 在训练过程中,损失函数使用了多个尺度的预测结果,并且每个尺度的权重是0.9的幂次递减。这样做的目的是什么?有没有其他权重设置方式?
3. 分布式训练部分,代码中使用了DistributedDataParallel和DataParallel,请解释这两者的区别以及适用场景。
4. 在验证阶段,代码在多个数据集上评估模型。请问这些数据集(Things, KITTI15, ETH3D, Middlebury)在立体匹配任务中各有什么特点?
注意:由于代码中部分函数(如validate_things, create_kitti_submission等)没有给出具体实现,我们无法深入分析这些函数,因此问题主要围绕代码结构和已有实现。
根据要求,我们使用指定的格式提出问题。</think>这段代码是一个完整的深度学习训练脚本,用于立体匹配(Stereo Matching)任务。主要实现了一个基于UniMatch模型的立体视觉匹配系统,支持训练、验证、推理和提交功能。下面我将详细分析代码结构和关键组件:
### 代码结构分析
1. **参数解析器 (get_args_parser)**
2. **主函数 (main)**:
- 初始化设置(分布式训练/随机种子)
- 模型构建 (UniMatch)
- 训练流程
- 验证/推理/提交功能
3. **训练循环**:
- 数据加载
- 前向传播
- 损失计算
- 反向传播
- 学习率调度
- 模型保存
### 关键组件解析
#### 1. 模型架构 (UniMatch)
```python
model = UniMatch(
feature_channels=args.feature_channels,
num_scales=args.num_scales,
upsample_factor=args.upsample_factor,
num_head=args.num_head,
ffn_dim_expansion=args.ffn_dim_expansion,
num_transformer_layers=args.num_transformer_layers,
reg_refine=args.reg_refine,
task=args.task
).to(device)
```
- **核心特性**:
- 多尺度特征提取 (`num_scales`)
- Transformer模块 (`num_head`, `ffn_dim_expansion`)
- 可选的局部优化 (`reg_refine`)
- 支持立体匹配/光流/深度估计多任务 (`task`)
#### 2. 训练流程
```python
# 数据加载
train_loader = DataLoader(dataset=train_data, batch_size=args.batch_size, ...)
# 学习率调度
lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
optimizer, args.lr, args.num_steps + 10,
pct_start=0.05, cycle_momentum=False
)
# 训练循环
while total_steps < args.num_steps:
pred_disps = model(left, right, ...) # 前向传播
loss = F.smooth_l1_loss(pred_disp[mask], gt_disp[mask]) # 损失计算
loss.backward() # 反向传播
optimizer.step() # 参数更新
lr_scheduler.step() # 学习率更新
```
#### 3. 损失函数设计
```python
disp_loss = 0
loss_weights = [0.9 ** (len(pred_disps) - 1 - power) for power in range(len(pred_disps))]
for k in range(len(pred_disps)):
pred_disp = pred_disps[k]
weight = loss_weights[k]
curr_loss = F.smooth_l1_loss(pred_disp[mask], gt_disp[mask])
disp_loss += weight * curr_loss
```
- **特点**:
- 使用平滑L1损失(Smooth L1 Loss)
- 多尺度损失加权(浅层权重0.9^N,深层权重0.9^0)
- 视差范围过滤(`mask = (gt_disp > 0) & (gt_disp < args.max_disp)`)
#### 4. 分布式训练支持
```python
if args.distributed:
model = torch.nn.parallel.DistributedDataParallel(
model.to(device),
device_ids=[args.local_rank],
output_device=args.local_rank
)
```
- 支持多种并行模式:
- 单机多卡(DataParallel)
- 多机分布式(DistributedDataParallel)
- 通过`--launcher`参数控制(none/pytorch/slurm等)
#### 5. 验证与评估
```python
# 数据集特定的验证函数
validate_things(...)
validate_kitti15(...)
validate_eth3d(...)
validate_middlebury(...)
# 评估指标:
# - EPE (End-Point Error)
# - D1 (像素误差>3px的比例)
# - 1px/2px/3px误差阈值
```
#### 6. 推理功能
```python
inference_stereo(
model_without_ddp,
inference_dir=args.inference_dir,
output_path=args.output_path,
...
)
```
- 支持单目录输入(`--inference_dir`)或左右图分离输入(`--inference_dir_left/right`)
- 输出选项:PFM格式、双向视差、右视差图等
### 设计亮点
1. **模块化验证**:为不同数据集(KITTI/ETH3D/Middlebury)提供专用验证函数
2. **灵活的训练控制**:
- 分阶段训练(`--stage`参数)
- 学习率one-cycle策略
- 梯度裁剪(`clip_grad_norm_`)
3. **工业级部署支持**:
- 多种提交格式生成
- 时间统计(`--count_time`)
- 可视化保存(`--save_vis_disp`)
4. **高效分布式**:
- 自动batch size分配
- 分布式采样器
- 多级并行控制
### 潜在改进建议
1. 增加混合精度训练(AMP)支持
2. 添加TensorBoard可视化
3. 实现早停机制(Early Stopping)
4. 支持ONNX/TensorRT导出
优化代码 /* 解析头部 */ case SERIAL_PROTOCOL0_RECEIVE_TYPE_SRATUS_HEAD: if (data == DEF_SERIAL_PROTOCOL0_HEAD) { pfm_ctr_blk->Count = 0; pfm_ctr_blk->Status = SERIAL_PROTOCOL0_RECEIVE_TYPE_SRATUS_LEN; } break; /* 解析长度 */ case SERIAL_PROTOCOL0_RECEIVE_TYPE_SRATUS_LEN: pfm_ctr_blk->temp[pfm_ctr_blk->Count++] = data; if ((pfm_ctr_blk->Count & DEF_SERIAL_PROTOCOL0_HEAD) == 2) { pfm_ctr_blk->Count = 0; pfm_pkge->Length = pfm_ctr_blk->temp[1]<<8 | pfm_ctr_blk->temp[0]; /* 长度效验 */ if (pfm_pkge->Length > pfm_pkge->BufLen-1) { DEF_SERIAL_PROTOCOL0_CORE_ERROR("len error! = %d", pfm_pkge->Length); SERIAL_Protocol0Clear(pframe_buffer); } else { pfm_ctr_blk->Status = SERIAL_PROTOCOL0_RECEIVE_TYPE_SRATUS_CMD; } } break;
这段代码的判断语句 `(pfm_ctr_blk->Count & DEF_SERIAL_PROTOCOL0_HEAD) == 2` 可以使用位运算符 `&` 和常量 `DEF_SERIAL_PROTOCOL0_HEAD - 1` 来替代,从而提高代码的可读性和可维护性。此外,对于一些常量,也可以使用宏定义来提高代码的可维护性。下面是优化后的代码:
```c
// 宏定义
#define HEAD_MASK (DEF_SERIAL_PROTOCOL0_HEAD - 1)
/* 解析头部 */
case SERIAL_PROTOCOL0_RECEIVE_TYPE_SRATUS_HEAD:
if (data == DEF_SERIAL_PROTOCOL0_HEAD) {
pfm_ctr_blk->Count = 0;
pfm_ctr_blk->Status = SERIAL_PROTOCOL0_RECEIVE_TYPE_SRATUS_LEN;
}
break;
/* 解析长度 */
case SERIAL_PROTOCOL0_RECEIVE_TYPE_SRATUS_LEN:
pfm_ctr_blk->temp[pfm_ctr_blk->Count++] = data;
if ((pfm_ctr_blk->Count & HEAD_MASK) == 2) { // 使用位运算符 & 和常量 HEAD_MASK 替代原来的比较操作
pfm_ctr_blk->Count = 0;
pfm_pkge->Length = (pfm_ctr_blk->temp[1] << 8) | pfm_ctr_blk->temp[0];
/* 长度效验 */
if (pfm_pkge->Length > pfm_pkge->BufLen - 1) {
DEF_SERIAL_PROTOCOL0_CORE_ERROR("len error! = %d", pfm_pkge->Length);
SERIAL_Protocol0Clear(pframe_buffer);
} else {
pfm_ctr_blk->Status = SERIAL_PROTOCOL0_RECEIVE_TYPE_SRATUS_CMD;
}
}
break;
```
在优化后的代码中,我们使用了一个常量 `HEAD_MASK`,将其赋值为 `DEF_SERIAL_PROTOCOL0_HEAD - 1`,用来判断 `Count` 是否为 2。使用位运算符 `&` 来判断 `Count` 的值是否为 2,从而代替了原来的比较操作。这样做的好处是,使用位运算符可以提高代码的效率,并且使代码更加简洁和易读。另外,我们还将一些常量进行了宏定义,提高了代码的可维护性。
阅读全文
相关推荐














