MyCaffe  1.12.2.41
Deep learning software for Windows C# programmers.
TextDataLayer.cs
1using System;
2using System.Collections.Generic;
3using System.Linq;
4using System.Text;
5using MyCaffe.basecode;
6using MyCaffe.common;
7using MyCaffe.param;
8using MyCaffe.fillers;
9using System.IO;
11
12namespace MyCaffe.layers.beta
13{
19 public class TextDataLayer<T> : Layer<T>
20 {
21 DataItem m_currentData = null;
22 Data m_data = null;
23 Vocabulary m_vocab = null;
24 ulong m_lOffset = 0;
25 float[] m_rgEncInput1;
26 float[] m_rgEncInput2;
27 float[] m_rgEncClip;
28 float[] m_rgDecInput;
29 float[] m_rgDecClip;
30 float[] m_rgDecTarget;
31
35 public event EventHandler<OnGetDataArgs> OnGetData;
36
61 : base(cuda, log, p)
62 {
64 }
65
69 protected override void dispose()
70 {
71 base.dispose();
72 }
73
78 public override int MinBottomBlobs
79 {
80 get { return (m_phase == Phase.RUN) ? 3 : 0; }
81 }
82
87 public override int MaxBottomBlobs
88 {
89 get { return (m_phase == Phase.RUN) ? 4 : 0; }
90 }
91
95 public override int MinTopBlobs
96 {
97 get { return 6; }
98 }
99
103 public override int MaxTopBlobs
104 {
105 get { return 7; }
106 }
107
112 {
113 get { return m_vocab; }
114 }
115
120 {
121 get { return (m_currentData == null) ? new IterationInfo(true, true, 0) : m_currentData.IterationInfo; }
122 }
123
124 private static string clean(string str)
125 {
126 string strOut = "";
127
128 foreach (char ch in str)
129 {
130 if (ch == 'á')
131 strOut += 'a';
132 else if (ch == 'é')
133 strOut += 'e';
134 else if (ch == 'í')
135 strOut += 'i';
136 else if (ch == 'ó')
137 strOut += 'o';
138 else if (ch == 'ú')
139 strOut += 'u';
140 else if (ch == 'Á')
141 strOut += 'A';
142 else if (ch == 'É')
143 strOut += 'E';
144 else if (ch == 'Í')
145 strOut += 'I';
146 else if (ch == 'Ó')
147 strOut += 'O';
148 else if (ch == 'Ú')
149 strOut += 'U';
150 else
151 strOut += ch;
152 }
153
154 return strOut;
155 }
156
160 public override bool SupportsPreProcessing
161 {
162 get { return true; }
163 }
164
168 public override bool SupportsPostProcessing
169 {
170 get { return true; }
171 }
172
173 private static List<string> preprocess(string str, int nMaxLen = 0)
174 {
175 string strInput = clean(str);
176 List<string> rgstr = strInput.ToLower().Trim().Split(' ').ToList();
177
178 if (nMaxLen > 0)
179 {
180 rgstr = rgstr.Take(nMaxLen).ToList();
181 if (rgstr.Count < nMaxLen)
182 return null;
183 }
184
185 return rgstr;
186 }
187
188 private string getPath(string strPath)
189 {
190 string strTarget = "$ProgramData$";
191
192 if (!strPath.StartsWith(strTarget))
193 return strPath;
194
195 string strProgData = Environment.GetFolderPath(Environment.SpecialFolder.CommonApplicationData);
196 strProgData = strProgData.TrimEnd('\\');
197
198 strPath = strProgData + strPath.Substring(strTarget.Length);
199
200 return strPath;
201 }
202
207 {
208 List<List<string>> rgrgstrInput = new List<List<string>>();
209 List<List<string>> rgrgstrTarget = new List<List<string>>();
210
211 string strEncoderSrc = getPath(p.encoder_source);
212 string strDecoderSrc = getPath(p.decoder_source);
213
214 string[] rgstrInput = File.ReadAllLines(strEncoderSrc);
215 string[] rgstrTarget = File.ReadAllLines(strDecoderSrc);
216
217 if (rgstrInput.Length != rgstrTarget.Length)
218 throw new Exception("Both the input and target files must contains the same number of lines!");
219
220 for (int i = 0; i < p.sample_size; i++)
221 {
222 List<string> rgstrInput1 = preprocess(rgstrInput[i]);
223 List<string> rgstrTarget1 = preprocess(rgstrTarget[i]);
224
225 if (rgstrInput1 != null && rgstrTarget1 != null)
226 {
227 rgrgstrInput.Add(rgstrInput1);
228 rgrgstrTarget.Add(rgstrTarget1);
229 }
230 }
231
232 m_vocab = new Vocabulary();
233 m_vocab.Load(rgrgstrInput, rgrgstrTarget);
234 m_data = new Data(rgrgstrInput, rgrgstrTarget, m_vocab);
235 }
236
247 public override BlobCollection<T> PreProcessInput(PropertySet customInput, out int nSeqLen, BlobCollection<T> colBottom = null)
248 {
249 nSeqLen = -1;
250
251 if (colBottom == null)
252 {
253 string strInput = m_param.PrepareRunModelInputs();
254 RawProto proto = RawProto.Parse(strInput);
255 Dictionary<string, BlobShape> rgInput = NetParameter.InputFromProto(proto);
256 colBottom = new BlobCollection<T>();
257
258 foreach (KeyValuePair<string, BlobShape> kv in rgInput)
259 {
260 Blob<T> blob = new Blob<T>(m_cuda, m_log);
261 blob.Name = kv.Key;
262 blob.Reshape(kv.Value);
263 colBottom.Add(blob);
264 }
265 }
266
267 string strEncInput = customInput.GetProperty("InputData");
268 if (strEncInput == null)
269 throw new Exception("Could not find the expected input property 'InputData'!");
270
271 PreProcessInput(strEncInput, null, colBottom);
272
273 return colBottom;
274 }
275
288 public override bool PreProcessInput(string strEncInput, int? nDecInput, BlobCollection<T> colBottom)
289 {
290 if (nDecInput.HasValue && nDecInput.Value == (int)SPECIAL_TOKENS.EOS)
291 return false;
292
293 List<string> rgstrInput = null;
294 if (strEncInput != null)
295 rgstrInput = preprocess(strEncInput);
296
297 DataItem data = Data.GetInputData(m_vocab, rgstrInput, nDecInput);
298
300 m_log.CHECK_EQ(colBottom.Count, 4, "The bottom collection must have 3 items: dec_input, enc_input, enc_inputr, enc_clip");
301 else
302 m_log.CHECK_EQ(colBottom.Count, 3, "The bottom collection must have 3 items: dec_input, enc_input | enc_inputr, enc_clip");
303
304 int nT = (int)m_param.text_data_param.time_steps;
305 int nBtmIdx = 0;
306
307 colBottom[nBtmIdx].Reshape(new List<int>() { 1, 1, 1 });
308 nBtmIdx++;
309
311 {
312 colBottom[nBtmIdx].Reshape(new List<int>() { nT, 1, 1 });
313 nBtmIdx++;
314 }
315
317 {
318 colBottom[nBtmIdx].Reshape(new List<int>() { nT, 1, 1 });
319 nBtmIdx++;
320 }
321
322 colBottom[nBtmIdx].Reshape(new List<int>() { nT, 1 });
323
324 float[] rgEncInput = null;
325 float[] rgEncInputR = null;
326 float[] rgEncClip = null;
327 float[] rgDecInput = new float[1];
328
329 if (data.EncoderInput != null)
330 {
331 rgEncInput = new float[nT];
332 rgEncInputR = new float[nT];
333 rgEncClip = new float[nT];
334
335 for (int i = 0; i < nT && i < data.EncoderInput.Count; i++)
336 {
337 rgEncInput[i] = data.EncoderInput[i];
338 rgEncInputR[i] = data.EncoderInputReverse[i];
339 rgEncClip[i] = (i == 0) ? 0 : 1;
340 }
341 }
342
343 rgDecInput[0] = data.DecoderInput;
344
345 nBtmIdx = 0;
346 colBottom[nBtmIdx].mutable_cpu_data = convert(rgDecInput);
347 nBtmIdx++;
348
350 {
351 if (rgEncInput != null)
352 colBottom[nBtmIdx].mutable_cpu_data = convert(rgEncInput);
353 nBtmIdx++;
354 }
355
357 {
358 if (rgEncInputR != null)
359 colBottom[nBtmIdx].mutable_cpu_data = convert(rgEncInputR);
360 nBtmIdx++;
361 }
362
363 if (rgEncClip != null)
364 colBottom[nBtmIdx].mutable_cpu_data = convert(rgEncClip);
365
366 return true;
367 }
368
376 public override List<Tuple<string, int, double>> PostProcessOutput(Blob<T> blobSoftmax, int nK = 1)
377 {
378 m_log.CHECK_EQ(blobSoftmax.channels, 1, "Currently, only batch size = 1 supported.");
379
380 List<Tuple<string, int, double>> rgRes = new List<Tuple<string, int, double>>();
381
382 long lPos;
383 double dfProb = blobSoftmax.GetMaxData(out lPos);
384
385 rgRes.Add(new Tuple<string, int, double>(m_vocab.IndexToWord((int)lPos), (int)lPos, dfProb));
386
387 if (nK > 1)
388 {
389 m_cuda.copy(blobSoftmax.count(), blobSoftmax.gpu_data, blobSoftmax.mutable_gpu_diff);
390
391 for (int i = 1; i < nK; i++)
392 {
393 blobSoftmax.SetData(-1000000000, (int)lPos);
394 dfProb = blobSoftmax.GetMaxData(out lPos);
395
396 string strWord = m_vocab.IndexToWord((int)lPos);
397 if (strWord.Length > 0)
398 rgRes.Add(new Tuple<string, int, double>(strWord, (int)lPos, dfProb));
399 }
400
401 m_cuda.copy(blobSoftmax.count(), blobSoftmax.gpu_diff, blobSoftmax.mutable_gpu_data);
402 blobSoftmax.SetDiff(0);
403 }
404
405 return rgRes;
406 }
407
413 public override string PostProcessOutput(int nIdx)
414 {
415 return m_vocab.IndexToWord(nIdx);
416 }
417
423 public override void LayerSetUp(BlobCollection<T> colBottom, BlobCollection<T> colTop)
424 {
425 // Refuse transformation parameters since TextData is totally generic.
426 if (m_param.transform_param != null)
427 m_log.WriteLine("WARNING: " + m_type.ToString() + " does not transform data.");
428
429 m_log.CHECK_EQ(m_param.text_data_param.batch_size, 1, "Currently, only batch_size = 1 supported.");
430
432 m_log.CHECK_EQ(colTop.Count, 7, "When normal and reverse encoder output used, there must be 7 tops: dec, dclip, enc, encr, eclip, vocabcount, dectgt (only valid on TEST | TRAIN)");
434 m_log.CHECK_EQ(colTop.Count, 6, "When normal or reverse encoder output used, there must be 6 tops: dec, dclip, enc | encr, eclip, vocabcount, dectgt (only valid on TEST | TRAIN)");
435 else
436 m_log.FAIL("You must specify to enable either normal, reverse or both encoder inputs.");
437
438 // Load the encoder and decoder input files into the Data and Vocabulary.
440
441 m_rgDecInput = new float[m_param.text_data_param.batch_size];
442 m_rgDecClip = new float[m_param.text_data_param.batch_size];
446
447 if (m_phase != Phase.RUN)
448 m_rgDecTarget = new float[m_param.text_data_param.batch_size];
449
450 reshape(colTop, true);
451 }
452
457 protected bool Skip()
458 {
459 ulong nSize = (ulong)m_param.solver_count;
460 ulong nRank = (ulong)m_param.solver_rank;
461 // In test mode, only rank 0 runs, so avoid skipping.
462 bool bKeep = (m_lOffset % nSize) == nRank || m_param.phase == Phase.TEST;
463
464 return !bKeep;
465 }
466
470 protected void Next()
471 {
472 m_currentData = m_data.GetNextData(m_param.text_data_param.shuffle);
473 }
474
480 public override void Reshape(BlobCollection<T> colBottom, BlobCollection<T> colTop)
481 {
482 reshape(colTop, false);
483 }
484
485 private void reshape(BlobCollection<T> colTop, bool bSetup)
486 {
487 int nBatchSize = (int)m_param.text_data_param.batch_size;
488 int nT = (int)m_param.text_data_param.time_steps;
489 List<int> rgTopShape = new List<int>() { nT, nBatchSize, 1 };
490 int nTopIdx = 0;
491
492 // Reshape the decoder input.
493 if (!bSetup)
494 colTop[nTopIdx].Reshape(new List<int>() { 1, nBatchSize, 1 });
495 nTopIdx++;
496
497 // Reshape the decoder clip.
498 if (!bSetup)
499 colTop[nTopIdx].Reshape(new List<int>() { 1, nBatchSize });
500 nTopIdx++;
501
502 // Reshape the encoder data | data reverse.
504 {
505 if (!bSetup)
506 colTop[nTopIdx].Reshape(rgTopShape);
507 nTopIdx++;
508 }
509
510 // Reshape the encoder data reverse.
512 {
513 if (!bSetup)
514 colTop[nTopIdx].Reshape(rgTopShape);
515 nTopIdx++;
516 }
517
518 // Reshape the encoder clip for attention.
519 if (!bSetup)
520 colTop[nTopIdx].Reshape(new List<int>() { nT, nBatchSize });
521 nTopIdx++;
522
523 // Reshape the vocab count.
524 colTop[nTopIdx].Reshape(new List<int>() { 1 });
525 if (bSetup)
526 colTop[nTopIdx].SetData(m_vocab.VocabularCount + 2, 0);
527 nTopIdx++;
528
529 // Reshape the decoder target.
530 if (!bSetup)
531 colTop[nTopIdx].Reshape(new List<int>() { 1, nBatchSize, 1 });
532 }
533
542 protected override void forward(BlobCollection<T> colBottom, BlobCollection<T> colTop)
543 {
544 int nBatch = (int)m_param.text_data_param.batch_size;
545 int nT = (int)m_param.text_data_param.time_steps;
546
547 Array.Clear(m_rgDecInput, 0, m_rgDecInput.Length);
548 if (m_phase != Phase.RUN)
549 Array.Clear(m_rgDecTarget, 0, m_rgDecTarget.Length);
550 Array.Clear(m_rgDecClip, 0, m_rgDecClip.Length);
551 Array.Clear(m_rgEncInput1, 0, m_rgEncInput1.Length);
552 Array.Clear(m_rgEncInput2, 0, m_rgEncInput2.Length);
553 Array.Clear(m_rgEncClip, 0, m_rgEncClip.Length);
554
555 int nTopIdx = 0;
556
557 if (m_phase != Phase.RUN)
558 {
559 for (int i = 0; i < nBatch; i++)
560 {
561 while (Skip())
562 Next();
563
564 Next();
565
566 if (OnGetData != null)
568
569 int nIdx = i * nT;
570
571 for (int j = 0; j < nT && j < m_currentData.EncoderInput.Count; j++)
572 {
573 m_rgEncInput1[nIdx + j] = m_currentData.EncoderInput[j];
574 m_rgEncInput2[nIdx + j] = m_currentData.EncoderInputReverse[j];
575 m_rgEncClip[nIdx + j] = (j == 0) ? 0 : 1;
576 }
577
578 m_rgDecClip[i] = m_currentData.DecoderClip;
579 m_rgDecInput[i] = m_currentData.DecoderInput;
580 m_rgDecTarget[i] = m_currentData.DecoderTarget;
581 }
582
583 colTop[nTopIdx].mutable_cpu_data = convert(m_rgDecInput);
584 nTopIdx++;
585
586 colTop[nTopIdx].mutable_cpu_data = convert(m_rgDecClip);
587 nTopIdx++;
588
590 {
591 colTop[nTopIdx].mutable_cpu_data = convert(m_rgEncInput1);
592 nTopIdx++;
593 }
594
596 {
597 colTop[nTopIdx].mutable_cpu_data = convert(m_rgEncInput2);
598 nTopIdx++;
599 }
600
601 colTop[nTopIdx].mutable_cpu_data = convert(m_rgEncClip);
602 nTopIdx++;
603
604 nTopIdx++; // vocab count.
605
606 colTop[nTopIdx].mutable_cpu_data = convert(m_rgDecTarget);
607 nTopIdx++;
608 }
609 else
610 {
611 int nBtmIdx = 0;
612 float fDecInput = convertF(colBottom[nBtmIdx].GetData(0));
613 if (fDecInput < 0)
614 fDecInput = 1;
615
616 nBtmIdx++;
617
618 // Decoder input.
619 colTop[nTopIdx].SetData(fDecInput, 0);
620 nTopIdx++;
621
622 // Decoder clip.
623 colTop[nTopIdx].SetData((fDecInput == 1) ? 0 : 1, 0);
624 nTopIdx++;
625
627 {
628 colTop[nTopIdx].CopyFrom(colBottom[nBtmIdx]);
629 nTopIdx++;
630 nBtmIdx++;
631 }
632
634 {
635 colTop[nTopIdx].CopyFrom(colBottom[nBtmIdx]);
636 nTopIdx++;
637 nBtmIdx++;
638 }
639
640 // Encoder clip.
641 colTop[nTopIdx].CopyFrom(colBottom[nBtmIdx]);
642 }
643 }
644
646 protected override void backward(BlobCollection<T> colTop, List<bool> rgbPropagateDown, BlobCollection<T> colBottom)
647 {
648 }
649 }
650
651
652 namespace TextData
653 {
654#pragma warning disable 1591
655
656 class Data
657 {
658 Random m_random = new Random((int)DateTime.Now.Ticks);
659 List<List<string>> m_rgInput;
660 List<List<string>> m_rgOutput;
661 int m_nCurrentSequence = -1;
662 int m_nCurrentOutputIdx = 0;
663 int m_nSequenceIdx = 0;
664 int m_nIxInput = 1;
665 int m_nIterations = 0;
666 int m_nOutputCount = 0;
667 Vocabulary m_vocab;
668
669 public Data(List<List<string>> rgInput, List<List<string>> rgOutput, Vocabulary vocab)
670 {
671 m_vocab = vocab;
672 m_rgInput = rgInput;
673 m_rgOutput = rgOutput;
674 }
675
677 {
678 get { return m_vocab; }
679 }
680
681 public int VocabularyCount
682 {
683 get { return m_vocab.VocabularCount; }
684 }
685
686 public static DataItem GetInputData(Vocabulary vocab, List<string> rgstrInput, int? nDecInput = null)
687 {
688 List<int> rgInput = null;
689
690 if (rgstrInput != null)
691 {
692 rgInput = new List<int>();
693 foreach (string str in rgstrInput)
694 {
695 rgInput.Add(vocab.WordToIndex(str));
696 }
697 }
698
699 int nClip = 1;
700
701 if (!nDecInput.HasValue)
702 {
703 nClip = 0;
704 nDecInput = 1;
705 }
706
707 return new DataItem(rgInput, nDecInput.Value, -1, nClip, false, true, 0);
708 }
709
710 public DataItem GetNextData(bool bShuffle)
711 {
712 int nDecClip = 1;
713
714 bool bNewSequence = false;
715 bool bNewEpoch = false;
716
717 if (m_nCurrentSequence == -1)
718 {
719 m_nIterations++;
720 bNewSequence = true;
721
722 if (bShuffle)
723 {
724 m_nCurrentSequence = m_random.Next(m_rgInput.Count);
725 }
726 else
727 {
728 m_nCurrentSequence = m_nSequenceIdx;
729 m_nSequenceIdx++;
730 if (m_nSequenceIdx == m_rgOutput.Count)
731 m_nSequenceIdx = 0;
732 }
733
734 m_nOutputCount = m_rgOutput[m_nCurrentSequence].Count;
735 nDecClip = 0;
736
737 if (m_nIterations == m_rgOutput.Count)
738 {
739 bNewEpoch = true;
740 m_nIterations = 0;
741 }
742 }
743
744 List<string> rgstrInput = m_rgInput[m_nCurrentSequence];
745 List<int> rgInput = new List<int>();
746 foreach (string str in rgstrInput)
747 {
748 rgInput.Add(m_vocab.WordToIndex(str));
749 }
750
751 int nIxTarget = 0;
752
753 if (m_nCurrentOutputIdx < m_rgOutput[m_nCurrentSequence].Count)
754 {
755 string strTarget = m_rgOutput[m_nCurrentSequence][m_nCurrentOutputIdx];
756 nIxTarget = m_vocab.WordToIndex(strTarget);
757 }
758
759 DataItem data = new DataItem(rgInput, m_nIxInput, nIxTarget, nDecClip, bNewEpoch, bNewSequence, m_nOutputCount);
760 m_nIxInput = nIxTarget;
761
762 m_nCurrentOutputIdx++;
763
764 if (m_nCurrentOutputIdx == m_rgOutput[m_nCurrentSequence].Count)
765 {
766 m_nCurrentSequence = -1;
767 m_nCurrentOutputIdx = 0;
768 m_nIxInput = 1;
769 }
770
771 return data;
772 }
773 }
774
775 class DataItem
776 {
777 IterationInfo m_iter;
778 List<int> m_rgInput;
779 List<int> m_rgInputReverse;
780 int m_nIxInput;
781 int m_nIxTarget;
782 int m_nDecClip;
783
784 public DataItem(List<int> rgInput, int nIxInput, int nIxTarget, int nDecClip, bool bNewEpoch, bool bNewSequence, int nOutputCount)
785 {
786 m_rgInput = rgInput;
787 m_nIxInput = nIxInput;
788 m_nIxTarget = nIxTarget;
789 m_nDecClip = nDecClip;
790 m_iter = new IterationInfo(bNewEpoch, bNewSequence, nOutputCount);
791 m_rgInputReverse = new List<int>();
792
793 if (rgInput != null)
794 {
795 for (int i = rgInput.Count - 1; i >= 0; i--)
796 {
797 m_rgInputReverse.Add(rgInput[i]);
798 }
799 }
800 else
801 {
802 m_rgInputReverse = null;
803 }
804 }
805
806 public List<int> EncoderInput
807 {
808 get { return m_rgInput; }
809 }
810
811 public List<int> EncoderInputReverse
812 {
813 get { return m_rgInputReverse; }
814 }
815
816 public int DecoderInput
817 {
818 get { return m_nIxInput; }
819 }
820
821 public int DecoderTarget
822 {
823 get { return m_nIxTarget; }
824 }
825
826 public int DecoderClip
827 {
828 get { return m_nDecClip; }
829 }
830
832 {
833 get { return m_iter; }
834 }
835 }
836
837#pragma warning restore 1591
838
842 public class IterationInfo
843 {
844 bool m_bNewEpoch;
845 bool m_bNewSequence;
846 int m_nOutputCount;
847
854 public IterationInfo(bool bNewEpoch, bool bNewSequence, int nOutputCount)
855 {
856 m_bNewEpoch = bNewEpoch;
857 m_bNewSequence = bNewSequence;
858 m_nOutputCount = nOutputCount;
859 }
860
864 public bool NewEpoch
865 {
866 get { return m_bNewEpoch; }
867 }
868
872 public bool NewSequence
873 {
874 get { return m_bNewSequence; }
875 }
876
880 public int OutputCount
881 {
882 get { return m_nOutputCount; }
883 }
884 }
885
889 public class Vocabulary
890 {
891 Dictionary<string, int> m_rgDictionary = new Dictionary<string, int>();
892 Dictionary<string, int> m_rgWordToIndex = new Dictionary<string, int>();
893 Dictionary<int, string> m_rgIndexToWord = new Dictionary<int, string>();
894 List<string> m_rgstrVocabulary = new List<string>();
895
899 public Vocabulary()
900 {
901 }
902
908 public int WordToIndex(string strWord)
909 {
910 if (!m_rgWordToIndex.ContainsKey(strWord))
911 throw new Exception("I do not know the word '" + strWord + "'!");
912
913 return m_rgWordToIndex[strWord];
914 }
915
921 public string IndexToWord(int nIdx)
922 {
923 if (!m_rgIndexToWord.ContainsKey(nIdx))
924 return "";
925
926 return m_rgIndexToWord[nIdx];
927 }
928
932 public int VocabularCount
933 {
934 get { return m_rgstrVocabulary.Count; }
935 }
936
942 public void Load(List<List<string>> rgrgstrInput, List<List<string>> rgrgstrTarget)
943 {
944 m_rgDictionary = new Dictionary<string, int>();
945
946 // Count up all words.
947 for (int i = 0; i < rgrgstrInput.Count; i++)
948 {
949 for (int j = 0; j < rgrgstrInput[i].Count; j++)
950 {
951 string strWord = rgrgstrInput[i][j];
952
953 if (!m_rgDictionary.ContainsKey(strWord))
954 m_rgDictionary.Add(strWord, 1);
955 else
956 m_rgDictionary[strWord]++;
957 }
958
959 for (int j = 0; j < rgrgstrTarget[i].Count; j++)
960 {
961 string strWord = rgrgstrTarget[i][j];
962
963 if (!m_rgDictionary.ContainsKey(strWord))
964 m_rgDictionary.Add(strWord, 1);
965 else
966 m_rgDictionary[strWord]++;
967 }
968 }
969
970 // NOTE: Start at one to save room for START and END tokens where
971 // START = 0 in the model word vectors and
972 // END = 0 in the next word softmax.
973 int nIdx = 2;
974 foreach (KeyValuePair<string, int> kv in m_rgDictionary)
975 {
976 if (kv.Value > 0)
977 {
978 // Add word to vocabulary.
979 m_rgWordToIndex[kv.Key] = nIdx;
980 m_rgIndexToWord[nIdx] = kv.Key;
981 m_rgstrVocabulary.Add(kv.Key);
982 nIdx++;
983 }
984 }
985 }
986 }
987
991 public class OnGetDataArgs : EventArgs
992 {
993 Vocabulary m_vocab;
994 IterationInfo m_iter;
995
1002 {
1003 m_vocab = vocab;
1004 m_iter = iter;
1005 }
1006
1011 {
1012 get { return m_vocab; }
1013 }
1014
1019 {
1020 get { return m_iter; }
1021 }
1022 }
1023 }
1024}
The Log class provides general output in text form.
Definition: Log.cs:13
void WriteLine(string str, bool bOverrideEnabled=false, bool bHeader=false, bool bError=false, bool bDisable=false)
Write a line of output.
Definition: Log.cs:80
void FAIL(string str)
Causes a failure which throws an exception with the desciptive text.
Definition: Log.cs:394
void CHECK_EQ(double df1, double df2, string str)
Test whether one number is equal to another.
Definition: Log.cs:239
Specifies a key-value pair of properties.
Definition: PropertySet.cs:16
string GetProperty(string strName, bool bThrowExceptions=true)
Returns a property as a string value.
Definition: PropertySet.cs:146
The RawProto class is used to parse and output Google prototxt file data.
Definition: RawProto.cs:17
static RawProto Parse(string str)
Parses a prototxt and places it in a new RawProto.
Definition: RawProto.cs:306
The BlobCollection contains a list of Blobs.
void SetData(double df)
Set all blob data to the value specified.
int Count
Returns the number of items in the collection.
void Reshape(int[] rgShape)
Reshapes all blobs in the collection to the given shape.
void CopyFrom(BlobCollection< T > bSrc, bool bCopyDiff=false)
Copy the data or diff from another BlobCollection into this one.
The Blob is the main holder of data that moves through the Layers of the Net.
Definition: Blob.cs:25
int channels
DEPRECIATED; legacy shape accessor channels: use shape(1) instead.
Definition: Blob.cs:800
void SetData(T[] rgData, int nCount=-1, bool bSetCount=true)
Sets a number of items within the Blob's data.
Definition: Blob.cs:1922
long mutable_gpu_diff
Returns the diff GPU handle used by the CudaDnn connection.
Definition: Blob.cs:1555
long mutable_gpu_data
Returns the data GPU handle used by the CudaDnn connection.
Definition: Blob.cs:1487
void Reshape(int nNum, int nChannels, int nHeight, int nWidth, bool? bUseHalfSize=null)
DEPRECIATED; use
Definition: Blob.cs:442
int count()
Returns the total number of items in the Blob.
Definition: Blob.cs:739
string Name
Get/set the name of the Blob.
Definition: Blob.cs:2184
long gpu_diff
Returns the diff GPU handle used by the CudaDnn connection.
Definition: Blob.cs:1541
void SetDiff(double dfVal, int nIdx=-1)
Either sets all of the diff items in the Blob to a given value, or alternatively only sets a single i...
Definition: Blob.cs:1981
long gpu_data
Returns the data GPU handle used by the CudaDnn connection.
Definition: Blob.cs:1479
double GetMaxData(out long lPos)
Returns the maximum data and the position where the maximum is located in the data.
Definition: Blob.cs:2538
The CudaDnn object is the main interface to the Low-Level Cuda C++ DLL.
Definition: CudaDnn.cs:969
An interface for the units of computation which can be composed into a Net.
Definition: Layer.cs:31
Log m_log
Specifies the Log for output.
Definition: Layer.cs:43
LayerParameter m_param
Specifies the LayerParameter describing the Layer.
Definition: Layer.cs:47
void convert(BlobCollection< T > col)
Convert a collection of blobs from / to half size.
Definition: Layer.cs:535
float convertF(T df)
Converts a generic to a float value.
Definition: Layer.cs:1359
Phase m_phase
Specifies the Phase under which the Layer is run.
Definition: Layer.cs:51
CudaDnn< T > m_cuda
Specifies the CudaDnn connection to Cuda.
Definition: Layer.cs:39
LayerParameter.LayerType m_type
Specifies the Layer type.
Definition: Layer.cs:35
The IterationInfo class contains information about each iteration.
int OutputCount
Returns the output count of the current sequence.
IterationInfo(bool bNewEpoch, bool bNewSequence, int nOutputCount)
The constructor.
bool NewEpoch
Returns whether or not the current iteration is in a new epoch.
bool NewSequence
Returns whether or not the current iteration is in a new sequence.
Defines the arguments passed to the OnGetData event.
OnGetDataArgs(Vocabulary vocab, IterationInfo iter)
The constructor.
The Vocabulary object manages the overall word dictionary and word to index and index to word mapping...
int WordToIndex(string strWord)
The WordToIndex method maps a word to its corresponding index value.
int VocabularCount
Returns the number of words in the vocabulary.
void Load(List< List< string > > rgrgstrInput, List< List< string > > rgrgstrTarget)
Loads the word to index mappings.
string IndexToWord(int nIdx)
The IndexToWord method maps an index value to its corresponding word.
The TextDataLayer loads data from text data files for an encoder/decoder type model....
override string PostProcessOutput(int nIdx)
Convert the index to the word.
Vocabulary Vocabulary
Returns the vocabulary of the data sources.
override int? MinBottomBlobs
When running in TRAIN or TEST phase, returns 0 for data layers have no bottom (input) Blobs....
TextDataLayer(CudaDnn< T > cuda, Log log, LayerParameter p)
The TextDataLayer constructor.
EventHandler< OnGetDataArgs > OnGetData
The OnGetTrainingData is called during each forward pass after getting the training data for the pass...
override void dispose()
Release all internal blobs.
override BlobCollection< T > PreProcessInput(PropertySet customInput, out int nSeqLen, BlobCollection< T > colBottom=null)
The PreprocessInput allows derivative data layers to convert a property set of input data into the bo...
override int MaxTopBlobs
Returns the maximum number of required top (output) Blobs: dec, dclip, enc, encr, eclip,...
IterationInfo? IterationInfo
Returns information on the current iteration.
void Next()
Proceeds to the next data item. When shuffling, the next item is randomly selected.
bool Skip()
Skip to the next data input.
override bool SupportsPreProcessing
Should return true when pre processing methods are overriden.
override void LayerSetUp(BlobCollection< T > colBottom, BlobCollection< T > colTop)
Setup the layer.
override void backward(BlobCollection< T > colTop, List< bool > rgbPropagateDown, BlobCollection< T > colBottom)
Not implemented - data Layers do not perform backward..
override void forward(BlobCollection< T > colBottom, BlobCollection< T > colTop)
Run the Forward computation, which fills the data into the top (output) Blobs.
override int? MaxBottomBlobs
When running in TRAIN or TEST phase, returns 0 for data layers have no bottom (input) Blobs....
override List< Tuple< string, int, double > > PostProcessOutput(Blob< T > blobSoftmax, int nK=1)
Convert the maximum index within the softmax into the word index, then convert the word index back in...
override bool SupportsPostProcessing
Should return true when pre postprocessing methods are overriden.
override bool PreProcessInput(string strEncInput, int? nDecInput, BlobCollection< T > colBottom)
Preprocess the input data for the RUN phase.
override int MinTopBlobs
Returns the minimum number of required top (output) Blobs: dec, dclip, enc, eclip,...
void PreProcessInputFiles(TextDataParameter p)
Load the input and target files and convert each into a list of lines each containing a list of words...
override void Reshape(BlobCollection< T > colBottom, BlobCollection< T > colTop)
Data layers have no bottoms, so reshaping is trivial.
Specifies the base parameter for all layers.
int solver_count
Returns the number of Solvers participating in a multi-GPU session for which the Solver using this La...
TextDataParameter text_data_param
Returns the parameter set when initialized with LayerType.TEXT_DATA
int solver_rank
Returns the SolverRank of the Solver using this LayerParameter (if any).
string PrepareRunModelInputs()
Prepare model inputs for the run-net (if any are needed for the layer).
TransformationParameter transform_param
Returns the parameter set when initialized with LayerType.TRANSFORM
Phase phase
Specifies the Phase for which this LayerParameter is run.
LayerType
Specifies the layer type.
override string ToString()
Returns a string representation of the LayerParameter.
Specifies the parameters use to create a Net
Definition: NetParameter.cs:18
static Dictionary< string, BlobShape > InputFromProto(RawProto rp)
Collect the inputs from the RawProto.
Specifies the parameter for the Text data layer.
bool enable_reverse_encoder_output
When enabled, the reverse ordered encoder data is output (default = true).
uint sample_size
Specifies the sample size to select from the data sources.
uint time_steps
Specifies the maximum length for each encoder input.
bool shuffle
Specifies the whether to shuffle the data or now.
bool enable_normal_encoder_output
When enabled, the normal ordered encoder data is output (default = true).
string decoder_source
Specifies the decoder data source.
virtual uint batch_size
Specifies the batch size.
string encoder_source
Specifies the encoder data source.
The MyCaffe.basecode contains all generic types used throughout MyCaffe.
Definition: Annotation.cs:12
Phase
Defines the Phase under which to run a Net.
Definition: Interfaces.cs:61
SPECIAL_TOKENS
Specifies the special tokens.
Definition: Interfaces.cs:15
The MyCaffe.common namespace contains common MyCaffe classes.
Definition: BatchInput.cs:8
The MyCaffe.fillers namespace contains all fillers including the Filler class.
The MyCaffe.layers.beta namespace contains all beta stage layers.
Definition: LayerFactory.cs:9
The MyCaffe.layers namespace contains all layers that have a solidified code base,...
Definition: LayerFactory.cs:15
The MyCaffe.param namespace contains parameters used to create models.
The MyCaffe namespace contains the main body of MyCaffe code that closesly tracks the C++ Caffe open-...
Definition: Annotation.cs:12