MyCaffe  1.12.2.41
Deep learning software for Windows C# programmers.
TrainerRNN.cs
1using System;
2using System.Collections.Generic;
3using System.Diagnostics;
4using System.Drawing;
5using System.Linq;
6using System.Text;
7using System.Threading;
8using System.Threading.Tasks;
9using MyCaffe.basecode;
10using MyCaffe.common;
11using MyCaffe.fillers;
12using MyCaffe.layers;
13using MyCaffe.param;
14using MyCaffe.solvers;
15
17{
26 public class TrainerRNN<T> : IxTrainerRNN, IDisposable
27 {
28 IxTrainerCallback m_icallback;
29 MyCaffeControl<T> m_mycaffe;
30 PropertySet m_properties;
31 CryptoRandom m_random;
32 BucketCollection m_rgVocabulary = null;
33 bool m_bUsePreloadData = true;
34
35
48 public TrainerRNN(MyCaffeControl<T> mycaffe, PropertySet properties, CryptoRandom random, IxTrainerCallback icallback, BucketCollection rgVocabulary)
49 {
50 m_icallback = icallback;
51 m_mycaffe = mycaffe;
52 m_properties = properties;
53 m_random = random;
54 m_rgVocabulary = rgVocabulary;
55 m_bUsePreloadData = properties.GetPropertyAsBool("UsePreLoadData", true); ;
56 }
57
61 public void Dispose()
62 {
63 }
64
69 public bool Initialize()
70 {
71 m_mycaffe.CancelEvent.Reset();
72 m_icallback.OnInitialize(new InitializeArgs(m_mycaffe));
73 return true;
74 }
75
76 private void wait(int nWait)
77 {
78 int nWaitInc = 250;
79 int nTotalWait = 0;
80
81 while (nTotalWait < nWait)
82 {
83 m_icallback.OnWait(new WaitArgs(nWaitInc));
84 nTotalWait += nWaitInc;
85 }
86 }
87
93 public bool Shutdown(int nWait)
94 {
95 if (m_mycaffe != null)
96 {
97 m_mycaffe.CancelEvent.Set();
98 wait(nWait);
99 }
100
101 m_icallback.OnShutdown();
102
103 return true;
104 }
105
106
113 public float[] Run(int nN, PropertySet runProp)
114 {
115 m_mycaffe.CancelEvent.Reset();
116 Agent<T> agent = new Agent<T>(m_icallback, m_mycaffe, m_properties, m_random, Phase.RUN, m_rgVocabulary, m_bUsePreloadData, runProp);
117 float[] rgResults = agent.Run(nN);
118 agent.Dispose();
119
120 return rgResults;
121 }
122
130 public byte[] Run(int nN, PropertySet runProp, out string type)
131 {
132 m_mycaffe.CancelEvent.Reset();
133 Agent<T> agent = new Agent<T>(m_icallback, m_mycaffe, m_properties, m_random, Phase.RUN, m_rgVocabulary, m_bUsePreloadData, runProp);
134 byte[] rgResults = agent.Run(nN, out type);
135 agent.Dispose();
136
137 return rgResults;
138 }
139
146 public bool Test(int nN, ITERATOR_TYPE type)
147 {
148 int nDelay = 1000;
149
150 m_mycaffe.CancelEvent.Reset();
151 Agent<T> agent = new Agent<T>(m_icallback, m_mycaffe, m_properties, m_random, Phase.TEST, m_rgVocabulary, m_bUsePreloadData);
152 agent.Run(Phase.TEST, nN, type, TRAIN_STEP.NONE);
153
154 agent.Dispose();
155 Shutdown(nDelay);
156
157 return true;
158 }
159
167 public bool Train(int nN, ITERATOR_TYPE type, TRAIN_STEP step)
168 {
169 m_mycaffe.CancelEvent.Reset();
170 Agent<T> agent = new Agent<T>(m_icallback, m_mycaffe, m_properties, m_random, Phase.TRAIN, m_rgVocabulary, m_bUsePreloadData);
171 agent.Run(Phase.TRAIN, nN, type, step);
172
173 agent.Dispose();
174
175 return false;
176 }
177 }
178
179 class Agent<T> : IDisposable
180 {
181 IxTrainerCallback m_icallback;
182 Brain<T> m_brain;
183 PropertySet m_properties;
184 CryptoRandom m_random;
185
186 public Agent(IxTrainerCallback icallback, MyCaffeControl<T> mycaffe, PropertySet properties, CryptoRandom random, Phase phase, BucketCollection rgVocabulary, bool bUsePreloadData, PropertySet runProp = null)
187 {
188 m_icallback = icallback;
189 m_brain = new Brain<T>(mycaffe, properties, random, icallback as IxTrainerCallbackRNN, phase, rgVocabulary, bUsePreloadData, runProp);
190 m_properties = properties;
191 m_random = random;
192 }
193
194 public void Dispose()
195 {
196 if (m_brain != null)
197 {
198 m_brain.Dispose();
199 m_brain = null;
200 }
201 }
202
203 private StateBase getData(Phase phase, int nAction)
204 {
205 GetDataArgs args = m_brain.getDataArgs(phase, 0, nAction, true);
206 m_icallback.OnGetData(args);
207 return args.State;
208 }
209
210
222 public void Run(Phase phase, int nN, ITERATOR_TYPE type, TRAIN_STEP step)
223 {
224 if (type != ITERATOR_TYPE.ITERATION)
225 throw new Exception("The TrainerRNN only supports the ITERATION type.");
226
227 StateBase s = getData(phase, -1);
228
229 while (!m_brain.Cancel.WaitOne(0) && !s.Done)
230 {
231 if (phase == Phase.TEST)
232 m_brain.Test(s, nN);
233 else if (phase == Phase.TRAIN)
234 m_brain.Train(s, nN, step);
235
236 s = getData(phase, 1);
237 }
238 }
239
245 public float[] Run(int nN)
246 {
247 return m_brain.Run(nN);
248 }
249
256 public byte[] Run(int nN, out string type)
257 {
258 float[] rgResults = m_brain.Run(nN);
259
260 ConvertOutputArgs args = new ConvertOutputArgs(nN, rgResults);
261 IxTrainerCallbackRNN icallback = m_icallback as IxTrainerCallbackRNN;
262 if (icallback == null)
263 throw new Exception("The Run method requires an IxTrainerCallbackRNN interface to convert the results into the native format!");
264
265 icallback.OnConvertOutput(args);
266
267 type = args.RawType;
268 return args.RawOutput;
269 }
270 }
271
272 class Brain<T> : IDisposable
273 {
274 IxTrainerCallbackRNN m_icallback;
275 MyCaffeControl<T> m_mycaffe;
276 Net<T> m_net;
277 Solver<T> m_solver;
278 PropertySet m_properties;
279 PropertySet m_runProperties = null;
280 Blob<T> m_blobData;
281 Blob<T> m_blobClip;
282 Blob<T> m_blobLabel;
283 Blob<T> m_blobOutput = null;
284 int m_nSequenceLength;
285 int m_nSequenceLengthLabel;
286 int m_nBatchSize;
287 int m_nVocabSize = 1;
288 CryptoRandom m_random;
289 T[] m_rgDataInput;
290 T[] m_rgLabelInput;
291 T m_tZero;
292 T m_tOne;
293 double m_dfRunTemperature = 0;
294 double m_dfTestTemperature = 0;
295 byte[] m_rgTestData = null;
296 byte[] m_rgTrainData = null;
297 float[] m_rgfTestData = null;
298 float[] m_rgfTrainData = null;
299 bool m_bIsDataReal = false;
300 Stopwatch m_sw = new Stopwatch();
301 double m_dfLastLoss = 0;
302 double m_dfLastLearningRate = 0;
303 BucketCollection m_rgVocabulary = null;
304 bool m_bUsePreloadData = true;
305 bool m_bDisableVocabulary = false;
306 Phase m_phaseOnRun = Phase.NONE;
308 int m_nSolverSequenceLength = -1;
309 int m_nThreads = 1;
310 DataCollectionPool m_dataPool = new DataCollectionPool();
311 double m_dfScale = 1.0;
312
313 public Brain(MyCaffeControl<T> mycaffe, PropertySet properties, CryptoRandom random, IxTrainerCallbackRNN icallback, Phase phase, BucketCollection rgVocabulary, bool bUsePreloadData, PropertySet runProp = null)
314 {
315 string strOutputBlob = null;
316
317 if (runProp != null)
318 m_runProperties = runProp;
319
320 m_icallback = icallback;
321 m_mycaffe = mycaffe;
322 m_properties = properties;
323 m_random = random;
324 m_rgVocabulary = rgVocabulary;
325 m_bUsePreloadData = bUsePreloadData;
326 m_nSolverSequenceLength = m_properties.GetPropertyAsInt("SequenceLength", -1);
327 m_bDisableVocabulary = m_properties.GetPropertyAsBool("DisableVocabulary", false);
328 m_nThreads = m_properties.GetPropertyAsInt("Threads", 1);
329 m_dfScale = m_properties.GetPropertyAsDouble("Scale", 1.0);
330 m_dfTestTemperature = properties.GetPropertyAsDouble("TestTemperature", 0);
331
332 if (m_nThreads > 1)
333 m_dataPool.Initialize(m_nThreads, icallback);
334
335 if (m_runProperties != null)
336 {
337 m_dfRunTemperature = Math.Abs(m_runProperties.GetPropertyAsDouble("Temperature", 0));
338 if (m_dfRunTemperature > 1.0)
339 m_dfRunTemperature = 1.0;
340
341 string strPhaseOnRun = m_runProperties.GetProperty("PhaseOnRun", false);
342 switch (strPhaseOnRun)
343 {
344 case "RUN":
345 m_phaseOnRun = Phase.RUN;
346 break;
347
348 case "TEST":
349 m_phaseOnRun = Phase.TEST;
350 break;
351
352 case "TRAIN":
353 m_phaseOnRun = Phase.TRAIN;
354 break;
355 }
356
357 if (phase == Phase.RUN && m_phaseOnRun != Phase.NONE)
358 {
359 if (m_phaseOnRun != Phase.RUN)
360 m_mycaffe.Log.WriteLine("Warning: Running on the '" + m_phaseOnRun.ToString() + "' network.");
361
362 strOutputBlob = m_runProperties.GetProperty("OutputBlob", false);
363 if (strOutputBlob == null)
364 throw new Exception("You must specify the 'OutputBlob' when Running with a phase other than RUN.");
365
366 strOutputBlob = Utility.Replace(strOutputBlob, '~', ';');
367
368 phase = m_phaseOnRun;
369 }
370 }
371
372 m_net = mycaffe.GetInternalNet(phase);
373 if (m_net == null)
374 {
375 mycaffe.Log.WriteLine("WARNING: Test net does not exist, set test_iteration > 0. Using TRAIN phase instead.");
376 m_net = mycaffe.GetInternalNet(Phase.TRAIN);
377 }
378
379 // Find the first LSTM layer to determine how to load the data.
380 // NOTE: Only LSTM has a special loading order, other layers use the standard N, C, H, W ordering.
381 LSTMLayer<T> lstmLayer = null;
382 LSTMAttentionLayer<T> lstmAttentionLayer = null;
383 LSTMSimpleLayer<T> lstmSimpleLayer = null;
384 foreach (Layer<T> layer1 in m_net.layers)
385 {
386 if (layer1.layer_param.type == LayerParameter.LayerType.LSTM)
387 {
388 lstmLayer = layer1 as LSTMLayer<T>;
389 m_lstmType = layer1.layer_param.type;
390 break;
391 }
392 else if (layer1.layer_param.type == LayerParameter.LayerType.LSTM_ATTENTION)
393 {
394 lstmAttentionLayer = layer1 as LSTMAttentionLayer<T>;
395 m_lstmType = LayerParameter.LayerType.LSTM_ATTENTION;
396 break;
397 }
398 // DEPRECIATED
399 else if (layer1.layer_param.type == LayerParameter.LayerType.LSTM_SIMPLE)
400 {
401 lstmSimpleLayer = layer1 as LSTMSimpleLayer<T>;
402 m_lstmType = LayerParameter.LayerType.LSTM_SIMPLE;
403 break;
404 }
405 }
406
407 if (lstmLayer == null && lstmAttentionLayer == null && lstmSimpleLayer == null)
408 throw new Exception("Could not find the required LSTM or LSTM_ATTENTION or LSTM_SIMPLE layer!");
409
410 if (m_phaseOnRun != Phase.NONE && m_phaseOnRun != Phase.RUN && strOutputBlob != null)
411 {
412 if ((m_blobOutput = m_net.FindBlob(strOutputBlob)) == null)
413 throw new Exception("Could not find the 'Output' layer top named '" + strOutputBlob + "'!");
414 }
415
416 if ((m_blobData = m_net.FindBlob("data")) == null)
417 throw new Exception("Could not find the 'Input' layer top named 'data'!");
418
419 if ((m_blobClip = m_net.FindBlob("clip")) == null)
420 throw new Exception("Could not find the 'Input' layer top named 'clip'!");
421
422 Layer<T> layer = m_net.FindLastLayer(LayerParameter.LayerType.INNERPRODUCT);
423 m_mycaffe.Log.CHECK(layer != null, "Could not find an ending INNERPRODUCT layer!");
424
425 if (!m_bDisableVocabulary)
426 {
427 m_nVocabSize = (int)layer.layer_param.inner_product_param.num_output;
428 if (rgVocabulary != null)
429 m_mycaffe.Log.CHECK_EQ(m_nVocabSize, rgVocabulary.Count, "The vocabulary count = '" + rgVocabulary.Count.ToString() + "' and last inner product output count = '" + m_nVocabSize.ToString() + "' - these do not match but they should!");
430 }
431
432 if (m_lstmType == LayerParameter.LayerType.LSTM || m_lstmType == LayerParameter.LayerType.LSTM_ATTENTION)
433 {
434 m_nSequenceLength = m_blobData.shape(0);
435 m_nBatchSize = m_blobData.shape(1);
436 }
437 else
438 {
439 m_nBatchSize = (int)lstmSimpleLayer.layer_param.lstm_simple_param.batch_size;
440 m_nSequenceLength = m_blobData.shape(0) / m_nBatchSize;
441
442 if (phase == Phase.RUN)
443 {
444 m_nBatchSize = 1;
445
446 List<int> rgNewShape = new List<int>() { m_nSequenceLength, 1 };
447 m_blobData.Reshape(rgNewShape);
448 m_blobClip.Reshape(rgNewShape);
449 m_net.Reshape();
450 }
451 }
452
453 m_mycaffe.Log.CHECK_EQ(m_nSequenceLength, m_blobData.num, "The data num must equal the sequence lengh of " + m_nSequenceLength.ToString());
454
455 m_rgDataInput = new T[m_nSequenceLength * m_nBatchSize];
456
457 T[] rgClipInput = new T[m_nSequenceLength * m_nBatchSize];
458 m_mycaffe.Log.CHECK_EQ(rgClipInput.Length, m_blobClip.count(), "The clip count must equal the sequence length * batch size: " + rgClipInput.Length.ToString());
459 m_tZero = (T)Convert.ChangeType(0, typeof(T));
460 m_tOne = (T)Convert.ChangeType(1, typeof(T));
461
462 for (int i = 0; i < rgClipInput.Length; i++)
463 {
464 if (m_lstmType == LayerParameter.LayerType.LSTM || m_lstmType == LayerParameter.LayerType.LSTM_ATTENTION)
465 rgClipInput[i] = (i < m_nBatchSize) ? m_tZero : m_tOne;
466 else
467 rgClipInput[i] = (i % m_nSequenceLength == 0) ? m_tZero : m_tOne;
468 }
469
470 m_blobClip.mutable_cpu_data = rgClipInput;
471
472 if (phase != Phase.RUN)
473 {
474 m_solver = mycaffe.GetInternalSolver();
475 m_solver.OnStart += m_solver_OnStart;
476 m_solver.OnTestStart += m_solver_OnTestStart;
477 m_solver.OnTestResults += m_solver_OnTestResults;
478 m_solver.OnTestingIteration += m_solver_OnTestingIteration;
479 m_solver.OnTrainingIteration += m_solver_OnTrainingIteration;
480
481 if ((m_blobLabel = m_net.FindBlob("label")) == null)
482 throw new Exception("Could not find the 'Input' layer top named 'label'!");
483
484 m_nSequenceLengthLabel = m_blobLabel.count(0, 2);
485 m_rgLabelInput = new T[m_nSequenceLengthLabel];
486 m_mycaffe.Log.CHECK_EQ(m_rgLabelInput.Length, m_blobLabel.count(), "The label count must equal the label sequence length * batch size: " + m_rgLabelInput.Length.ToString());
487 m_mycaffe.Log.CHECK(m_nSequenceLengthLabel == m_nSequenceLength * m_nBatchSize || m_nSequenceLengthLabel == 1, "The label sqeuence length must be 1 or equal the length of the sequence: " + m_nSequenceLength.ToString());
488 }
489 }
490
491 private void m_solver_OnTrainingIteration(object sender, TrainingIterationArgs<T> e)
492 {
493 if (m_sw.Elapsed.TotalMilliseconds > 1000)
494 {
495 m_dfLastLoss = e.SmoothedLoss;
496 m_dfLastLearningRate = e.LearningRate;
497 updateStatus(e.Iteration, m_solver.MaximumIteration, e.Accuracy, e.SmoothedLoss, e.LearningRate);
498 m_sw.Restart();
499 }
500 }
501
502 private void m_solver_OnTestingIteration(object sender, TestingIterationArgs<T> e)
503 {
504 if (m_sw.Elapsed.TotalMilliseconds > 1000)
505 {
506 updateStatus(e.Iteration, m_solver.MaximumIteration, e.Accuracy, m_dfLastLoss, m_dfLastLearningRate);
507 m_sw.Restart();
508 }
509 }
510
511 private void dispose(ref Blob<T> b)
512 {
513 if (b != null)
514 {
515 b.Dispose();
516 b = null;
517 }
518 }
519
520 public void Dispose()
521 {
522 if (m_dataPool != null)
523 {
524 m_dataPool.Shutdown();
525 m_dataPool = null;
526 }
527 }
528
529 private void updateStatus(int nIteration, int nMaxIteration, double dfAccuracy, double dfLoss, double dfLearningRate)
530 {
531 GetStatusArgs args = new GetStatusArgs(0, nIteration, nIteration, nMaxIteration, dfAccuracy, 0, 0, 0, dfLoss, dfLearningRate);
532 m_icallback.OnUpdateStatus(args);
533 }
534
535 public GetDataArgs getDataArgs(Phase phase, int nIdx, int nAction, bool bGetLabel = false, int nBatchSize = 1)
536 {
537 bool bReset = (nAction == -1) ? true : false;
538 return new GetDataArgs(phase, nIdx, m_mycaffe, m_mycaffe.Log, m_mycaffe.CancelEvent, bReset, nAction, false, bGetLabel, (nBatchSize > 1) ? true : false);
539 }
540
541 public Log Log
542 {
543 get { return m_mycaffe.Log; }
544 }
545
546 public CancelEvent Cancel
547 {
548 get { return m_mycaffe.CancelEvent; }
549 }
550
551 private void copyData(SimpleDatum sd, int nSrcOffset, float[] rgfDst, int nCount)
552 {
553 int nDim = sd.Height * sd.Width;
554 float[] rgfSrc = sd.GetData<float>();
555
556 if (nDim == 0)
557 {
558 Array.Copy(rgfSrc, nSrcOffset, rgfDst, 0, nCount);
559 }
560 else
561 {
562 for (int i = 0; i < nCount; i++)
563 {
564 rgfDst[i] = rgfSrc[(nSrcOffset + i) * nDim];
565 }
566 }
567 }
568
569 private void getRawData(StateBase s)
570 {
571 int nTestLen = (int)(s.Data.Channels * s.TestingPercent);
572 int nTrainLen = s.Data.Channels - nTestLen;
573
574 if (s.Data.IsRealData)
575 {
576 m_bIsDataReal = true;
577
578 m_rgfTrainData = new float[nTrainLen];
579 copyData(s.Data, 0, m_rgfTrainData, nTrainLen);
580
581 if (nTestLen == 0)
582 {
583 m_rgfTestData = m_rgfTrainData;
584 }
585 else
586 {
587 m_rgfTestData = new float[nTestLen];
588 copyData(s.Data, nTrainLen, m_rgfTestData, nTestLen);
589 }
590 }
591 else
592 {
593 int nDim = s.Data.Height * s.Data.Width;
594 if (nDim != 1)
595 throw new Exception("When training on binary data the height and width must = 1.");
596
597 m_bIsDataReal = false;
598 m_rgTrainData = new byte[nTrainLen];
599 Array.Copy(s.Data.ByteData, 0, m_rgTrainData, 0, nTrainLen);
600
601 if (nTestLen == 0)
602 {
603 m_rgTestData = m_rgTrainData;
604 }
605 else
606 {
607 m_rgTestData = new byte[nTestLen];
608 Array.Copy(s.Data.ByteData, nTrainLen, m_rgTestData, 0, nTestLen);
609 }
610 }
611 }
612
613 public void Test(StateBase s, int nIterations)
614 {
615 if (nIterations <= 0)
616 {
617 nIterations = 20;
618
619 if (m_solver.parameter.test_iter.Count > 0)
620 nIterations = m_solver.parameter.test_iter[0];
621 }
622
623 getRawData(s);
624 m_sw.Start();
625 m_solver.TestAll(nIterations);
626 }
627
628 private void m_solver_OnTestStart(object sender, EventArgs e)
629 {
630 FeedNet(false);
631 }
632
640 private int getLabel(float[] rgfScores, int nIdx, int nDim)
641 {
642 float[] rgfLastScores = new float[nDim];
643 int nStartIdx = nIdx * nDim;
644
645 for (int i = 0; i < nDim; i++)
646 {
647 rgfLastScores[i] = rgfScores[nStartIdx + i];
648 }
649
650 return getLastPrediction(rgfLastScores, m_dfTestTemperature);
651 }
652
653 private void m_solver_OnTestResults(object sender, TestResultArgs<T> e)
654 {
655 if (e.Results.Count != 2)
656 return;
657
658 int nNum = e.Results[1].num;
659 if (nNum != m_rgLabelInput.Length)
660 return;
661
662 int nDim = e.Results[1].count(1);
663
664 float[] rgfScores = Utility.ConvertVecF<T>(e.Results[1].mutable_cpu_data);
665 int nCorrectCount = 0;
666 for (int i = 0; i < m_nBatchSize; i++)
667 {
668 int nIdx = (m_lstmType == LayerParameter.LayerType.LSTM_SIMPLE) ? (i * m_nSequenceLength + m_nSequenceLength - 1) : ((nNum - m_nBatchSize) + i);
669 int nExpectedLabel = (int)Utility.ConvertVal<T>(m_rgLabelInput[nIdx]);
670 int nActualLabel = getLabel(rgfScores, nIdx, nDim);
671 bool bHandled = false;
672
673 TestAccuracyUpdateArgs args = new TestAccuracyUpdateArgs(nActualLabel, nExpectedLabel);
674 m_icallback.OnTestAccuracyUpdate(args);
675 if (args.Handled)
676 {
677 if (args.IsCorrect)
678 nCorrectCount++;
679 bHandled = true;
680 }
681
682 if (!bHandled)
683 {
684 if (nExpectedLabel == nActualLabel)
685 nCorrectCount++;
686 }
687 }
688
689 e.Accuracy = (double)nCorrectCount / m_nBatchSize;
690 }
691
692 public void Train(StateBase s, int nIterations, TRAIN_STEP step)
693 {
694 if (nIterations <= 0)
695 nIterations = m_solver.parameter.max_iter;
696
697 getRawData(s);
698 m_sw.Start();
699 m_solver.Solve(nIterations, null, null, step);
700 }
701
702 private void m_solver_OnStart(object sender, EventArgs e)
703 {
704 FeedNet(true);
705 }
706
707 public void FeedNet(bool bTrain)
708 {
709 bool bFound;
710 int nIdx;
711 Phase phase = (bTrain) ? Phase.TRAIN : Phase.TEST;
712
713 // Real Data
714 if (m_bIsDataReal)
715 {
716 if (m_bUsePreloadData)
717 {
718 float[] rgfData = (bTrain) ? m_rgfTrainData : m_rgfTestData;
719
720 // Re-order the data according to caffe input specification for LSTM layer.
721 for (int i = 0; i < m_nBatchSize; i++)
722 {
723 int nCurrentValIdx = m_random.Next(rgfData.Length - m_nSequenceLength - 1);
724
725 for (int j = 0; j < m_nSequenceLength; j++)
726 {
727 // Feed the net with input data and labels (clips are always the same)
728 double dfData = rgfData[nCurrentValIdx + j];
729 // Labels are the same with an offset of +1
730 double dfLabel = rgfData[nCurrentValIdx + j + 1]; // predict next value
731 float fDataIdx = findIndex(dfData, out bFound);
732 float fLabelIdx = findIndex(dfLabel, out bFound);
733
734 // LSTM or LSTM_ATTENTION: Create input data, the data must be in the order
735 // seq1_val1, seq2_val1, ..., seqBatch_Size_val1, seq1_val2, seq2_val2, ..., seqBatch_Size_valSequence_Length
736 if (m_lstmType == LayerParameter.LayerType.LSTM || m_lstmType == LayerParameter.LayerType.LSTM_ATTENTION)
737 nIdx = m_nBatchSize * j + i;
738
739 // [DEPRECIATED] LSTM_SIMPLE: Create input data, the data must be in the order
740 // seq1_val1, seq1_val2, ..., seq1_valBatchSize, seq2_val1, seq2_val2, ..., seqSequenceLength_valBatchSize
741 else
742 nIdx = i * m_nBatchSize + j;
743
744 m_rgDataInput[nIdx] = (T)Convert.ChangeType(fDataIdx, typeof(T));
745
746 if (m_nSequenceLengthLabel == (m_nSequenceLength * m_nBatchSize) || j == m_nSequenceLength - 1)
747 m_rgLabelInput[nIdx] = (T)Convert.ChangeType(fLabelIdx, typeof(T));
748 }
749 }
750
751 m_blobData.mutable_cpu_data = m_rgDataInput;
752 m_blobLabel.mutable_cpu_data = m_rgLabelInput;
753 }
754 else
755 {
756 m_mycaffe.Log.CHECK_EQ(m_nBatchSize, m_nThreads, "The 'Threads' setting of " + m_nThreads.ToString() + " must match the batch size = " + m_nBatchSize.ToString() + "!");
757
758 List<GetDataArgs> rgDataArgs = new List<GetDataArgs>();
759
760 if (m_nBatchSize == 1)
761 {
762 GetDataArgs e = getDataArgs(phase, 0, 0, true, m_nBatchSize);
763 m_icallback.OnGetData(e);
764 rgDataArgs.Add(e);
765 }
766 else
767 {
768 for (int i = 0; i < m_nBatchSize; i++)
769 {
770 rgDataArgs.Add(getDataArgs(phase, i, 0, true, m_nBatchSize));
771 }
772
773 if (!m_dataPool.Run(rgDataArgs))
774 m_mycaffe.Log.FAIL("Data Time Out - Failed to collect all data to build the RNN batch!");
775 }
776
777 double[] rgData = rgDataArgs[0].State.Data.GetData<double>();
778 double[] rgLabel = rgDataArgs[0].State.Label.GetData<double>();
779 double[] rgClip = rgDataArgs[0].State.Clip.GetData<double>();
780
781 int nDataLen = rgData.Length;
782 int nLabelLen = rgLabel.Length;
783 int nClipLen = rgClip.Length;
784 int nDataItem = nDataLen / nLabelLen;
785
786 if (m_nBatchSize > 1)
787 {
788 rgData = new double[nDataLen * m_nBatchSize];
789 rgLabel = new double[nLabelLen * m_nBatchSize];
790 rgClip = new double[nClipLen * m_nBatchSize];
791
792 for (int i = 0; i < m_nBatchSize; i++)
793 {
794 for (int j = 0; j < m_nSequenceLength; j++)
795 {
796 // LSTM or LSTM_ATTENTION: Create input data, the data must be in the order
797 // seq1_val1, seq2_val1, ..., seqBatch_Size_val1, seq1_val2, seq2_val2, ..., seqBatch_Size_valSequence_Length
798 if (m_lstmType == LayerParameter.LayerType.LSTM || m_lstmType == LayerParameter.LayerType.LSTM_ATTENTION)
799 nIdx = m_nBatchSize * j + i;
800
801 // [DEPRECIATED] LSTM_SIMPLE: Create input data, the data must be in the order
802 // seq1_val1, seq1_val2, ..., seq1_valBatchSize, seq2_val1, seq2_val2, ..., seqSequenceLength_valBatchSize
803 else
804 nIdx = i * m_nBatchSize + j;
805
806 Array.Copy(rgDataArgs[i].State.Data.GetData<double>(), 0, rgData, nIdx * nDataItem, nDataItem);
807 rgLabel[nIdx] = rgDataArgs[i].State.Label.GetDataAtD(j);
808 rgClip[nIdx] = rgDataArgs[i].State.Clip.GetDataAtD(j);
809 }
810 }
811 }
812
813 string strSolverErr = "";
814 if (m_nSolverSequenceLength >= 0 && m_nSolverSequenceLength != m_nSequenceLength)
815 strSolverErr = "The solver parameter 'SequenceLength' length of " + m_nSolverSequenceLength.ToString() + " must match the model sequence length of " + m_nSequenceLength.ToString() + ". ";
816
817 int nExpectedCount = m_blobData.count();
818 m_mycaffe.Log.CHECK_EQ(nExpectedCount, rgData.Length, strSolverErr + "The size of the data received ('" + rgData.Length.ToString() + "') does mot match the expected data count of '" + nExpectedCount.ToString() + "'!");
819 m_blobData.mutable_cpu_data = Utility.ConvertVec<T>(rgData);
820
821 nExpectedCount = m_blobLabel.count();
822 m_mycaffe.Log.CHECK_EQ(nExpectedCount, rgLabel.Length, strSolverErr + "The size of the label received ('" + rgLabel.Length.ToString() + "') does not match the expected label count of '" + nExpectedCount.ToString() + "'!");
823 m_blobLabel.mutable_cpu_data = Utility.ConvertVec<T>(rgLabel);
824
825 nExpectedCount = m_blobClip.count();
826 m_mycaffe.Log.CHECK_EQ(nExpectedCount, rgClip.Length, strSolverErr + "The size of the clip received ('" + rgClip.Length.ToString() + "') does not match the expected clip count of '" + nExpectedCount.ToString() + "'!");
827 m_blobClip.mutable_cpu_data = Utility.ConvertVec<T>(rgClip);
828 }
829 }
830 // Byte Data (uses a vocabulary if available)
831 else
832 {
833 byte[] rgData = (bTrain) ? m_rgTrainData : m_rgTestData;
834 // Create input data, the data must be in the order
835 // seq1_char1, seq2_char1, ..., seqBatch_Size_char1, seq1_char2, seq2_char2, ..., seqBatch_Size_charSequence_Length
836 // As seq1_charSequence_Length == seq2_charSequence_Length-1 == seq3_charSequence_Length-2 == ... we can perform block copy for efficientcy.
837 // Labels are the same with an offset of +1
838
839 // Re-order the data according to caffe input specification for LSTM layer.
840 for (int i = 0; i < m_nBatchSize; i++)
841 {
842 int nCurrentCharIdx = m_random.Next(rgData.Length - m_nSequenceLength - 2);
843
844 for (int j = 0; j < m_nSequenceLength; j++)
845 {
846 // Feed the net with input data and labels (clips are always the same)
847 byte bData = rgData[nCurrentCharIdx + j];
848 // Labels are the same with an offset of +1
849 byte bLabel = rgData[nCurrentCharIdx + j + 1]; // predict next character
850 float fDataIdx = findIndex(bData, out bFound);
851 float fLabelIdx = findIndex(bLabel, out bFound);
852
853 // LSTM or LSTM_ATTENTION: Create input data, the data must be in the order
854 // seq1_val1, seq2_val1, ..., seqBatch_Size_val1, seq1_val2, seq2_val2, ..., seqBatch_Size_valSequence_Length
855 if (m_lstmType == LayerParameter.LayerType.LSTM || m_lstmType == LayerParameter.LayerType.LSTM_ATTENTION)
856 nIdx = m_nBatchSize * j + i;
857
858 // [DEPRECIATED] LSTM_SIMPLE: Create input data, the data must be in the order
859 // seq1_val1, seq1_val2, ..., seq1_valBatchSize, seq2_val1, seq2_val2, ..., seqSequenceLength_valBatchSize
860 else
861 nIdx = i * m_nBatchSize + j;
862
863 m_rgDataInput[nIdx] = (T)Convert.ChangeType(fDataIdx, typeof(T));
864
865 if (m_nSequenceLengthLabel == (m_nSequenceLength * m_nBatchSize) || j == m_nSequenceLength - 1)
866 m_rgLabelInput[nIdx] = (T)Convert.ChangeType(fLabelIdx, typeof(T));
867 }
868 }
869
870 m_blobData.mutable_cpu_data = m_rgDataInput;
871 m_blobLabel.mutable_cpu_data = m_rgLabelInput;
872 }
873 }
874
875 private float findIndex(byte b, out bool bFound)
876 {
877 bFound = false;
878
879 if (m_rgVocabulary == null || m_bDisableVocabulary)
880 return b;
881
882 bFound = true;
883
884 return m_rgVocabulary.FindIndex(b);
885 }
886
887 private float findIndex(double df, out bool bFound)
888 {
889 bFound = false;
890
891 if (m_rgVocabulary == null || m_bDisableVocabulary)
892 return (float)df;
893
894 return m_rgVocabulary.FindIndex(df);
895 }
896
897 private List<T> getInitialInput(bool bIsReal)
898 {
899 List<T> rgInput = new List<T>();
900 float[] rgCorrectLengthSequence = new float[m_nSequenceLength];
901
902 for (int i = 0; i < m_nSequenceLength; i++)
903 {
904 rgCorrectLengthSequence[i] = (int)m_random.Next(m_nVocabSize);
905 }
906
907 // If a seed is specified, add it to the end of the sequence.
908 bool bDataNeeded = true;
909 if (!bIsReal && m_runProperties != null)
910 {
911 byte[] rgSeed = m_runProperties.GetPropertyBlob("Seed", false);
912
913 if (rgSeed != null && rgSeed.Length > 0)
914 {
915 int nLen = rgSeed.Length;
916 if (rgSeed[nLen - 1] == 0)
917 nLen--;
918
919 int nStart = rgCorrectLengthSequence.Length - nLen;
920 if (nStart < 0)
921 nStart = 0;
922
923 for (int i = nStart; i < rgCorrectLengthSequence.Length; i++)
924 {
925 byte bVal = rgSeed[i - nStart];
926 bool bFound;
927 int nIdx = (int)findIndex(bVal, out bFound);
928
929 if (bFound)
930 rgCorrectLengthSequence[i] = nIdx;
931 }
932
933 bDataNeeded = false;
934 }
935 }
936
937 if (bDataNeeded && m_runProperties != null)
938 {
939 GetDataArgs e = getDataArgs(Phase.RUN, 0, 0, false, m_nSequenceLength);
940 e.ExtraProperties = m_runProperties;
941 e.ExtraProperties.SetProperty("DataCountRequested", m_nSequenceLength.ToString());
942 m_icallback.OnGetData(e);
943
944 if (e.State.Data != null)
945 {
946 float[] rgf = e.State.Data.GetData<float>();
947 int nDim = e.State.Data.Height * e.State.Data.Width;
948
949 //if (e.State.Data.Channels != rgCorrectLengthSequence.Length)
950 // throw new Exception("The data length received is incorrect!");
951
952 for (int i = 0; i < rgCorrectLengthSequence.Length; i++)
953 {
954 float fChar = rgf[i * nDim];
955 // Tokenize
956 fChar = m_rgVocabulary.FindIndex(fChar);
957 rgCorrectLengthSequence[i] = fChar;
958 }
959
960 bDataNeeded = false;
961 }
962 }
963
964 if (bDataNeeded)
965 m_mycaffe.Log.WriteLine("WARNING: No seed data found - using random data.");
966
967 for (int i = 0; i < rgCorrectLengthSequence.Length; i++)
968 {
969 rgInput.Add((T)Convert.ChangeType(rgCorrectLengthSequence[i], typeof(T)));
970 }
971
972 return rgInput;
973 }
974
975 public float[] Run(int nN)
976 {
977 try
978 {
979 Stopwatch sw = new Stopwatch();
980 float[] rgPredictions = new float[nN];
981
982 sw.Start();
983
984 m_bIsDataReal = true;
985
986 if (m_rgVocabulary != null)
987 m_bIsDataReal = m_rgVocabulary.IsDataReal;
988
989 m_mycaffe.Log.Enable = false;
990
991 if (m_bIsDataReal && !m_bUsePreloadData)
992 {
993 string strSolverErr = "";
994 int nLookahead = 1;
995 if (m_nSolverSequenceLength >= 0 && m_nSolverSequenceLength < m_nSequenceLength)
996 nLookahead = m_nSequenceLength - m_nSolverSequenceLength;
997
998 rgPredictions = new float[nN * 2 * nLookahead];
999
1000 for (int i = 0; i < nN; i++)
1001 {
1002 GetDataArgs e = getDataArgs(Phase.RUN, 0, 0, true);
1003 m_icallback.OnGetData(e);
1004
1005 int nExpectedCount = m_blobData.count();
1006 m_mycaffe.Log.CHECK_EQ(nExpectedCount, e.State.Data.ItemCount, strSolverErr + "The size of the data received ('" + e.State.Data.ItemCount.ToString() + "') does mot match the expected data count of '" + nExpectedCount.ToString() + "'!");
1007 m_blobData.mutable_cpu_data = e.State.Data.GetData<T>();
1008
1009 if (m_blobLabel != null)
1010 {
1011 nExpectedCount = m_blobLabel.count();
1012 m_mycaffe.Log.CHECK_EQ(nExpectedCount, e.State.Label.ItemCount, strSolverErr + "The size of the label received ('" + e.State.Label.ItemCount.ToString() + "') does not match the expected label count of '" + nExpectedCount.ToString() + "'!");
1013 m_blobLabel.mutable_cpu_data = e.State.Label.GetData<T>();
1014 }
1015
1016 double dfLoss;
1017 BlobCollection<T> colResults = m_net.Forward(out dfLoss);
1018 Blob<T> blobOutput = colResults[0];
1019
1020 if (m_blobOutput != null)
1021 blobOutput = m_blobOutput;
1022
1023 float[] rgResults = Utility.ConvertVecF<T>(blobOutput.update_cpu_data());
1024
1025 for (int j = nLookahead; j > 0; j--)
1026 {
1027 float fPrediction = getLastPrediction(rgResults, m_rgVocabulary, j);
1028 int nIdx = e.State.Label.ItemCount - j;
1029 float fActual = (float)e.State.Label.GetDataAtF(nIdx);
1030
1031 int nIdx0 = ((nLookahead - j) * nN * 2);
1032 int nIdx1 = nIdx0 + nN;
1033
1034 if (m_dfScale != 1.0 && m_dfScale > 0)
1035 fActual /= (float)m_dfScale;
1036
1037 if (m_rgVocabulary == null || m_bDisableVocabulary)
1038 {
1039 if (m_dfScale != 1.0 && m_dfScale > 0)
1040 fPrediction /= (float)m_dfScale;
1041
1042 rgPredictions[nIdx0 + i] = fPrediction;
1043 rgPredictions[nIdx1 + i] = fActual;
1044 }
1045 else
1046 {
1047 rgPredictions[nIdx0 + i] = (float)m_rgVocabulary.GetValueAt((int)fPrediction, true);
1048 rgPredictions[nIdx1 + i] = (float)m_rgVocabulary.GetValueAt((int)fActual, true);
1049 }
1050 }
1051
1052 if (sw.Elapsed.TotalMilliseconds > 1000)
1053 {
1054 double dfPct = (double)i / (double)nN;
1055 m_mycaffe.Log.Enable = true;
1056 m_mycaffe.Log.Progress = dfPct;
1057 m_mycaffe.Log.WriteLine("Running at " + dfPct.ToString("P") + " complete...");
1058 m_mycaffe.Log.Enable = false;
1059 sw.Restart();
1060 }
1061
1062 if (m_mycaffe.CancelEvent.WaitOne(0))
1063 break;
1064 }
1065 }
1066 else
1067 {
1068 int nIdx = 0;
1069 List<T> rgInput = getInitialInput(m_bIsDataReal);
1070 Blob<T> blobLossBtm = null;
1071
1072 for (int i = 0; i < nN; i++)
1073 {
1074 T[] rgInputVector = new T[m_blobData.count()];
1075 for (int j = 0; j < m_nSequenceLength; j++)
1076 {
1077 // The batch is filled with 0 except for the first sequence which is the one we want to use for prediction.
1078 nIdx = j * m_nBatchSize;
1079 rgInputVector[nIdx] = rgInput[j];
1080 }
1081
1082 m_blobData.mutable_cpu_data = rgInputVector;
1083
1084 double dfLoss;
1085 BlobCollection<T> colResults = m_net.Forward(out dfLoss);
1086 Blob<T> blobOutput = colResults[0];
1087
1088 if (m_blobOutput != null)
1089 blobOutput = m_blobOutput;
1090
1091 if (blobOutput.type == BLOB_TYPE.LOSS)
1092 {
1093 if (blobLossBtm == null)
1094 {
1095 Blob<T> blob = m_net.FindLossBottomBlob();
1096 if (blob != null)
1097 blobLossBtm = blob;
1098 }
1099
1100 if (blobLossBtm != null)
1101 blobOutput = blobLossBtm;
1102 }
1103
1104 float[] rgResults = Utility.ConvertVecF<T>(blobOutput.update_cpu_data());
1105 float fPrediction = getLastPrediction(rgResults, m_rgVocabulary, 1);
1106
1107 //Add the new prediction and discard the oldest one
1108 rgInput.Add((T)Convert.ChangeType(fPrediction, typeof(T)));
1109 rgInput.RemoveAt(0);
1110
1111 if (m_rgVocabulary == null || m_bDisableVocabulary)
1112 rgPredictions[i] = fPrediction;
1113 else
1114 rgPredictions[i] = (float)m_rgVocabulary.GetValueAt((int)fPrediction);
1115
1116 if (sw.Elapsed.TotalMilliseconds > 1000)
1117 {
1118 double dfPct = (double)i / (double)nN;
1119 m_mycaffe.Log.Enable = true;
1120 m_mycaffe.Log.Progress = dfPct;
1121 m_mycaffe.Log.WriteLine("Running at " + dfPct.ToString("P") + " complete...");
1122 m_mycaffe.Log.Enable = false;
1123 sw.Restart();
1124 }
1125
1126 if (m_mycaffe.CancelEvent.WaitOne(0))
1127 break;
1128 }
1129 }
1130
1131 return rgPredictions;
1132 }
1133 catch (Exception excpt)
1134 {
1135 throw excpt;
1136 }
1137 finally
1138 {
1139 m_mycaffe.Log.Enable = true;
1140 }
1141 }
1142
1143 private float getLastPrediction(float[] rgDataRaw, BucketCollection rgVocabulary, int nLookahead)
1144 {
1145 // Get the probabilities for the last character of the first sequence in the batch
1146 int nOffset = (m_nSequenceLength - nLookahead) * m_nBatchSize * m_nVocabSize;
1147
1148 if (m_bDisableVocabulary)
1149 return rgDataRaw[nOffset];
1150
1151 float[] rgData = new float[m_nVocabSize];
1152
1153 for (int i = 0; i < rgData.Length; i++)
1154 {
1155 rgData[i] = rgDataRaw[nOffset + i];
1156 }
1157
1158 return getLastPrediction(rgData, m_dfRunTemperature);
1159 }
1160
1161 private int getLastPrediction(float[] rgData, double dfTemperature)
1162 {
1163 int nIdx = m_nVocabSize - 1;
1164
1165 // If no temperature, return directly the character with the best score
1166 if (dfTemperature == 0)
1167 {
1168 nIdx = ArgMax(rgData, 0, m_nVocabSize);
1169 }
1170 else
1171 {
1172 // Otherwise, compute the probabilities with the temperature and select the character according to the probabilities.
1173 double[] rgAccumulatedProba = new double[m_nVocabSize];
1174 double[] rgProba = new double[m_nVocabSize];
1175 double dfExpoSum = 0;
1176
1177 double dfMax = rgData.Max();
1178 for (int i = 0; i < m_nVocabSize; i++)
1179 {
1180 // The max value is subtracted for numerical stability
1181 rgProba[i] = Math.Exp((rgData[i] - dfMax) / dfTemperature);
1182 dfExpoSum += rgProba[i];
1183 }
1184
1185 rgProba[0] /= dfExpoSum;
1186 rgAccumulatedProba[0] = rgProba[0];
1187
1188 double dfRandom = m_random.NextDouble();
1189
1190 for (int i = 1; i < rgProba.Length; i++)
1191 {
1192 // Return the first index for which the accumulated probability is bigger than the random number.
1193 if (rgAccumulatedProba[i - 1] > dfRandom)
1194 {
1195 nIdx = i - 1;
1196 break;
1197 }
1198
1199 rgProba[i] /= dfExpoSum;
1200 rgAccumulatedProba[i] = rgAccumulatedProba[i - 1] + rgProba[i];
1201 }
1202 }
1203
1204 if (nIdx < 0 || nIdx > m_nVocabSize)
1205 throw new Exception("Invalid index - out of the vocabulary range of [0," + m_nVocabSize.ToString() + "]");
1206
1207 return nIdx;
1208 }
1209
1210 private int ArgMax(float[] rg, int nOffset, int nCount)
1211 {
1212 if (nCount == 0)
1213 return -1;
1214
1215 int nMaxIdx = nOffset;
1216 float fMax = rg[nOffset];
1217
1218 for (int i = nOffset; i < nOffset + nCount; i++)
1219 {
1220 if (rg[i] > fMax)
1221 {
1222 nMaxIdx = i;
1223 fMax = rg[i];
1224 }
1225 }
1226
1227 return nMaxIdx - nOffset;
1228 }
1229 }
1230
1231 class DataCollectionPool
1232 {
1233 List<DataCollector> m_rgCollectors = new List<DataCollector>();
1234
1235 public DataCollectionPool()
1236 {
1237 }
1238
1239 public void Initialize(int nThreads, IxTrainerCallback icallback)
1240 {
1241 for (int i = 0; i < nThreads; i++)
1242 {
1243 m_rgCollectors.Add(new DataCollector(icallback));
1244 }
1245 }
1246
1247 public void Shutdown()
1248 {
1249 foreach (DataCollector col in m_rgCollectors)
1250 {
1251 col.CleanUp();
1252 }
1253 }
1254
1255 public bool Run(List<GetDataArgs> rgStartup)
1256 {
1257 List<ManualResetEvent> rgWait = new List<ManualResetEvent>();
1258
1259 if (rgStartup.Count != m_rgCollectors.Count)
1260 throw new Exception("The startup count does not match the collector count.");
1261
1262 for (int i = 0; i < rgStartup.Count; i++)
1263 {
1264 rgWait.Add(rgStartup[i].DataReady);
1265 m_rgCollectors[i].Run(rgStartup[i]);
1266 }
1267
1268 return WaitHandle.WaitAll(rgWait.ToArray(), 10000);
1269 }
1270 }
1271
1272 class DataCollector
1273 {
1274 ManualResetEvent m_evtAbort = new ManualResetEvent(false);
1275 AutoResetEvent m_evtRun = new AutoResetEvent(false);
1276 Thread m_thread;
1277 GetDataArgs m_args;
1278 IxTrainerCallback m_icallback;
1279
1280 public DataCollector(IxTrainerCallback icallback)
1281 {
1282 m_icallback = icallback;
1283 m_thread = new Thread(new ThreadStart(doWork));
1284 m_thread.Start();
1285 }
1286
1287 public void CleanUp()
1288 {
1289 m_evtAbort.Set();
1290 }
1291
1292 public void Run(GetDataArgs args)
1293 {
1294 m_args = args;
1295 m_evtRun.Set();
1296 }
1297
1298 private void doWork()
1299 {
1300 bool bDone = false;
1301 List<WaitHandle> rgWait = new List<WaitHandle>();
1302 rgWait.Add(m_evtAbort);
1303 rgWait.Add(m_evtRun);
1304
1305 while (!bDone)
1306 {
1307 int nWait = WaitHandle.WaitAny(rgWait.ToArray());
1308 if (nWait == 0)
1309 return;
1310
1311 m_icallback.OnGetData(m_args);
1312 m_args.DataReady.Set();
1313 }
1314 }
1315 }
1316}
The MyCaffeControl is the main object used to manage all training, testing and running of the MyCaffe...
CancelEvent CancelEvent
Returns the CancelEvent used.
The BucketCollection contains a set of Buckets.
int FindIndex(double dfVal)
Finds the index of the Bucket containing the value.
int Count
Returns the number of Buckets.
bool IsDataReal
Get/set whether or not the Buckets hold Real values.
double GetValueAt(int nIdx, bool bUseMidPoint=false)
Returns the average of the Bucket at a given index.
The CancelEvent provides an extension to the manual cancel event that allows for overriding the manua...
Definition: CancelEvent.cs:17
void Reset()
Resets the event clearing any signaled state.
Definition: CancelEvent.cs:279
CancelEvent()
The CancelEvent constructor.
Definition: CancelEvent.cs:28
void Set()
Sets the event to the signaled state.
Definition: CancelEvent.cs:270
The CryptoRandom is a random number generator that can use either the standard .Net Random objec or t...
Definition: CryptoRandom.cs:14
int Next(int nMinVal, int nMaxVal, bool bMaxInclusive=true)
Returns a random int within the range
double NextDouble()
Returns a random double within the range .
Definition: CryptoRandom.cs:83
The Log class provides general output in text form.
Definition: Log.cs:13
Log(string strSrc)
The Log constructor.
Definition: Log.cs:33
Specifies a key-value pair of properties.
Definition: PropertySet.cs:16
string GetProperty(string strName, bool bThrowExceptions=true)
Returns a property as a string value.
Definition: PropertySet.cs:146
byte[] GetPropertyBlob(string strName, bool bThrowExceptions=true)
Returns a property blob as a byte array value.
Definition: PropertySet.cs:184
int GetPropertyAsInt(string strName, int nDefault=0)
Returns a property as an integer value.
Definition: PropertySet.cs:287
bool GetPropertyAsBool(string strName, bool bDefault=false)
Returns a property as a boolean value.
Definition: PropertySet.cs:267
double GetPropertyAsDouble(string strName, double dfDefault=0)
Returns a property as an double value.
Definition: PropertySet.cs:307
void SetProperty(string strName, string strVal)
Sets a property in the property set to a value if it exists, otherwise it adds the new property.
Definition: PropertySet.cs:211
The SimpleDatum class holds a data input within host memory.
Definition: SimpleDatum.cs:161
void Copy(SimpleDatum d, bool bCopyData, int? nHeight=null, int? nWidth=null)
Copy another SimpleDatum into this one.
int Width
Return the width of the data.
int Height
Return the height of the data.
The Utility class provides general utility funtions.
Definition: Utility.cs:35
static string Replace(string str, char ch1, char ch2)
Replaces each instance of one character with another character in a given string.
Definition: Utility.cs:864
static double[] ConvertVec(float[] rgf)
Convert an array of float to an array of generics.
Definition: Utility.cs:550
The BlobCollection contains a list of Blobs.
The Blob is the main holder of data that moves through the Layers of the Net.
Definition: Blob.cs:25
T[] mutable_cpu_data
Get data from the GPU and bring it over to the host, or Set data from the Host and send it over to th...
Definition: Blob.cs:1461
void Reshape(int nNum, int nChannels, int nHeight, int nWidth, bool? bUseHalfSize=null)
DEPRECIATED; use
Definition: Blob.cs:442
BLOB_TYPE type
Returns the BLOB_TYPE of the Blob.
Definition: Blob.cs:2761
List< int > shape()
Returns an array where each element contains the shape of an axis of the Blob.
Definition: Blob.cs:684
T[] update_cpu_data()
Update the CPU data by transferring the GPU data over to the Host.
Definition: Blob.cs:1470
int count()
Returns the total number of items in the Blob.
Definition: Blob.cs:739
int num
DEPRECIATED; legacy shape accessor num: use shape(0) instead.
Definition: Blob.cs:792
Connects Layer's together into a direct acrylic graph (DAG) specified by a NetParameter
Definition: Net.cs:23
List< Layer< T > > layers
Returns the layers.
Definition: Net.cs:2003
void Reshape()
Reshape all layers from the bottom to the top.
Definition: Net.cs:1800
BlobCollection< T > Forward()
Run forward with the input Blob's already fed separately.
Definition: Net.cs:1445
Layer< T > FindLastLayer(LayerParameter.LayerType type)
Find the last layer with the matching type.
Definition: Net.cs:2806
Blob< T > FindBlob(string strName)
Finds a Blob in the Net by name.
Definition: Net.cs:2592
Blob< T > FindLossBottomBlob()
Find the bottom blob of the Loss layer if it exists, otherwise null is returned.
Definition: Net.cs:2664
The TestResultArgs are passed to the Solver::OnTestResults event.
Definition: EventArgs.cs:116
BlobCollection< T > Results
Returns the results from the test.
Definition: EventArgs.cs:135
double Accuracy
Get/set the accuracy. The recipient of this event should set this value.
Definition: EventArgs.cs:143
Specifies the TestingIterationArgs sent to the Solver::OnTestingIteration, which is called at the end...
Definition: EventArgs.cs:216
double Accuracy
Return the accuracy of the test cycle.
Definition: EventArgs.cs:238
int Iteration
Return the iteration of the test cycle.
Definition: EventArgs.cs:246
The TrainingIterationArgs is sent to the Solver::OnTrainingIteration event that fires at the end of a...
Definition: EventArgs.cs:264
double LearningRate
Return the current learning rate.
Definition: EventArgs.cs:375
double SmoothedLoss
Retunrs the average loss after the training cycle.
Definition: EventArgs.cs:319
The LSTMAttentionLayer adds attention to the long-short term memory layer and is used in encoder/deco...
The LSTMLayer processes sequential inputs using a 'Long Short-Term Memory' (LSTM) [1] style recurrent...
Definition: LSTMLayer.cs:59
[DEPRECIATED - use LSTMAttentionLayer instead with enable_attention = false] The LSTMSimpleLayer is a...
An interface for the units of computation which can be composed into a Net.
Definition: Layer.cs:31
LayerParameter layer_param
Returns the LayerParameter for this Layer.
Definition: Layer.cs:899
uint num_output
The number of outputs for the layer.
uint batch_size
Specifies the batch size, default = 1.
Specifies the base parameter for all layers.
LayerType type
Specifies the type of this LayerParameter.
LSTMSimpleParameter lstm_simple_param
[DEPRECIATED] Returns the parameter set when initialized with LayerType.LSTM_SIMPLE
InnerProductParameter inner_product_param
Returns the parameter set when initialized with LayerType.INNERPRODUCT
LayerType
Specifies the layer type.
int max_iter
The maximum number of iterations.
List< int > test_iter
The number of iterations for each test.
An interface for classes that perform optimization on Nets - this class serves as the base class for ...
Definition: Solver.cs:28
EventHandler< TrainingIterationArgs< T > > OnTrainingIteration
The OnTrainingIteration event fires at the end of each training iteration.
Definition: Solver.cs:134
int MaximumIteration
Returns the maximum training iterations.
Definition: Solver.cs:700
EventHandler< TestingIterationArgs< T > > OnTestingIteration
The OnTestingIteration event fires at the end of each testing iteration.
Definition: Solver.cs:138
SolverParameter parameter
Returns the SolverParameter used.
Definition: Solver.cs:1221
double TestAll(int nIterationOverride=-1)
Run a TestAll by running all test Nets.
Definition: Solver.cs:1322
EventHandler< TestResultArgs< T > > OnTestResults
When specified, the OnTestResults event fires after each single test run. The recipient is responsibl...
Definition: Solver.cs:142
EventHandler OnTestStart
The OnTestStart event fires at the start of each testing iteration.
Definition: Solver.cs:150
virtual void Solve(int nIterationOverride=-1, byte[] rgWeights=null, byte[] rgState=null, TRAIN_STEP step=TRAIN_STEP.NONE)
The main entry of the solver function. In default, iter will be zero. Pass in a non-zero iter number ...
Definition: Solver.cs:744
EventHandler OnStart
The OnStart event fires at the start of each training iteration.
Definition: Solver.cs:118
The InitializeArgs is passed to the OnInitialize event.
Definition: EventArgs.cs:90
The WaitArgs is passed to the OnWait event.
Definition: EventArgs.cs:65
The TrainerRNN implements a simple RNN trainer inspired by adepierre's GitHub site referenced.
Definition: TrainerRNN.cs:27
bool Initialize()
Initialize the trainer.
Definition: TrainerRNN.cs:69
bool Train(int nN, ITERATOR_TYPE type, TRAIN_STEP step)
Train the network using a modified PG training algorithm optimized for GPU use.
Definition: TrainerRNN.cs:167
bool Shutdown(int nWait)
Shutdown the trainer.
Definition: TrainerRNN.cs:93
float[] Run(int nN, PropertySet runProp)
Run a single cycle on the environment after the delay.
Definition: TrainerRNN.cs:113
bool Test(int nN, ITERATOR_TYPE type)
Run the test cycle - currently this is not implemented.
Definition: TrainerRNN.cs:146
void Dispose()
Releases all resources used.
Definition: TrainerRNN.cs:61
TrainerRNN(MyCaffeControl< T > mycaffe, PropertySet properties, CryptoRandom random, IxTrainerCallback icallback, BucketCollection rgVocabulary)
The constructor.
Definition: TrainerRNN.cs:48
byte[] Run(int nN, PropertySet runProp, out string type)
Run a single cycle on the environment after the delay.
Definition: TrainerRNN.cs:130
The IxTrainerCallback provides functions used by each trainer to 'call-back' to the parent for inform...
Definition: Interfaces.cs:303
The IxTrainerCallbackRNN provides functions used by each trainer to 'call-back' to the parent for inf...
Definition: Interfaces.cs:348
The IxTrainerRL interface is implemented by each RL Trainer.
Definition: Interfaces.cs:279
The MyCaffe.basecode contains all generic types used throughout MyCaffe.
Definition: Annotation.cs:12
Phase
Defines the Phase under which to run a Net.
Definition: Interfaces.cs:61
The MyCaffe.common namespace contains common MyCaffe classes.
Definition: BatchInput.cs:8
BLOB_TYPE
Defines the tpe of data held by a given Blob.
Definition: Interfaces.cs:62
TRAIN_STEP
Defines the training stepping method (if any).
Definition: Interfaces.cs:131
The MyCaffe.fillers namespace contains all fillers including the Filler class.
The MyCaffe.layers namespace contains all layers that have a solidified code base,...
Definition: LayerFactory.cs:15
The MyCaffe.param namespace contains parameters used to create models.
The MyCaffe.solvers namespace contains all solver classes, including the base Solver.
ITERATOR_TYPE
Specifies the iterator type to use.
Definition: Interfaces.cs:22
The MyCaffe namespace contains the main body of MyCaffe code that closesly tracks the C++ Caffe open-...
Definition: Annotation.cs:12