MyCaffe  1.12.2.41
Deep learning software for Windows C# programmers.
MultiHeadAttentionInterpParameter.cs
1using System;
2using System.Collections.Generic;
3using System.Linq;
4using System.Text;
6using MyCaffe.basecode;
7
8namespace MyCaffe.param.tft
9{
25 [Serializable]
26 [TypeConverter(typeof(ExpandableObjectConverter))]
28 {
29 FillerParameter m_fillerParam_weights = new FillerParameter("xavier");
30 FillerParameter m_fillerParam_bias = new FillerParameter("constant", 0.1);
31 bool m_bEnableNoise = false;
32 double m_dfSigmaInit = 0.017;
33 uint m_nEmbedDim;
34 uint m_nNumHeads;
35 uint m_nNumHistoricalSteps = 0;
36 uint m_nNumFutureSteps = 0;
37 bool m_bEnableSelfAttention = true;
38
41 {
42 }
43
47 [Description("Specifies to enable self attention (one input, default = true).")]
49 {
50 get { return m_bEnableSelfAttention; }
51 set { m_bEnableSelfAttention = value; }
52 }
53
57 [Description("Specifies the number of historical steps.")]
59 {
60 get { return m_nNumHistoricalSteps; }
61 set { m_nNumHistoricalSteps = value; }
62 }
63
67 [Description("Specifies the number of future steps.")]
68 public uint num_future_steps
69 {
70 get { return m_nNumFutureSteps; }
71 set { m_nNumFutureSteps = value; }
72 }
73
77 [Description("Specifies the state size corresponding to both the input and output sizes.")]
78 public uint embed_dim
79 {
80 get { return m_nEmbedDim; }
81 set { m_nEmbedDim = value; }
82 }
83
87 [Description("Specifies number of attention heads used in the multi-attention.")]
88 public uint num_heads
89 {
90 get { return m_nNumHeads; }
91 set { m_nNumHeads = value; }
92 }
93
100 [Description("Enable/disable noise in the inner-product layer (default = false).")]
101 public bool enable_noise
102 {
103 get { return m_bEnableNoise; }
104 set { m_bEnableNoise = value; }
105 }
106
110 [Description("Specifies the initialization value for the sigma weight and sigma bias used when 'enable_noise' = true.")]
111 public double sigma_init
112 {
113 get { return m_dfSigmaInit; }
114 set { m_dfSigmaInit = value; }
115 }
116
120 [Category("Fillers")]
121 [Description("The filler for the weights.")]
123 {
124 get { return m_fillerParam_weights; }
125 set { m_fillerParam_weights = value; }
126 }
127
131 [Category("Fillers")]
132 [Description("The filler for the bias.")]
134 {
135 get { return m_fillerParam_bias; }
136 set { m_fillerParam_bias = value; }
137 }
138
140 public override object Load(System.IO.BinaryReader br, bool bNewInstance = true)
141 {
142 RawProto proto = RawProto.Parse(br.ReadString());
144
145 if (!bNewInstance)
146 Copy(p);
147
148 return p;
149 }
150
152 public override void Copy(LayerParameterBase src)
153 {
155
156 m_bEnableSelfAttention = p.enable_self_attention;
157 m_nNumHistoricalSteps = p.num_historical_steps;
158 m_nNumFutureSteps = p.num_future_steps;
159
160 m_nEmbedDim = p.embed_dim;
161 m_nNumHeads = p.num_heads;
162
163 if (p.m_fillerParam_bias != null)
164 m_fillerParam_bias = p.m_fillerParam_bias.Clone();
165
166 if (p.m_fillerParam_weights != null)
167 m_fillerParam_weights = p.m_fillerParam_weights.Clone();
168
169 m_bEnableNoise = p.m_bEnableNoise;
170 m_dfSigmaInit = p.m_dfSigmaInit;
171 }
172
174 public override LayerParameterBase Clone()
175 {
177 p.Copy(this);
178 return p;
179 }
180
186 public override RawProto ToProto(string strName)
187 {
188 RawProtoCollection rgChildren = new RawProtoCollection();
189
190 rgChildren.Add("enable_self_attention", enable_self_attention.ToString());
191 rgChildren.Add("num_historical_steps", num_historical_steps.ToString());
192 rgChildren.Add("num_future_steps", num_future_steps.ToString());
193
194 rgChildren.Add("embed_dim", embed_dim.ToString());
195 rgChildren.Add("num_heads", num_heads.ToString());
196
197 if (weight_filler != null)
198 rgChildren.Add(weight_filler.ToProto("weight_filler"));
199
200 if (bias_filler != null)
201 rgChildren.Add(bias_filler.ToProto("bias_filler"));
202
203 if (m_bEnableNoise)
204 {
205 rgChildren.Add("enable_noise", m_bEnableNoise.ToString());
206 rgChildren.Add("sigma_init", m_dfSigmaInit.ToString());
207 }
208
209 return new RawProto(strName, "", rgChildren);
210 }
211
218 {
219 string strVal;
221
222 if ((strVal = rp.FindValue("enable_self_attention")) != null)
223 p.enable_self_attention = bool.Parse(strVal);
224
225 if ((strVal = rp.FindValue("embed_dim")) != null)
226 p.embed_dim = uint.Parse(strVal);
227
228 if ((strVal = rp.FindValue("num_heads")) != null)
229 p.num_heads = uint.Parse(strVal);
230
231 RawProto rpWeightFiller = rp.FindChild("weight_filler");
232 if (rpWeightFiller != null)
233 p.weight_filler = FillerParameter.FromProto(rpWeightFiller);
234
235 RawProto rpBiasFiller = rp.FindChild("bias_filler");
236 if (rpBiasFiller != null)
237 p.bias_filler = FillerParameter.FromProto(rpBiasFiller);
238
239 if ((strVal = rp.FindValue("enable_noise")) != null)
240 p.enable_noise = bool.Parse(strVal);
241
242 if ((strVal = rp.FindValue("sigma_init")) != null)
243 p.sigma_init = ParseDouble(strVal);
244
245 if ((strVal = rp.FindValue("num_historical_steps")) != null)
246 p.num_historical_steps = uint.Parse(strVal);
247
248 if ((strVal = rp.FindValue("num_future_steps")) != null)
249 p.num_future_steps = uint.Parse(strVal);
250
251 return p;
252 }
253 }
254}
static double ParseDouble(string strVal)
Parse double values using the US culture if the decimal separator = '.', then using the native cultur...
The RawProtoCollection class is a list of RawProto objects.
void Add(RawProto p)
Adds a RawProto to the collection.
The RawProto class is used to parse and output Google prototxt file data.
Definition: RawProto.cs:17
RawProto FindChild(string strName)
Searches for a given node.
Definition: RawProto.cs:231
static RawProto Parse(string str)
Parses a prototxt and places it in a new RawProto.
Definition: RawProto.cs:306
string FindValue(string strName)
Searches for a falue of a node within this nodes children.
Definition: RawProto.cs:105
Specifies the filler parameters used to create each Filler.
static FillerParameter FromProto(RawProto rp)
Parses the parameter from a RawProto.
override RawProto ToProto(string strName)
Convert the parameter into a RawProto.
FillerParameter Clone()
Creates a new copy of this instance of the parameter.
The LayerParameterBase is the base class for all other layer specific parameters.
Specifies the parameters for the MultiHeadAttentionInterpLayer (Interpretable Multi-Head Attention La...
bool enable_self_attention
Specifies to enable self attention (one input, default = true).
uint num_historical_steps
Specifies the number of historical steps
bool enable_noise
Enable/disable noise in the inner-product layer (default = false).
uint num_heads
Specifies number of attention heads used in the multi-attention.
override object Load(System.IO.BinaryReader br, bool bNewInstance=true)
Load the parameter from a binary reader.
uint embed_dim
Specifies the state size corresponding to both the input and output sizes.
static MultiHeadAttentionInterpParameter FromProto(RawProto rp)
Parses the parameter from a RawProto.
override void Copy(LayerParameterBase src)
Copy on parameter to another.
double sigma_init
Specifies the initialization value for the sigma weight and sigma bias used when 'enable_noise' = tru...
override RawProto ToProto(string strName)
Convert the parameter into a RawProto.
override LayerParameterBase Clone()
Creates a new copy of this instance of the parameter.
The MyCaffe.basecode contains all generic types used throughout MyCaffe.
Definition: Annotation.cs:12
The MyCaffe namespace contains the main body of MyCaffe code that closesly tracks the C++ Caffe open-...
Definition: Annotation.cs:12