MyCaffe  1.12.2.41
Deep learning software for Windows C# programmers.
MultiheadAttentionLayer.cs
1using System;
2using System.Collections.Generic;
3using System.Linq;
4using System.Text;
5using MyCaffe.basecode;
6using MyCaffe.common;
7using MyCaffe.param;
8using MyCaffe.fillers;
9using System.Diagnostics;
10using MyCaffe.param.gpt;
11
13namespace MyCaffe.layers.gpt
14{
22 public class MultiheadAttentionLayer<T> : Layer<T>
23 {
24 List<int> m_rgShape = new List<int>() { 1, 1, 1, 1 };
25 // Key, query, value projections for all heads, but in a batch.
26 Layer<T> m_c_attnQ = null;
27 Layer<T> m_c_attnK = null;
28 Layer<T> m_c_attnV = null;
29 // Output projection.
30 Layer<T> m_c_proj = null;
31 // Regularization
32 Layer<T> m_attn_dropout = null;
33 Layer<T> m_resid_dropout = null;
34 // Transpose
35 Layer<T> m_transpose;
36 Layer<T> m_transposeQ;
37 // Softmax
38 Layer<T> m_softmax = null;
39 Blob<T> m_blobX0;
40 Blob<T> m_blobX1;
41 Blob<T> m_blobX2;
42 Blob<T> m_blobQ;
43 Blob<T> m_blobK;
44 Blob<T> m_blobV;
45 Blob<T> m_blobQt;
46 Blob<T> m_blobKt;
47 Blob<T> m_blobKt1;
48 Blob<T> m_blobVt;
49 Blob<T> m_blobWork;
50 Blob<T> m_blobAttA;
51 Blob<T> m_blobAttB;
52 Blob<T> m_blobY;
53 // The number of heads.
54 int m_nHeads;
55 int m_nEmbed;
56 int m_nBlockSize;
57 double m_dfAttnDropout;
58 double m_dfResidDropout;
59
60 int m_nSize;
61 int m_nB;
62 int m_nT;
63 int m_nC;
64
65 BlobCollection<T> m_colInternalBottom = new BlobCollection<T>();
66 BlobCollection<T> m_colInternalTop = new BlobCollection<T>();
67
76 : base(cuda, log, p)
77 {
78 m_type = LayerParameter.LayerType.MULTIHEAD_ATTENTION;
79
80 m_nHeads = (int)p.multihead_attention_param.heads;
81 m_nEmbed = (int)p.multihead_attention_param.embed;
82 m_nBlockSize = (int)p.multihead_attention_param.block_size;
83 m_dfAttnDropout = p.multihead_attention_param.attn_dropout;
84 m_dfResidDropout = p.multihead_attention_param.resid_dropout;
85
86 log.CHECK_EQ(m_nEmbed % m_nHeads, 0, "The embedding size must be divisible by the number of heads.");
87
88 // Query projection for all heads, but in a batch.
89 // input features = m_nHeads
90 LayerParameter ipAttnQ = new LayerParameter(LayerParameter.LayerType.INNERPRODUCT, p.name + ".c_attnQ");
91 ipAttnQ.inner_product_param.num_output = (uint)m_nEmbed;
92 ipAttnQ.inner_product_param.bias_term = true;
94 {
96 ipAttnQ.inner_product_param.bias_filler = new FillerParameter("xavier");
97 }
98 else
99 {
100 ipAttnQ.inner_product_param.weight_filler = new FillerParameter("gaussian", 0, 0, 0.02);
101 ipAttnQ.inner_product_param.bias_filler = new FillerParameter("constant", 0.0);
102 }
103 ipAttnQ.inner_product_param.axis = 2;
104 ipAttnQ.parameters.Add((m_param.parameters.Count > 0) ? m_param.parameters[0] : new ParamSpec(1.0, 1.0));
105 ipAttnQ.parameters.Add((m_param.parameters.Count > 1) ? m_param.parameters[1] : new ParamSpec(1.0, 0.0));
106 m_c_attnQ = Layer<T>.Create(cuda, log, convertLayerParam(ipAttnQ, p), null);
107
108 // Key projection for all heads, but in a batch.
109 // input features = m_nHeads
110 LayerParameter ipAttnK = new LayerParameter(LayerParameter.LayerType.INNERPRODUCT, p.name + ".c_attnK");
111 ipAttnK.inner_product_param.num_output = (uint)m_nEmbed;
112 ipAttnK.inner_product_param.bias_term = true;
114 {
115 ipAttnK.inner_product_param.weight_filler = new FillerParameter("xavier");
116 ipAttnK.inner_product_param.bias_filler = new FillerParameter("xavier");
117 }
118 else
119 {
120 ipAttnK.inner_product_param.weight_filler = new FillerParameter("gaussian", 0, 0, 0.02);
121 ipAttnK.inner_product_param.bias_filler = new FillerParameter("constant", 0.0);
122 }
123 ipAttnK.inner_product_param.axis = 2;
124 ipAttnK.parameters.Add((m_param.parameters.Count > 0) ? m_param.parameters[0] : new ParamSpec(1.0, 1.0));
125 ipAttnK.parameters.Add((m_param.parameters.Count > 1) ? m_param.parameters[1] : new ParamSpec(1.0, 0.0));
126 m_c_attnK = Layer<T>.Create(cuda, log, convertLayerParam(ipAttnK, p), null);
127
128 // Value projection for all heads, but in a batch.
129 // input features = m_nHeads
130 LayerParameter ipAttnV = new LayerParameter(LayerParameter.LayerType.INNERPRODUCT, p.name + ".c_attnV");
131 ipAttnV.inner_product_param.num_output = (uint)m_nEmbed;
132 ipAttnV.inner_product_param.bias_term = true;
134 {
135 ipAttnV.inner_product_param.weight_filler = new FillerParameter("xavier");
136 ipAttnV.inner_product_param.bias_filler = new FillerParameter("xavier");
137 }
138 else
139 {
140 ipAttnV.inner_product_param.weight_filler = new FillerParameter("gaussian", 0, 0, 0.02);
141 ipAttnV.inner_product_param.bias_filler = new FillerParameter("constant", 0.0);
142 }
143 ipAttnV.inner_product_param.axis = 2;
144 ipAttnV.parameters.Add((m_param.parameters.Count > 0) ? m_param.parameters[0] : new ParamSpec(1.0, 1.0));
145 ipAttnV.parameters.Add((m_param.parameters.Count > 1) ? m_param.parameters[1] : new ParamSpec(1.0, 0.0));
146 m_c_attnV = Layer<T>.Create(cuda, log, convertLayerParam(ipAttnV, p), null);
147
148 // Output projection.
149 // input features = m_nEmbed
150 LayerParameter ipProj = new LayerParameter(LayerParameter.LayerType.INNERPRODUCT, p.name + ".c_proj");
151 ipProj.inner_product_param.num_output = (uint)m_nEmbed;
152 ipProj.inner_product_param.bias_term = true;
154 {
156 ipProj.inner_product_param.bias_filler = new FillerParameter("xavier");
157 }
158 else
159 {
160 ipProj.inner_product_param.weight_filler = new FillerParameter("gaussian", 0, 0, 0.02 / Math.Sqrt(2 * m_param.multihead_attention_param.layers));
161 ipProj.inner_product_param.bias_filler = new FillerParameter("constant", 0.0);
162 }
163 ipProj.inner_product_param.axis = 2;
164 ipProj.parameters.Add((m_param.parameters.Count > 0) ? m_param.parameters[0] : new ParamSpec(1.0, 1.0));
165 ipProj.parameters.Add((m_param.parameters.Count > 1) ? m_param.parameters[1] : new ParamSpec(1.0, 0.0));
166 m_c_proj = Layer<T>.Create(cuda, log, convertLayerParam(ipProj, p), null);
167
168 // Regularization
169 if (m_dfAttnDropout > 0)
170 {
171 LayerParameter dropoutAttn = new LayerParameter(LayerParameter.LayerType.DROPOUT, p.name + ".drop_attn");
172 dropoutAttn.dropout_param.dropout_ratio = m_dfAttnDropout;
173 m_attn_dropout = Layer<T>.Create(cuda, log, convertLayerParam(dropoutAttn, p), null);
174 }
175
176 if (m_dfResidDropout > 0)
177 {
178 LayerParameter dropoutResid = new LayerParameter(LayerParameter.LayerType.DROPOUT, p.name + ".drop_resid");
179 dropoutResid.dropout_param.dropout_ratio = m_dfResidDropout;
180 m_resid_dropout = Layer<T>.Create(cuda, log, convertLayerParam(dropoutResid, p), null);
181 }
182
183 // Transpose
184 LayerParameter transpose = new LayerParameter(LayerParameter.LayerType.TRANSPOSE, p.name + ".trans");
185 transpose.transpose_param.dim[1] = 2;
186 transpose.transpose_param.dim[2] = 1;
187 m_transpose = Layer<T>.Create(cuda, log, convertLayerParam(transpose, p), null);
188
189 LayerParameter transposeQ = new LayerParameter(LayerParameter.LayerType.TRANSPOSE, p.name + ".transQ");
190 transposeQ.transpose_param.dim[2] = 3;
191 transposeQ.transpose_param.dim[3] = 2;
192 m_transposeQ = Layer<T>.Create(cuda, log, convertLayerParam(transposeQ, p), null);
193
194 // Softmax
195 LayerParameter softmax = new LayerParameter(LayerParameter.LayerType.SOFTMAX, p.name + ".softmax");
196 softmax.softmax_param.axis = -1;
198 m_softmax = Layer<T>.Create(cuda, log, convertLayerParam(softmax, p), null);
199
200 m_blobX0 = new Blob<T>(cuda, log);
201 m_blobX0.Name = m_param.name + " x0";
202 m_blobX1 = new Blob<T>(cuda, log);
203 m_blobX1.Name = m_param.name + " x1";
204 m_blobX2 = new Blob<T>(cuda, log);
205 m_blobX2.Name = m_param.name + " x2";
206 m_blobQ = new Blob<T>(cuda, log);
207 m_blobQ.Name = m_param.name + " Q";
208 m_blobK = new Blob<T>(cuda, log);
209 m_blobK.Name = m_param.name + " K";
210 m_blobV = new Blob<T>(cuda, log);
211 m_blobV.Name = m_param.name + " V";
212 m_blobQt = new Blob<T>(cuda, log);
213 m_blobQt.Name = m_param.name + " Qt";
214 m_blobKt = new Blob<T>(cuda, log);
215 m_blobKt.Name = m_param.name + " Kt";
216 m_blobKt1 = new Blob<T>(cuda, log);
217 m_blobKt1.Name = m_param.name + " Kt1";
218 m_blobVt = new Blob<T>(cuda, log);
219 m_blobVt.Name = m_param.name + " Vt";
220 m_blobAttA = new Blob<T>(cuda, log);
221 m_blobAttA.Name = m_param.name + " AttA";
222 m_blobAttB = new Blob<T>(cuda, log);
223 m_blobAttB.Name = m_param.name + " AttB";
224 m_blobWork = new Blob<T>(cuda, log);
225 m_blobWork.Name = m_param.name + " Work";
226 m_blobY = new Blob<T>(cuda, log);
227 m_blobY.Name = m_param.name + " Y";
228
230 }
231
233 protected override void dispose()
234 {
235 dispose(ref m_c_attnQ);
236 dispose(ref m_c_attnK);
237 dispose(ref m_c_attnV);
238 dispose(ref m_c_proj);
239 dispose(ref m_attn_dropout);
240 dispose(ref m_resid_dropout);
241 dispose(ref m_transpose);
242 dispose(ref m_transposeQ);
243 dispose(ref m_softmax);
244
245 dispose(ref m_blobX0);
246 dispose(ref m_blobX1);
247 dispose(ref m_blobX2);
248 dispose(ref m_blobQ);
249 dispose(ref m_blobK);
250 dispose(ref m_blobV);
251 dispose(ref m_blobQt);
252 dispose(ref m_blobKt);
253 dispose(ref m_blobKt1);
254 dispose(ref m_blobVt);
255 dispose(ref m_blobAttA);
256 dispose(ref m_blobAttB);
257 dispose(ref m_blobWork);
258 dispose(ref m_blobY);
259
260 base.dispose();
261 }
262
264 protected override void setup_internal_blobs(BlobCollection<T> col)
265 {
266 if (col.Count > 0)
267 return;
268
269 col.Add(m_blobX0);
270 col.Add(m_blobX1);
271 col.Add(m_blobX2);
272 col.Add(m_blobQ);
273 col.Add(m_blobK);
274 col.Add(m_blobV);
275 col.Add(m_blobQt);
276 col.Add(m_blobKt);
277 col.Add(m_blobVt);
278 col.Add(m_blobKt1);
279 col.Add(m_blobAttA);
280 col.Add(m_blobAttB);
281 col.Add(m_blobWork);
282 col.Add(m_blobY);
283
284 col.Add(m_c_attnQ.internal_blobs);
285 col.Add(m_c_attnK.internal_blobs);
286 col.Add(m_c_attnV.internal_blobs);
287 col.Add(m_transpose.internal_blobs);
288 col.Add(m_transposeQ.internal_blobs);
289 col.Add(m_softmax.internal_blobs);
290 if (m_attn_dropout != null)
291 col.Add(m_attn_dropout.internal_blobs);
292 col.Add(m_c_proj.internal_blobs);
293 if (m_resid_dropout != null)
294 col.Add(m_resid_dropout.internal_blobs);
295 }
296
300 public override int ExactNumBottomBlobs
301 {
302 get { return 4; }
303 }
304
308 public override int ExactNumTopBlobs
309 {
310 get { return 1; }
311 }
312
313
319 public override bool ReInitializeParameters(WEIGHT_TARGET target)
320 {
321 base.ReInitializeParameters(target);
322
323 m_c_attnQ.ReInitializeParameters(target);
324 m_c_attnK.ReInitializeParameters(target);
325 m_c_attnV.ReInitializeParameters(target);
326 m_c_proj.ReInitializeParameters(target);
327
328 return true;
329 }
330
331 private void addInternal(Blob<T> bottom, Blob<T> top)
332 {
333 m_colInternalBottom.Clear();
334 m_colInternalBottom.Add(bottom);
335
336 m_colInternalTop.Clear();
337 m_colInternalTop.Add(top);
338 }
339
340 private void addInternal(List<Blob<T>> rgBottom, Blob<T> top)
341 {
342 m_colInternalBottom.Clear();
343
344 for (int i=0; i<rgBottom.Count; i++)
345 {
346 m_colInternalBottom.Add(rgBottom[i]);
347 }
348
349 m_colInternalTop.Clear();
350 m_colInternalTop.Add(top);
351 }
352
358 public override void LayerSetUp(BlobCollection<T> colBottom, BlobCollection<T> colTop)
359 {
360 shareLayerBlob(m_blobX0, colBottom[0].shape());
361 m_blobX0.ReshapeLike(colBottom[0]);
362 shareLayerBlob(m_blobX1, colBottom[0].shape());
363 m_blobX1.ReshapeLike(colBottom[1]);
364 shareLayerBlob(m_blobX2, colBottom[0].shape());
365 m_blobX2.ReshapeLike(colBottom[2]);
366
367 m_nB = m_blobX0.num; // batch size
368 m_nT = m_blobX0.channels; // sequence length
369 m_nC = m_blobX0.height; // embedding dim (m_nEmbed)
370 m_nSize = m_nC / m_nHeads;
371
372 addInternal(m_blobX0, m_blobQ);
373 m_c_attnQ.Setup(m_colInternalBottom, m_colInternalTop);
374 addInternal(m_blobX1, m_blobK);
375 m_c_attnK.Setup(m_colInternalBottom, m_colInternalTop);
376 addInternal(m_blobX2, m_blobV);
377 m_c_attnV.Setup(m_colInternalBottom, m_colInternalTop);
378
379 blobs.Add(m_c_attnQ.blobs[0]);
380 blobs.Add(m_c_attnQ.blobs[1]);
381 blobs.Add(m_c_attnK.blobs[0]);
382 blobs.Add(m_c_attnK.blobs[1]);
383 blobs.Add(m_c_attnV.blobs[0]);
384 blobs.Add(m_c_attnV.blobs[1]);
385
386 m_rgShape[0] = m_nB;
387 m_rgShape[1] = m_nHeads;
388 m_rgShape[2] = m_nT;
389 m_rgShape[3] = m_nSize;
390
391 shareLayerBlob(m_blobQ, m_rgShape);
392 m_blobQ.Reshape(m_rgShape);
393 addInternal(m_blobQ, m_blobQt);
394 m_transpose.Setup(m_colInternalBottom, m_colInternalTop); // (B, nh, T, hs)
395
396 shareLayerBlob(m_blobAttA, m_blobX0.shape());
397 m_blobAttA.Reshape(m_nB, m_nHeads, m_nBlockSize, m_nBlockSize);
398 shareLayerBlob(m_blobAttB, m_blobX0.shape());
399 m_blobAttB.Reshape(m_nB, m_nHeads, m_nBlockSize, m_nBlockSize);
400
401 addInternal(m_blobAttA, m_blobAttB);
402 m_softmax.Setup(m_colInternalBottom, m_colInternalTop);
403
404 if (m_attn_dropout != null)
405 {
406 addInternal(m_blobAttB, m_blobAttB);
407 m_attn_dropout.Setup(m_colInternalBottom, m_colInternalTop);
408 }
409
410 m_rgShape[0] = m_nB;
411 m_rgShape[1] = m_nT;
412 m_rgShape[2] = m_nC;
413 m_rgShape[3] = 1;
414
415 shareLayerBlob(m_blobY, m_rgShape);
416 m_blobY.Reshape(m_rgShape);
417
418 addInternal(m_blobY, colTop[0]);
419 m_c_proj.Setup(m_colInternalBottom, m_colInternalTop);
420
421 blobs.Add(m_c_proj.blobs[0]);
422 blobs.Add(m_c_proj.blobs[1]);
423
424 if (m_resid_dropout != null)
425 {
426 addInternal(colTop[0], colTop[0]);
427 m_resid_dropout.Setup(m_colInternalBottom, m_colInternalTop);
428 }
429
430 foreach (Blob<T> blob in blobs)
431 {
432 if (!blob.Name.StartsWith(m_param.name + "_"))
433 blob.Name = m_param.name + "_" + blob.Name;
434 }
435 }
436
442 public override void Reshape(BlobCollection<T> colBottom, BlobCollection<T> colTop)
443 {
444 m_blobX0.ReshapeLike(colBottom[0]);
445 m_blobX1.ReshapeLike(colBottom[1]);
446 m_blobX2.ReshapeLike(colBottom[2]);
447
448 m_nB = m_blobX0.num; // batch size
449 m_nT = m_blobX0.channels; // sequence length
450 m_nC = m_blobX0.height; // embedding dim (m_nEmbed)
451 m_nSize = m_nC / m_nHeads;
452
453 m_rgShape[0] = m_nB;
454 m_rgShape[1] = m_nT;
455 m_rgShape[2] = m_nHeads;
456 m_rgShape[3] = m_nSize;
457
458 shareLayerBlob(m_blobK, m_rgShape);
459 m_blobK.Reshape(m_rgShape);
460 shareLayerBlob(m_blobKt1, m_rgShape);
461 m_blobKt1.ReshapeLike(m_blobK);
462 shareLayerBlob(m_blobKt, m_rgShape);
463
464 addInternal(m_blobK, m_blobKt);
465 m_transpose.Reshape(m_colInternalBottom, m_colInternalTop); // (B, nh, T, hs)
466 m_blobKt1.ReshapeLike(m_blobKt);
467
468 shareLayerBlob(m_blobQ, m_rgShape);
469 m_blobQ.Reshape(m_rgShape);
470 shareLayerBlob(m_blobQt, m_rgShape);
471
472 addInternal(m_blobQ, m_blobQt);
473 m_transpose.Reshape(m_colInternalBottom, m_colInternalTop); // (B, nh, T, hs)
474
475 shareLayerBlob(m_blobV, m_rgShape);
476 m_blobV.Reshape(m_rgShape);
477 shareLayerBlob(m_blobVt, m_rgShape);
478
479 addInternal(m_blobV, m_blobVt);
480 m_transpose.Reshape(m_colInternalBottom, m_colInternalTop); // (B, nh, T, hs)
481
482 m_rgShape[0] = m_nB;
483 m_rgShape[1] = m_nHeads;
484 m_rgShape[2] = m_nT;
485 m_rgShape[3] = m_nT;
486
487 shareLayerBlob(m_blobAttA, m_rgShape);
488 m_blobAttA.Reshape(m_rgShape);
489 shareLayerBlob(m_blobAttB, m_rgShape);
490 m_blobAttB.Reshape(m_rgShape);
491
492 m_rgShape[0] = m_blobVt.num;
493 m_rgShape[1] = m_blobVt.channels;
494 m_rgShape[2] = m_blobVt.width; // col major
495 m_rgShape[3] = m_blobVt.height;
496
497 shareLayerBlob(m_blobWork, m_rgShape);
498 m_blobWork.Reshape(m_rgShape); // col major
499 addInternal(m_blobWork, m_blobY);
500 m_transposeQ.Reshape(m_colInternalBottom, m_colInternalTop);
501
502 m_rgShape[0] = m_nB;
503 m_rgShape[1] = m_nT;
504 m_rgShape[2] = m_nC;
505 m_rgShape[3] = 1;
506
507 shareLayerBlob(m_blobY, m_rgShape);
508 m_blobY.Reshape(m_rgShape);
509 addInternal(m_blobY, colTop[0]);
510 m_c_proj.Reshape(m_colInternalBottom, m_colInternalTop);
511
512 if (m_resid_dropout != null)
513 {
514 addInternal(colTop[0], colTop[0]);
515 m_resid_dropout.Reshape(m_colInternalBottom, m_colInternalTop);
516 }
517 }
518
529 protected override void forward(BlobCollection<T> colBottom, BlobCollection<T> colTop)
530 {
531 Blob<T> blobMask = colBottom[3];
532
533 m_blobX0.CopyFrom(colBottom[0]);
534 m_blobX1.CopyFrom(colBottom[1]);
535 m_blobX2.CopyFrom(colBottom[2]);
536
537 // Calculate query, for all heads in batch and move head forward to be the batch dim.
538 // q = self.c_attnQ(x1)
539 addInternal(m_blobX0, m_blobQ);
540 m_c_attnQ.Forward(m_colInternalBottom, m_colInternalTop);
541
542 // Calculate key, for all heads in batch and move head forward to be the batch dim.
543 // k = self.c_attnK(x2)
544 addInternal(m_blobX1, m_blobK);
545 m_c_attnK.Forward(m_colInternalBottom, m_colInternalTop);
546
547 // Calculate value, for all heads in batch and move head forward to be the batch dim.
548 // v = self.c_attnK(x3)
549 addInternal(m_blobX2, m_blobV);
550 m_c_attnV.Forward(m_colInternalBottom, m_colInternalTop);
551
552 // Transpose query, key and values along axes 1 & 2
553 // k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
554 // q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
555 // v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
556 m_blobQ.Reshape(m_nB, m_nT, m_nHeads, m_nSize);
557 m_blobK.Reshape(m_nB, m_nT, m_nHeads, m_nSize);
558 m_blobV.Reshape(m_nB, m_nT, m_nHeads, m_nSize);
559
560 addInternal(m_blobQ, m_blobQt);
561 m_transpose.Forward(m_colInternalBottom, m_colInternalTop); // (B, nh, T, hs)
562 addInternal(m_blobK, m_blobKt);
563 m_transpose.Forward(m_colInternalBottom, m_colInternalTop); // (B, nh, T, hs)
564 addInternal(m_blobV, m_blobVt);
565 m_transpose.Forward(m_colInternalBottom, m_colInternalTop); // (B, nh, T, hs)
566
567 // Perform Self Attention forward pass
568 {
569 // Multiply query and key(T) matrices and scale
570 // att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
571 addInternal(m_blobKt, m_blobKt1);
572 m_transposeQ.Forward(m_colInternalBottom, m_colInternalTop);
573
574 double dfScale = 1.0 / Math.Sqrt(m_nSize);
575 m_blobAttA.MatMul(m_blobQt, m_blobKt1);
576 m_blobAttA.scale_data(dfScale);
577
578 // Apply mask to attention matrix
579 // att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
580 float fInf = 1e+29f;
581 m_cuda.mask_batch(m_blobAttA.count(), m_blobAttA.num, blobMask.count(), convert(0.0), convert(-1 * fInf), m_blobAttA.gpu_data, blobMask.gpu_data, m_blobAttA.mutable_gpu_data); // all masked items set to -inf.
582
583 // Take softmax of attention along the last axis.
584 // att = F.softmax(att, dim = -1)
585 addInternal(m_blobAttA, m_blobAttB);
586 m_softmax.Forward(m_colInternalBottom, m_colInternalTop);
587
588 // Apply attention dropout.
589 // att = self.attn_dropout(att)
590 if (m_attn_dropout != null)
591 {
592 addInternal(m_blobAttB, m_blobAttB);
593 m_attn_dropout.Forward(m_colInternalBottom, m_colInternalTop);
594 }
595
596 m_blobWork.Reshape(m_blobVt.num, m_blobVt.channels, m_blobVt.height, m_blobVt.width);
597
598 // Multiply attention matrix with values
599 // y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
600 m_blobWork.MatMul(m_blobAttB, m_blobVt);
601 }
602
603 // Reassemble all head outputs side by side.
604 // y = y.transpose(1, 2).contiguous().view(B, T, C)
605 addInternal(m_blobWork, m_blobY);
606 m_transpose.Forward(m_colInternalBottom, m_colInternalTop);
607 m_blobY.Reshape(m_nB, m_nT, m_nC, 1);
608
609 // Apply output projection.
610 // y = self.resid_dropout(self.c_proj(y))
611 addInternal(m_blobY, colTop[0]);
612 m_c_proj.Forward(m_colInternalBottom, m_colInternalTop);
613
614 // Apply resid dropout
615 if (m_resid_dropout != null)
616 {
617 addInternal(colTop[0], colTop[0]);
618 m_resid_dropout.Forward(m_colInternalBottom, m_colInternalTop);
619 }
620 }
621
633 protected override void backward(BlobCollection<T> colTop, List<bool> rgbPropagateDown, BlobCollection<T> colBottom)
634 {
635 // Gradient with respect to state then data.
636 if (rgbPropagateDown[0])
637 {
638 List<bool> rgbPropagate = new List<bool>() { true, true };
639
640 // Apply resid dropout
641 if (m_resid_dropout != null)
642 {
643 addInternal(colTop[0], colTop[0]);
644 m_resid_dropout.Backward(m_colInternalTop, rgbPropagate, m_colInternalBottom);
645 }
646
647 // Apply output projection.
648 // y = self.w_0(concat_output)
649 addInternal(m_blobY, colTop[0]);
650 m_c_proj.Backward(m_colInternalTop, rgbPropagate, m_colInternalBottom);
651
652 // Reassemble all head outputs side by side.
653 // y = y.transpose(1, 2).contiguous().view(B, T, C)
654 addInternal(m_blobWork, m_blobY);
655 m_transpose.Backward(m_colInternalTop, rgbPropagate, m_colInternalBottom);
656
657 // Perform Self Attention backward pass
658 {
659 // Multiply attention matrix with values
660 // y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
661 m_blobY.CopyFrom(m_blobWork, true, true);
662
663 // Multiply attention matrix with values
664 // y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
665 // Gradient with respect to att
666 // att' = y' @ v^T
667 // Gradient with respect to vt
668 // vt' = att^T @ y'
669 m_blobY.MatMulGrad(m_blobAttB, m_blobVt, m_blobWork);
670
671 // Apply attention dropout.
672 // att = self.attn_dropout(att)
673 if (m_attn_dropout != null)
674 {
675 addInternal(m_blobAttB, m_blobAttB);
676 m_attn_dropout.Backward(m_colInternalTop, rgbPropagate, m_colInternalBottom);
677 }
678
679 // Take softmax of attention along the last axis.
680 // att = F.softmax(att, dim = -1)
681 addInternal(m_blobAttA, m_blobAttB);
682 m_softmax.Backward(m_colInternalTop, rgbPropagate, m_colInternalBottom);
683
684 // Multiply qt with kt^T to create attention matrix
685 // att = qt @ kt^T
686 // Gradient with respect to qt
687 // qt' = att' @ kt
688 // Gradient with respect to qt
689 // qt' = att' @ kt
690 double dfScale = 1.0 / Math.Sqrt(m_nSize);
691 m_blobAttA.MatMulGrad(m_blobQt, m_blobKt1, m_blobWork, dfScale);
692
693 // Transpose Kt1 back to Kt
694 addInternal(m_blobKt, m_blobKt1);
695 m_transposeQ.Backward(m_colInternalTop, rgbPropagate, m_colInternalBottom);
696 }
697
698 // Transpose query, key and values along axes 1 & 2
699 // k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
700 // q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
701 // v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
702 addInternal(m_blobQ, m_blobQt);
703 m_transpose.Backward(m_colInternalTop, rgbPropagate, m_colInternalBottom); // (B, nh, T, hs)
704 addInternal(m_blobK, m_blobKt);
705 m_transpose.Backward(m_colInternalTop, rgbPropagate, m_colInternalBottom); // (B, nh, T, hs)
706 addInternal(m_blobV, m_blobVt);
707 m_transpose.Backward(m_colInternalTop, rgbPropagate, m_colInternalBottom); // (B, nh, T, hs)
708
709 // Calculate query for all heads in batch and move head forward to be the batch dim.
710 // q = self.c_attnQ(x1)
711 addInternal(m_blobX0, m_blobQ);
712 m_c_attnQ.Backward(m_colInternalTop, rgbPropagate, m_colInternalBottom);
713
714 // Calculate query for all heads in batch and move head forward to be the batch dim.
715 // k = self.c_attnK(x2)
716 addInternal(m_blobX1, m_blobK);
717 m_c_attnK.Backward(m_colInternalTop, rgbPropagate, m_colInternalBottom);
718
719 // Calculate query for all heads in batch and move head forward to be the batch dim.
720 // v = self.c_attnV(x3)
721 addInternal(m_blobX2, m_blobV);
722 m_c_attnV.Backward(m_colInternalTop, rgbPropagate, m_colInternalBottom);
723
724 if (colBottom[0].gpu_diff == colBottom[1].gpu_diff && colBottom[0].gpu_diff == colBottom[2].gpu_diff)
725 {
726 m_cuda.add(m_blobX0.count(), m_blobX0.gpu_diff, m_blobX1.gpu_diff, m_blobX2.gpu_diff, colBottom[0].mutable_gpu_diff);
727 }
728 else if (colBottom[1].gpu_diff == colBottom[2].gpu_diff)
729 {
730 colBottom[0].CopyFrom(m_blobX0, true);
731 m_cuda.add(m_blobX1.count(), m_blobX1.gpu_diff, m_blobX2.gpu_diff, colBottom[1].mutable_gpu_diff);
732 }
733 else
734 {
735 colBottom[0].CopyFrom(m_blobX0, true);
736 colBottom[1].CopyFrom(m_blobX1, true);
737 colBottom[2].CopyFrom(m_blobX2, true);
738 }
739 }
740 }
741 }
742}
The Log class provides general output in text form.
Definition: Log.cs:13
void CHECK_EQ(double df1, double df2, string str)
Test whether one number is equal to another.
Definition: Log.cs:239
The BlobCollection contains a list of Blobs.
void Add(Blob< T > b)
Add a new Blob to the collection.
int Count
Returns the number of items in the collection.
void Clear(bool bDispose=false)
Remove all items from the collection.
void CopyFrom(BlobCollection< T > bSrc, bool bCopyDiff=false)
Copy the data or diff from another BlobCollection into this one.
The Blob is the main holder of data that moves through the Layers of the Net.
Definition: Blob.cs:25
int channels
DEPRECIATED; legacy shape accessor channels: use shape(1) instead.
Definition: Blob.cs:800
void MatMul(Blob< T > blobA, Blob< T > blobB, bool bReshape=false, bool bTransA=false, bool bTransB=false, double dfScale=1.0, bool bADiff=false, bool bBDiff=false, bool bCDiff=false)
MatMul blobA with blobB and place the result in this blob (e.g. this = matmul(A, B))....
Definition: Blob.cs:3922
int height
DEPRECIATED; legacy shape accessor height: use shape(2) instead.
Definition: Blob.cs:808
void MatMulGrad(Blob< T > blobA, Blob< T > blobB, Blob< T > blobWork, double dfScale=1.0)
Calculates and propagates the gradient for blobA and BlobB given the input gradient in this blob's di...
Definition: Blob.cs:3974
long mutable_gpu_data
Returns the data GPU handle used by the CudaDnn connection.
Definition: Blob.cs:1487
void Reshape(int nNum, int nChannels, int nHeight, int nWidth, bool? bUseHalfSize=null)
DEPRECIATED; use
Definition: Blob.cs:442
void CopyFrom(Blob< T > src, int nSrcOffset, int nDstOffset, int nCount, bool bCopyData, bool bCopyDiff)
Copy from a source Blob.
Definition: Blob.cs:903
void scale_data(double df)
Scale the data by a scaling factor.
Definition: Blob.cs:1754
int width
DEPRECIATED; legacy shape accessor width: use shape(3) instead.
Definition: Blob.cs:816
List< int > shape()
Returns an array where each element contains the shape of an axis of the Blob.
Definition: Blob.cs:684
int count()
Returns the total number of items in the Blob.
Definition: Blob.cs:739
void ReshapeLike(Blob< T > b, bool? bUseHalfSize=null)
Reshape this Blob to have the same shape as another Blob.
Definition: Blob.cs:648
string Name
Get/set the name of the Blob.
Definition: Blob.cs:2184
long gpu_diff
Returns the diff GPU handle used by the CudaDnn connection.
Definition: Blob.cs:1541
int num
DEPRECIATED; legacy shape accessor num: use shape(0) instead.
Definition: Blob.cs:792
long gpu_data
Returns the data GPU handle used by the CudaDnn connection.
Definition: Blob.cs:1479
The CudaDnn object is the main interface to the Low-Level Cuda C++ DLL.
Definition: CudaDnn.cs:969
An interface for the units of computation which can be composed into a Net.
Definition: Layer.cs:31
LayerParameter m_param
Specifies the LayerParameter describing the Layer.
Definition: Layer.cs:47
void convert(BlobCollection< T > col)
Convert a collection of blobs from / to half size.
Definition: Layer.cs:535
bool shareLayerBlob(Blob< T > b, List< int > rgMinShape)
Attempts to share a Layer Blob if another parameter Blob with the same name and acceptable size is fo...
Definition: Layer.cs:1170
void Backward(BlobCollection< T > colTop, List< bool > rgbPropagateDown, BlobCollection< T > colBottom)
Given the top Blob error gradients, compute the bottom Blob error gradients.
Definition: Layer.cs:815
virtual bool ReInitializeParameters(WEIGHT_TARGET target)
Re-initialize the parameters of the layer.
Definition: Layer.cs:389
double Forward(BlobCollection< T > colBottom, BlobCollection< T > colTop)
Given the bottom (input) Blobs, this function computes the top (output) Blobs and the loss.
Definition: Layer.cs:728
abstract void Reshape(BlobCollection< T > colBottom, BlobCollection< T > colTop)
Adjust the shapes of top blobs and internal buffers to accomodate the shapes of the bottom blobs.
BlobCollection< T > m_colInternalBlobs
Specifies internal blobs used by the layer.
Definition: Layer.cs:59
CudaDnn< T > m_cuda
Specifies the CudaDnn connection to Cuda.
Definition: Layer.cs:39
void Setup(BlobCollection< T > colBottom, BlobCollection< T > colTop)
Implements common Layer setup functionality.
Definition: Layer.cs:439
static Layer< T > Create(CudaDnn< T > cuda, Log log, LayerParameter p, CancelEvent evtCancel, IXDatabaseBase db=null, TransferInput trxinput=null)
Create a new Layer based on the LayerParameter.
Definition: Layer.cs:1468
LayerParameter.LayerType m_type
Specifies the Layer type.
Definition: Layer.cs:35
BlobCollection< T > blobs
Returns the collection of learnable parameter Blobs for the Layer.
Definition: Layer.cs:875
BlobCollection< T > internal_blobs
Returns the collection of internal Blobs used by the Layer.
Definition: Layer.cs:883
LayerParameter convertLayerParam(LayerParameter pChild, LayerParameter pParent)
Called to convert a parent LayerParameterEx, used in blob sharing, with a child layer parameter.
Definition: Layer.cs:1134
The MultiheadAttention provides a vanilla multi-head layer.
override void LayerSetUp(BlobCollection< T > colBottom, BlobCollection< T > colTop)
Setup the layer.
override bool ReInitializeParameters(WEIGHT_TARGET target)
Re-initialize the parameters of the layer.
override void setup_internal_blobs(BlobCollection< T > col)
Derivative layers should add all internal blobws to the 'col' provided.
override int ExactNumTopBlobs
Returns the exact number of required top (output) Blobs: attn
override void dispose()
Releases all GPU and host resources used by the Layer.
override void Reshape(BlobCollection< T > colBottom, BlobCollection< T > colTop)
Reshape the bottom (input) and top (output) blobs.
MultiheadAttentionLayer(CudaDnn< T > cuda, Log log, LayerParameter p)
The MultiheadAttention constructor.
override void forward(BlobCollection< T > colBottom, BlobCollection< T > colTop)
The forward computation.
override int ExactNumBottomBlobs
Returns the exact number of required bottom (input) Blobs: q, k, v, mask
override void backward(BlobCollection< T > colTop, List< bool > rgbPropagateDown, BlobCollection< T > colBottom)
Computes the loss error gradient w.r.t the outputs.
double dropout_ratio
Specifies the dropout ratio. (e.g. the probability that values will be dropped out and set to zero....
Specifies whether to use the NVIDIA cuDnn version or Caffe version of a given forward/backward operat...
Engine engine
Specifies the Engine in use.
Engine
Defines the type of engine to use.
Specifies the filler parameters used to create each Filler.
FillerParameter weight_filler
The filler for the weights.
int axis
Specifies the first axis to be lumped into a single inner product computation; all preceding axes are...
FillerParameter bias_filler
The filler for the bias.
uint num_output
The number of outputs for the layer.
bool bias_term
Whether to have bias terms or not.
Specifies the base parameter for all layers.
List< ParamSpec > parameters
Specifies the ParamSpec parameters of the LayerParameter.
string name
Specifies the name of this LayerParameter.
SoftmaxParameter softmax_param
Returns the parameter set when initialized with LayerType.SOFTMAX
MultiheadAttentionParameter multihead_attention_param
Returns the parameter set when initialized with LayerType.MULTIHEAD_ATTENTION
InnerProductParameter inner_product_param
Returns the parameter set when initialized with LayerType.INNERPRODUCT
TransposeParameter transpose_param
Returns the parameter set when initialized with LayerType.TRANSPOSE
LayerType
Specifies the layer type.
DropoutParameter dropout_param
Returns the parameter set when initialized with LayerType.DROPOUT
Specifies training parameters (multipliers on global learning constants, and the name of other settin...
Definition: ParamSpec.cs:19
int axis
The axis along which to perform the softmax – may be negative to index from the end (e....
Specifies the parameters for the MultiheadAttentionLayer.
WEIGHT_INIT
Defines the weight initialization strategy.
double attn_dropout
Specifies dropout probability used on the attention weights.
uint layers
The number of layers (transformer blocks) used.
double resid_dropout
Specifies dropout probability used on the residual weights.
WEIGHT_INIT weight_init
Specifies the weight initialization strategy (default = ENCODER_DECODER).
The MyCaffe.basecode contains all generic types used throughout MyCaffe.
Definition: Annotation.cs:12
The MyCaffe.common namespace contains common MyCaffe classes.
Definition: BatchInput.cs:8
WEIGHT_TARGET
Defines the type of weight to target in re-initializations.
Definition: Interfaces.cs:38
The MyCaffe.fillers namespace contains all fillers including the Filler class.
The MyCaffe.layers.gpt namespace contains all GPT related layers.
Definition: LayerFactory.cs:15
The MyCaffe.param namespace contains parameters used to create models.
The MyCaffe namespace contains the main body of MyCaffe code that closesly tracks the C++ Caffe open-...
Definition: Annotation.cs:12