深度学习笔记(2):caffe 加新层 Attention LSTM layer

本文继上篇分析LSTM层后,探讨如何在Caffe框架中添加新的Attention LSTM层。文章介绍了相关论文中的Attention模型概念,包括Recurrent Models of Visual Attention和ACTION RECOGNITION USING VISUAL ATTENTION,并分享了ALSTM Layer的代码实现。博主在尝试将坐标信息用于attention时遇到困难,推测可能因维度过低导致问题。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

上一篇文章,详细地分析了LSTM layer 的源码和流程图,本篇将在caffe加入新层,Attention lstm layer。

在代码之前,我们先分析一下一些论文里的attention model 的公式和流程图。

(1):  Recurrent Models of Visual Attention 



A、Glimpse Sensor: 在t时刻,选取不同大小的区域,组合成数据ρ
B、Glimpse Network:图片局部信息与位置信息整合 
C、Model Architecture:ht-1隐藏记忆单元,与gt,生成新的ht,并以此生成attention,即感兴趣的地方。

具体公式推导可以看论文和代码


(2)ACTION RECOGNITION USING VISUAL ATTENTION


大体思想是对提出的特征分割,即每张图片分割成49个部分(7X7),这样找出每张图片的关注地方,这里的图(b)有问题,作者的代码也反映出这一点。本文主要是在caffe里写出一个这样的Attention Lstm layer.


(3)  Show, Attend and Tell: Neural Image Caption Generation with Visual Attention 





思想一致,具体推导看论文。



二:ALSTM Layer 代码

Alstm.cpp

#include <string>
#include <vector>

#include "caffe/blob.hpp"
#include "caffe/common.hpp"
#include "caffe/filler.hpp"
#include "caffe/layer.hpp"
#include "caffe/sequence_layers.hpp"
#include "caffe/util/math_functions.hpp"

namespace caffe {

template <typename Dtype>
void ALSTMLayer<Dtype>::RecurrentInputBlobNames(vector<string>* names) const {
  names->resize(2);
  (*names)[0] = "h_0";
  (*names)[1] = "c_0";
}

template <typename Dtype>
void ALSTMLayer<Dtype>::RecurrentOutputBlobNames(vector<string>* names) const {
  names->resize(2);
  (*names)[0] = "h_" + this->int_to_str(this->T_);
  (*names)[1] = "c_T";
}

template <typename Dtype>
void ALSTMLayer<Dtype>::OutputBlobNames(vector<string>* names) const {
  names->resize(2);
  (*names)[0] = "h";
  (*names)[1] = "mask";
}

template <typename Dtype>
void ALSTMLayer<Dtype>::FillUnrolledNet(NetParameter* net_param) const {
  const int num_output = this->layer_param_.recurrent_param().num_output();
  CHECK_GT(num_output, 0) << "num_output must be positive";
  const FillerParameter& weight_filler =
      this->layer_param_.recurrent_param().weight_filler();
  const FillerParameter& bias_filler =
      this->layer_param_.recurrent_param().bias_filler();

  // Add generic LayerParameter's (without bottoms/tops) of layer types we'll
  // use to save redundant code.
  LayerParameter hidden_param;
  hidden_param.set_type("InnerProduct");
  hidden_param.mutable_inner_product_param()->set_num_output(num_output * 4);
  hidden_param.mutable_inner_product_param()->set_bias_term(false);
  hidden_param.mutable_inner_product_param()->set_axis(1);
  hidden_param.mutable_inner_product_param()->
      mutable_weight_filler()->CopyFrom(weight_filler);

  LayerParameter biased_hidden_param(hidden_param);
  biased_hidden_param.mutable_inner_product_param()->set_bias_term(true);
  biased_hidden_param.mutable_inner_product_param()->
      mutable_bias_filler()->CopyFrom(bias_filler);

  LayerParameter attention_param;
  attention_param.set_type("InnerProduct");
  attention_param.mutable_inner_product_param()->set_num_output(256);
  attention_param.mutable_inner_product_param()->set_bias_term(false);
  attention_param.mutable_inner_product_param()->set_axis(2);
  attention_param.mutable_inner_product_param()->
      mutable_weight_filler()->CopyFrom(weight_filler);

  LayerParameter biased_attention_param(attention_param);
  biased_attention_param.mutable_inner_product_param()->set_bias_term(true);
  biased_attention_param.mutable_inner_product_param()->
      mutable_bias_filler()->CopyFrom(bias_filler); // weight + bias

  LayerParameter sum_param;
  sum_param.set_type("Eltwise");
  sum_param.mutable_eltwise_param()->set_operation(
      EltwiseParameter_EltwiseOp_SUM);

  LayerParameter slice_param;
  slice_param.set_type("Slice");
  slice_param.mutable_slice_param()->set_axis(0);

  LayerParameter softmax_param;
  softmax_param.set_type("Softmax");
  softmax_param.mutable_softmax_param()->set_axis(-1);

  LayerParameter split_param;
  split_param.set_type("Split");

  LayerParameter scale_param;
  scale_param.set_type("Scale");

  LayerParameter permute_param;
  permute_param.set_type("Permute");

  LayerParameter reshape_param;
  reshape_param.set_type("Reshape");

  LayerParameter bias_layer_param;
  bias_layer_param.set_type("Bias");

  LayerParameter pool_param;
  pool_param.set_type("Pooling");

  LayerParameter reshape_layer_param;
  reshape_layer_param.set_type("Reshape");

  BlobShape input_shape;
  input_shape.add_dim(1);  // c_0 and h_0 are a single timestep
  input_shape.add_dim(this->N_);
  input_shape.add_dim(num_output);

  net_param->add_input("c_0");
  net_param->add_input_shape()->CopyFrom(input_shape);

  net_param->add_input("h_0");
  net_param->add_input_shape()->CopyFrom(input_shape);

  LayerParameter* cont_slice_param = net_param->add_layer();
  cont_slice_param->CopyFrom(slice_param);
  cont_slice_param->set_name("cont_slice");
  cont_slice_param->add_bottom("cont");
  cont_slice_param->mutable_slice_param()->set_axis(1);

  LayerParameter* x_slice_param = net_param->add_layer();
  x_slice_param->CopyFrom(slice_param);
  x_slice_param->set_name("x_slice");
  x_slice_param->add_bottom("x");

  // Add layer to transform all timesteps of x to the hidden state dimension.
  //     W_xc_x = W_xc * x + b_c
/*
  {
    LayerParameter* x_transform_param = net_param->add_layer();
    x_transform_param->CopyFrom(biased_hidden_param);
    x_transform_param->set_name("x_transform");
    x_transform_param->add_param()->set_name("W_xc");
    x_transform_param->add_param()->set_name("b_c");
    x_transform_param->add_bottom("x");
    x_transform_param->add_top("W_xc_x");
  }

  if (this->static_input_) {
    // Add layer to transform x_static to the gate dimension.
    //     W_xc_x_static = W_xc_static * x_static
    LayerParameter* x_static_transform_param = net_param->add_layer();
    x_static_transform_param->CopyFrom(hidden_param);
    x_static_transform_param->mutable_inner_product_param()->se
评论 17
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值