MyCaffe  1.12.2.41
Deep learning software for Windows C# programmers.
TrainerPG.cs
1using System;
2using System.Collections;
3using System.Collections.Generic;
4using System.Diagnostics;
5using System.Drawing;
6using System.Linq;
7using System.Text;
8using System.Threading;
9using System.Threading.Tasks;
10using MyCaffe.basecode;
11using MyCaffe.common;
12using MyCaffe.fillers;
13using MyCaffe.layers;
14using MyCaffe.param;
15using MyCaffe.solvers;
16
18{
27 public class TrainerPG<T> : IxTrainerRL, IDisposable
28 {
29 IxTrainerCallback m_icallback;
30 CryptoRandom m_random = new CryptoRandom();
31 MyCaffeControl<T> m_mycaffe;
32 PropertySet m_properties;
33 int m_nThreads = 1;
34 List<int> m_rgGpuID = new List<int>();
35 Optimizer<T> m_optimizer = null;
36
44 public TrainerPG(MyCaffeControl<T> mycaffe, PropertySet properties, CryptoRandom random, IxTrainerCallback icallback)
45 {
46 m_icallback = icallback;
47 m_mycaffe = mycaffe;
48 m_properties = properties;
49 m_random = random;
50
51 m_nThreads = m_properties.GetPropertyAsInt("Threads", 1);
52 m_rgGpuID.Add(m_mycaffe.Cuda.GetDeviceID());
53
54 string strGpuID = m_properties.GetProperty("GPUIDs", false);
55 if (strGpuID != null && m_nThreads > 1)
56 {
57 int nDeviceCount = m_mycaffe.Cuda.GetDeviceCount();
58
59 m_rgGpuID.Clear();
60 string[] rgstrGpuIDs = strGpuID.Split(',');
61 foreach (string strID in rgstrGpuIDs)
62 {
63 int nDevId = int.Parse(strID);
64
65 if (nDevId < 0 || nDevId >= nDeviceCount)
66 throw new Exception("Invalid device ID - value must be within the range [0," + (nDeviceCount - 1).ToString() + "].");
67
68 m_rgGpuID.Add(nDevId);
69 }
70 }
71 }
72
76 public void Dispose()
77 {
78 }
79
84 public bool Initialize()
85 {
86 m_mycaffe.CancelEvent.Reset();
87 m_icallback.OnInitialize(new InitializeArgs(m_mycaffe));
88 return true;
89 }
90
91 private void wait(int nWait)
92 {
93 int nWaitInc = 250;
94 int nTotalWait = 0;
95
96 while (nTotalWait < nWait)
97 {
98 m_icallback.OnWait(new WaitArgs(nWaitInc));
99 nTotalWait += nWaitInc;
100 }
101 }
102
108 public bool Shutdown(int nWait)
109 {
110 if (m_mycaffe != null)
111 {
112 m_mycaffe.CancelEvent.Set();
113 wait(nWait);
114 }
115
116 m_icallback.OnShutdown();
117
118 return true;
119 }
120
126 public ResultCollection RunOne(int nDelay = 1000)
127 {
128 m_mycaffe.CancelEvent.Reset();
129 Agent<T> agent = new Agent<T>(0, m_icallback, m_mycaffe, m_properties, m_random, Phase.TRAIN, 0, 1);
130 Tuple<int,int> res = agent.Run(nDelay);
131
132 List<Result> rgActions = new List<Result>();
133 for (int i = 0; i < res.Item2; i++)
134 {
135 if (res.Item1 == i)
136 rgActions.Add(new Result(i, 1.0));
137 else
138 rgActions.Add(new Result(i, 0.0));
139 }
140
141 agent.Dispose();
142
143 return new ResultCollection(rgActions, LayerParameter.LayerType.SOFTMAX);
144 }
145
153 public byte[] Run(int nN, PropertySet runProp, out string type)
154 {
155 m_mycaffe.CancelEvent.Reset();
156 Agent<T> agent = new Agent<T>(0, m_icallback, m_mycaffe, m_properties, m_random, Phase.RUN, 0, 1);
157 byte[] rgResults = agent.Run(nN, out type);
158 agent.Dispose();
159
160 return rgResults;
161 }
162
169 public bool Test(int nN, ITERATOR_TYPE type)
170 {
171 int nDelay = 1000;
172 string strProp = m_properties.ToString();
173
174 // Turn off the num-skip to run at normal speed.
175 strProp += "EnableNumSkip=False;";
176 PropertySet properties = new PropertySet(strProp);
177
178 m_mycaffe.CancelEvent.Reset();
179 Agent<T> agent = new Agent<T>(0, m_icallback, m_mycaffe, properties, m_random, Phase.TRAIN, 0, 1);
180 agent.Run(Phase.TEST, nN, type, TRAIN_STEP.NONE);
181
182 agent.Dispose();
183 Shutdown(nDelay);
184
185 return true;
186 }
187
195 public bool Train(int nN, ITERATOR_TYPE type, TRAIN_STEP step)
196 {
197 List<Agent<T>> rgAgents = new List<Agent<T>>();
198 int nGpuIdx = 0;
199
200 m_mycaffe.CancelEvent.Reset();
201
202 if (m_nThreads > 1)
203 m_optimizer = new Optimizer<T>(m_mycaffe);
204
205 for (int i = 0; i < m_nThreads; i++)
206 {
207 int nGpuID = m_rgGpuID[nGpuIdx];
208
209 Agent<T> agent = new Agent<T>(i, m_icallback, m_mycaffe, m_properties, m_random, Phase.TRAIN, nGpuID, m_nThreads);
210 agent.OnApplyUpdates += Agent_OnApplyUpdates;
211 rgAgents.Add(agent);
212
213 nGpuIdx++;
214 if (nGpuIdx == m_rgGpuID.Count)
215 nGpuIdx = 0;
216 }
217
218 if (m_optimizer != null)
219 m_optimizer.Start(new WorkerStartArgs(0, Phase.TRAIN, nN, type, step));
220
221 WorkerStartArgs args = new WorkerStartArgs(1, Phase.TRAIN, nN, type, step);
222 foreach (Agent<T> agent in rgAgents)
223 {
224 agent.Start(args);
225 }
226
227 while (!m_mycaffe.CancelEvent.WaitOne(250))
228 {
229 }
230
231 foreach (Agent<T> agent in rgAgents)
232 {
233 agent.Stop(1000);
234 agent.Dispose();
235 }
236
237 if (m_optimizer != null)
238 {
239 m_optimizer.Stop(1000);
240 m_optimizer.Dispose();
241 m_optimizer = null;
242 }
243
244 Shutdown(3000);
245
246 return false;
247 }
248
249 private void Agent_OnApplyUpdates(object sender, ApplyUpdateArgs<T> e)
250 {
251 if (m_optimizer != null)
252 m_optimizer.ApplyUpdates(e.MyCaffeWorker, e.Iteration);
253 }
254 }
255
260 {
261 int m_nCycleDelay;
262 Phase m_phase;
263 int m_nN;
264 ITERATOR_TYPE m_type;
265 TRAIN_STEP m_step = TRAIN_STEP.NONE;
266
275 public WorkerStartArgs(int nCycleDelay, Phase phase, int nN, ITERATOR_TYPE type, TRAIN_STEP step)
276 {
277 m_nCycleDelay = nCycleDelay;
278 m_phase = phase;
279 m_nN = nN;
280 m_type = type;
281 m_step = step;
282 }
283
288 {
289 get { return m_step; }
290 }
291
295 public int CycleDelay
296 {
297 get { return m_nCycleDelay; }
298 }
299
304 {
305 get { return m_phase; }
306 }
307
311 public int N
312 {
313 get { return m_nN; }
314 }
315
320 {
321 get { return m_type; }
322 }
323 }
324
328 class Worker
329 {
333 protected int m_nIndex = -1;
337 protected AutoResetEvent m_evtCancel = new AutoResetEvent(false);
341 protected ManualResetEvent m_evtDone = new ManualResetEvent(false);
345 protected Task m_workTask = null;
346
351 public Worker(int nIdx)
352 {
353 m_nIndex = nIdx;
354 }
355
360 protected virtual void doWork(object arg)
361 {
362 }
363
368 public void Start(WorkerStartArgs args)
369 {
370 if (m_workTask == null)
371 m_workTask = Task.Factory.StartNew(new Action<object>(doWork), args);
372 }
373
378 public void Stop(int nWait)
379 {
380 m_evtCancel.Set();
381 m_workTask = null;
382 m_evtDone.WaitOne(nWait);
383 }
384 }
385
391 class Optimizer<T> : Worker, IDisposable
392 {
393 MyCaffeControl<T> m_mycaffePrimary;
394 MyCaffeControl<T> m_mycaffeWorker;
395 int m_nIteration;
396 double m_dfLearningRate;
397 AutoResetEvent m_evtApplyUpdates = new AutoResetEvent(false);
398 ManualResetEvent m_evtDoneApplying = new ManualResetEvent(false);
399 object m_syncObj = new object();
400
405 public Optimizer(MyCaffeControl<T> mycaffePrimary)
406 : base(0)
407 {
408 m_mycaffePrimary = mycaffePrimary;
409 }
410
414 public void Dispose()
415 {
416 }
417
423 protected override void doWork(object arg)
424 {
425 WorkerStartArgs args = arg as WorkerStartArgs;
426
427 m_mycaffePrimary.Cuda.SetDeviceID();
428
429 List<WaitHandle> rgWait = new List<WaitHandle>();
430 rgWait.Add(m_evtApplyUpdates);
431 rgWait.AddRange(m_mycaffePrimary.CancelEvent.Handles);
432
433 int nWait = WaitHandle.WaitAny(rgWait.ToArray());
434
435 while (nWait == 0)
436 {
437 if (args.Step != TRAIN_STEP.FORWARD)
438 {
439 m_mycaffePrimary.CopyGradientsFrom(m_mycaffeWorker);
440 m_mycaffePrimary.Log.Enable = false;
441 m_dfLearningRate = m_mycaffePrimary.ApplyUpdate(m_nIteration);
442 m_mycaffePrimary.Log.Enable = true;
443 m_mycaffeWorker.CopyWeightsFrom(m_mycaffePrimary);
444 }
445
446 m_evtDoneApplying.Set();
447
448 nWait = WaitHandle.WaitAny(rgWait.ToArray());
449
450 if (args.Step != TRAIN_STEP.NONE)
451 break;
452 }
453 }
454
462 public double ApplyUpdates(MyCaffeControl<T> mycaffeWorker, int nIteration)
463 {
464 lock (m_syncObj)
465 {
466 m_mycaffeWorker = mycaffeWorker;
467 m_nIteration = nIteration;
468
469 m_evtDoneApplying.Reset();
470 m_evtApplyUpdates.Set();
471
472 List<WaitHandle> rgWait = new List<WaitHandle>();
473 rgWait.Add(m_evtDoneApplying);
474 rgWait.AddRange(m_mycaffePrimary.CancelEvent.Handles);
475
476 int nWait = WaitHandle.WaitAny(rgWait.ToArray());
477 if (nWait != 0)
478 return 0;
479
480 return m_dfLearningRate;
481 }
482 }
483 }
484
489 class Agent<T> : Worker, IDisposable
490 {
491 IxTrainerCallback m_icallback;
492 Brain<T> m_brain;
493 PropertySet m_properties;
494 CryptoRandom m_random;
495 float m_fGamma;
496 bool m_bAllowDiscountReset = false;
497 bool m_bUseRawInput = false;
498 int m_nEpsSteps = 0;
499 double m_dfEpsStart = 0;
500 double m_dfEpsEnd = 0;
501 double m_dfExplorationRate = 0;
502 int m_nEpisodeBatchSize = 1;
503 double m_dfEpisodeElitePercentile = 1;
504 static object m_syncObj = new object();
505 bool m_bShowActionProb = false;
506 bool m_bVerbose = false;
507
511 public event EventHandler<ApplyUpdateArgs<T>> OnApplyUpdates;
512
524 public Agent(int nIdx, IxTrainerCallback icallback, MyCaffeControl<T> mycaffe, PropertySet properties, CryptoRandom random, Phase phase, int nGpuID, int nThreadCount)
525 : base(nIdx)
526 {
527 m_icallback = icallback;
528 m_brain = new Brain<T>(mycaffe, properties, random, phase, nGpuID, nThreadCount);
529 m_brain.OnApplyUpdate += brain_OnApplyUpdate;
530 m_properties = properties;
531 m_random = random;
532
533 m_fGamma = (float)properties.GetPropertyAsDouble("Gamma", 0.99);
534 m_bAllowDiscountReset = properties.GetPropertyAsBool("AllowDiscountReset", false);
535 m_bUseRawInput = properties.GetPropertyAsBool("UseRawInput", false);
536 m_nEpsSteps = properties.GetPropertyAsInt("EpsSteps", 0);
537 m_dfEpsStart = properties.GetPropertyAsDouble("EpsStart", 0);
538 m_dfEpsEnd = properties.GetPropertyAsDouble("EpsEnd", 0);
539 m_nEpisodeBatchSize = m_properties.GetPropertyAsInt("EpisodeBatchSize", 1);
540 m_dfEpisodeElitePercentile = properties.GetPropertyAsDouble("EpisodeElitePercent", 1.0);
541 m_bShowActionProb = properties.GetPropertyAsBool("ShowActionProb", false);
542 m_bVerbose = properties.GetPropertyAsBool("Verbose", false);
543
544 if (m_dfEpsStart < 0 || m_dfEpsStart > 1)
545 throw new Exception("The 'EpsStart' is out of range - please specify a real number in the range [0,1]");
546
547 if (m_dfEpsEnd < 0 || m_dfEpsEnd > 1)
548 throw new Exception("The 'EpsEnd' is out of range - please specify a real number in the range [0,1]");
549
550 if (m_dfEpsEnd > m_dfEpsStart)
551 throw new Exception("The 'EpsEnd' must be less than the 'EpsStart' value.");
552 }
553
554 private void brain_OnApplyUpdate(object sender, ApplyUpdateArgs<T> e)
555 {
556 if (OnApplyUpdates != null)
557 OnApplyUpdates(sender, e);
558 }
559
563 public void Dispose()
564 {
565 if (m_brain != null)
566 {
567 m_brain.Dispose();
568 m_brain = null;
569 }
570 }
571
576 protected override void doWork(object arg)
577 {
578 try
579 {
580 WorkerStartArgs args = arg as WorkerStartArgs;
581
582 lock (m_syncObj)
583 {
584 m_brain.Create();
585 }
586
587 m_evtDone.Reset();
588 m_evtCancel.Reset();
589 Run(args.Phase, args.N, args.IterationType, args.Step);
590 m_evtDone.Set();
591 }
592 catch (Exception excpt)
593 {
594 m_brain.OutputLog.WriteError(excpt);
595 }
596
597 m_brain.Cancel.Set();
598 }
599
600 private double getEpsilon(int nEpisode)
601 {
602 if (m_nEpsSteps == 0)
603 return 0;
604
605 if (nEpisode >= m_nEpsSteps)
606 return m_dfEpsEnd;
607
608 return m_dfEpsStart + (double)(nEpisode * (m_dfEpsEnd - m_dfEpsStart)/m_nEpsSteps);
609 }
610
611 private StateBase getData(Phase phase, int nIdx, int nAction, bool? bResetOverride = null)
612 {
613 GetDataArgs args = m_brain.getDataArgs(phase, nIdx, nAction, bResetOverride);
614 m_icallback.OnGetData(args);
615 return args.State;
616 }
617
618 private int getAction(int nEpisode, SimpleDatum sd, SimpleDatum sdClip, int nActionCount, TRAIN_STEP step, out float[] rgfAprob)
619 {
620 if (step == TRAIN_STEP.NONE)
621 {
622 m_dfExplorationRate = getEpsilon(nEpisode);
623
624 if (m_dfExplorationRate > 0 && m_random.NextDouble() < m_dfExplorationRate)
625 {
626 rgfAprob = new float[nActionCount];
627 int nAction = m_random.Next(nActionCount);
628 rgfAprob[nAction] = 1.0f;
629 return nAction;
630 }
631 }
632
633 return m_brain.act(sd, sdClip, out rgfAprob);
634 }
635
636 private int updateStatus(int nIteration, int nEpisodeCount, double dfRunningReward, double dfRewardSum, double dfLoss, double dfLearningRate)
637 {
638 GetStatusArgs args = new GetStatusArgs(m_nIndex, nIteration, nEpisodeCount, 1000000, dfRunningReward, dfRewardSum, m_dfExplorationRate, 0, dfLoss, dfLearningRate);
639 m_icallback.OnUpdateStatus(args);
640 return args.NewFrameCount;
641 }
642
648 public Tuple<int, int> Run(int nDelay = 1000)
649 {
650 // Reset the environment and get the initial state.
651 getData(Phase.RUN, m_nIndex, -1);
652 Thread.Sleep(nDelay);
653
654 StateBase state = getData(Phase.RUN, m_nIndex, -1, false);
655 float[] rgfAprob;
656
657 m_brain.Create();
658
659 int a = m_brain.act(state.Data, state.Clip, out rgfAprob);
660
661 return new Tuple<int, int>(a, state.ActionCount);
662 }
663
670 public byte[] Run(int nIterations, out string type)
671 {
672 IxTrainerCallbackRNN icallback = m_icallback as IxTrainerCallbackRNN;
673 if (icallback == null)
674 throw new Exception("The Run method requires an IxTrainerCallbackRNN interface to convert the results into the native format!");
675
676 m_brain.Create();
677
678 StateBase s = getData(Phase.RUN, m_nIndex, -1);
679 int nIteration = 0;
680 List<float> rgResults = new List<float>();
681 int nLookahead = m_properties.GetPropertyAsInt("Lookahead", 0);
682
683 while (!m_brain.Cancel.WaitOne(0) && (nIterations == -1 || nIteration < nIterations))
684 {
685 // Preprocess the observation.
686 SimpleDatum x = m_brain.Preprocess(s, m_bUseRawInput);
687
688 // Forward the policy network and sample an action.
689 float[] rgfAprob;
690 int nAction = m_brain.act(x, s.Clip, out rgfAprob);
691
692 if (m_bShowActionProb && m_bVerbose)
693 {
694 string strOut = "Action Prob: " + Utility.ToString<float>(rgfAprob.ToList(), 4) + " -> " + nAction.ToString();
695 m_brain.OutputLog.WriteLine(strOut);
696 }
697
698 int nSeqLen = m_brain.RecurrentSequenceLength;
699 int nItemLen = s.Data.ItemCount / nSeqLen;
700 int nData1Idx = s.Data.ItemCount - (nItemLen * (nLookahead + 1));
701
702 rgResults.Add(s.Data.TimeStamp.ToFileTime());
703 rgResults.Add((float)s.Data.GetDataAtF(nData1Idx));
704 rgResults.Add(nAction);
705
706 // Take the next step using the action
707 s = getData(Phase.RUN, m_nIndex, nAction);
708 nIteration++;
709
710 m_brain.OutputLog.Progress = ((double)nIteration / (double)nIterations);
711 }
712
713 ConvertOutputArgs args = new ConvertOutputArgs(nIterations, rgResults.ToArray());
714 icallback.OnConvertOutput(args);
715
716 type = args.RawType;
717 return args.RawOutput;
718 }
719
720 private bool isAtIteration(int nN, ITERATOR_TYPE type, int nIteration, int nEpisode)
721 {
722 if (nN == -1)
723 return false;
724
725 if (type == ITERATOR_TYPE.EPISODE)
726 {
727 if (nEpisode < nN)
728 return false;
729
730 return true;
731 }
732 else
733 {
734 if (nIteration < nN)
735 return false;
736
737 return true;
738 }
739 }
740
752 public void Run(Phase phase, int nN, ITERATOR_TYPE type, TRAIN_STEP step)
753 {
754 MemoryCache rgMemoryCache = new MemoryCache(m_nEpisodeBatchSize);
755 Memory rgMemory = new Memory();
756 double? dfRunningReward = null;
757 double dfEpisodeReward = 0;
758 int nEpisode = 0;
759 int nIteration = 0;
760
761 m_brain.Create();
762
763 StateBase s = getData(phase, m_nIndex, -1);
764
765 while (!m_brain.Cancel.WaitOne(0) && !isAtIteration(nN, type, nIteration, nEpisode))
766 {
767 // Preprocess the observation.
768 SimpleDatum x = m_brain.Preprocess(s, m_bUseRawInput);
769
770 // Forward the policy network and sample an action.
771 float[] rgfAprob;
772 int action = getAction(nIteration, x, s.Clip, s.ActionCount, step, out rgfAprob);
773
774 if (m_bShowActionProb && m_bVerbose)
775 {
776 string strOut = "Action Prob: " + Utility.ToString<float>(rgfAprob.ToList(), 4) + " -> " + action.ToString();
777 m_brain.OutputLog.WriteLine(strOut);
778 }
779
780 if (step == TRAIN_STEP.FORWARD)
781 return;
782
783 // Take the next step using the action
784 StateBase s_ = getData(phase, m_nIndex, action);
785 dfEpisodeReward += s_.Reward;
786
787 if (phase == Phase.TRAIN)
788 {
789 // Build up episode memory, using reward for taking the action.
790 rgMemory.Add(new MemoryItem(s, x, action, rgfAprob, (float)s_.Reward));
791
792 // An episode has finished.
793 if (s_.Done)
794 {
795 nEpisode++;
796 nIteration++;
797
798 if (rgMemoryCache.Add(rgMemory))
799 {
800 if (m_bShowActionProb)
801 m_brain.OutputLog.WriteLine("---learning---");
802
803 rgMemoryCache.PurgeNonElite(m_dfEpisodeElitePercentile);
804
805 for (int i=0; i<rgMemoryCache.Count; i++)
806 {
807 Memory rgMemory1 = rgMemoryCache[i];
808
809 m_brain.Reshape(rgMemory1);
810
811 // Compute the discounted reward (backwards through time)
812 float[] rgDiscountedR = rgMemory1.GetDiscountedRewards(m_fGamma, m_bAllowDiscountReset);
813 // Rewards are normalized when set to be unit normal (helps control the gradient estimator variance)
814 m_brain.SetDiscountedR(rgDiscountedR);
815
816 // Sigmoid models, set the probabilities up font.
817 if (!m_brain.UsesSoftMax)
818 {
819 // Get the action probabilities.
820 float[] rgfAprobSet = rgMemory1.GetActionProbabilities();
821 // The action probabilities are used to calculate the initial gradient within the loss function.
822 m_brain.SetActionProbabilities(rgfAprobSet);
823 }
824
825 // Get the action one-hot vectors. When using Softmax, this contains the one-hot vector containing
826 // each action set (e.g. 3 actions with action 0 set would return a vector <1,0,0>).
827 // When using a binary probability (e.g. with Sigmoid), the each action set only contains a
828 // single element which is set to the action value itself (e.g. 0 for action '0' and 1 for action '1')
829 float[] rgfAonehotSet = rgMemory1.GetActionOneHotVectors();
830 m_brain.SetActionOneHotVectors(rgfAonehotSet);
831
832 // Train for one iteration, which triggers the loss function.
833 List<Datum> rgData = rgMemory1.GetData();
834 List<Datum> rgClip = rgMemory1.GetClip();
835
836 m_brain.SetData(rgData, rgClip);
837
838 bool bApplyGradients = (i == rgMemoryCache.Count - 1) ? true : false;
839 m_brain.Train(nIteration, step, bApplyGradients);
840
841 // Update reward running
842 if (!dfRunningReward.HasValue)
843 dfRunningReward = dfEpisodeReward;
844 else
845 dfRunningReward = dfRunningReward * 0.99 + dfEpisodeReward * 0.01;
846
847 nEpisode = updateStatus(nIteration, nEpisode, dfRunningReward.Value, dfEpisodeReward, m_brain.LastLoss, m_brain.LearningRate);
848 dfEpisodeReward = 0;
849 }
850
851 rgMemoryCache.Clear();
852 }
853
854 s = getData(phase, m_nIndex, -1);
855 rgMemory = new Memory();
856
857 if (step != TRAIN_STEP.NONE)
858 return;
859 }
860 else
861 {
862 s = s_;
863 }
864 }
865 else
866 {
867 if (s_.Done)
868 {
869 nEpisode++;
870
871 // Update reward running
872 if (!dfRunningReward.HasValue)
873 dfRunningReward = dfEpisodeReward;
874 else
875 dfRunningReward = dfRunningReward * 0.99 + dfEpisodeReward * 0.01;
876
877 nEpisode = updateStatus(nIteration, nEpisode, dfRunningReward.Value, dfEpisodeReward, m_brain.LastLoss, m_brain.LearningRate);
878 dfEpisodeReward = 0;
879
880 s = getData(phase, m_nIndex, -1);
881 }
882 else
883 {
884 s = s_;
885 }
886
887 nIteration++;
888 }
889 }
890 }
891 }
892
897 class Brain<T> : IDisposable
898 {
899 MyCaffeControl<T> m_mycaffePrimary;
900 MyCaffeControl<T> m_mycaffeWorker;
901 Net<T> m_net;
902 Solver<T> m_solver;
903 MemoryDataLayer<T> m_memData;
904 MemoryLossLayer<T> m_memLoss;
905 SoftmaxLayer<T> m_softmax = null;
906 SoftmaxCrossEntropyLossLayer<T> m_softmaxCe = null;
907 bool m_bSoftmaxCeSetup = false;
908 PropertySet m_properties;
909 CryptoRandom m_random;
910 BlobCollection<T> m_colAccumulatedGradients = new BlobCollection<T>();
911 Blob<T> m_blobDiscountedR;
912 Blob<T> m_blobPolicyGradient;
913 Blob<T> m_blobActionOneHot;
914 Blob<T> m_blobDiscountedR1;
915 Blob<T> m_blobPolicyGradient1;
916 Blob<T> m_blobActionOneHot1;
917 Blob<T> m_blobLoss;
918 Blob<T> m_blobAprobLogit;
919 bool m_bSkipLoss;
920 int m_nMiniBatch = 10;
921 SimpleDatum m_sdLast = null;
922 double m_dfLastLoss = 0;
923 double m_dfLearningRate = 0;
924 Phase m_phase;
925 int m_nGpuID = 0;
926 int m_nThreadCount = 1;
927 bool m_bCreated = false;
928 bool m_bUseAcceleratedTraining = false;
929 int m_nRecurrentSequenceLength = 0;
930 List<Datum> m_rgData = null;
931 List<Datum> m_rgClip = null;
932
936 public event EventHandler<ApplyUpdateArgs<T>> OnApplyUpdate;
937
947 public Brain(MyCaffeControl<T> mycaffe, PropertySet properties, CryptoRandom random, Phase phase, int nGpuID, int nThreadCount)
948 {
949 m_properties = properties;
950 m_random = random;
951 m_phase = phase;
952 m_nGpuID = nGpuID;
953 m_nThreadCount = nThreadCount;
954 m_mycaffePrimary = mycaffe;
955
956 int nMiniBatch = mycaffe.CurrentProject.GetBatchSize(phase);
957 if (nMiniBatch != 0)
958 m_nMiniBatch = nMiniBatch;
959
960 m_nMiniBatch = m_properties.GetPropertyAsInt("MiniBatch", m_nMiniBatch);
961
962 double? dfRate = mycaffe.CurrentProject.GetSolverSettingAsNumeric("base_lr");
963 if (dfRate.HasValue)
964 m_dfLearningRate = dfRate.Value;
965
966 m_bUseAcceleratedTraining = properties.GetPropertyAsBool("UseAcceleratedTraining", false);
967 }
968
972 public void Create()
973 {
974 if (m_bCreated)
975 return;
976
977 m_mycaffePrimary.Log.Enable = false;
978
979 if (m_nThreadCount == 1)
980 {
981 m_mycaffeWorker = m_mycaffePrimary;
982 m_mycaffePrimary.Cuda.SetDeviceID();
983 }
984 else
985 {
986 m_mycaffeWorker = m_mycaffePrimary.Clone(m_nGpuID);
987 }
988
989 m_mycaffePrimary.Log.Enable = true;
990
991 m_mycaffeWorker.Cuda.SetDeviceID();
992
993 m_net = m_mycaffeWorker.GetInternalNet(m_phase);
994 m_solver = m_mycaffeWorker.GetInternalSolver();
995
996 m_memData = m_net.FindLayer(LayerParameter.LayerType.MEMORYDATA, null) as MemoryDataLayer<T>;
997 m_memLoss = m_net.FindLayer(LayerParameter.LayerType.MEMORY_LOSS, null) as MemoryLossLayer<T>;
998 m_softmax = m_net.FindLayer(LayerParameter.LayerType.SOFTMAX, null) as SoftmaxLayer<T>;
999
1000 if (m_memData == null)
1001 throw new Exception("Could not find the MemoryData Layer!");
1002
1003 if (m_memLoss == null && m_phase != Phase.RUN)
1004 throw new Exception("Could not find the MemoryLoss Layer!");
1005
1006 m_memData.OnDataPack += memData_OnDataPack;
1007
1008 if (m_memLoss != null)
1009 m_memLoss.OnGetLoss += memLoss_OnGetLoss;
1010
1011 m_blobDiscountedR = new Blob<T>(m_mycaffeWorker.Cuda, m_mycaffeWorker.Log);
1012 m_blobPolicyGradient = new Blob<T>(m_mycaffeWorker.Cuda, m_mycaffeWorker.Log);
1013 m_blobActionOneHot = new Blob<T>(m_mycaffeWorker.Cuda, m_mycaffeWorker.Log);
1014 m_blobDiscountedR1 = new Blob<T>(m_mycaffeWorker.Cuda, m_mycaffeWorker.Log);
1015 m_blobPolicyGradient1 = new Blob<T>(m_mycaffeWorker.Cuda, m_mycaffeWorker.Log);
1016 m_blobActionOneHot1 = new Blob<T>(m_mycaffeWorker.Cuda, m_mycaffeWorker.Log);
1017 m_blobLoss = new Blob<T>(m_mycaffeWorker.Cuda, m_mycaffeWorker.Log);
1018 m_blobAprobLogit = new Blob<T>(m_mycaffeWorker.Cuda, m_mycaffeWorker.Log);
1019
1020 if (m_softmax != null)
1021 {
1022 LayerParameter p = new LayerParameter(LayerParameter.LayerType.SOFTMAXCROSSENTROPY_LOSS);
1023 p.loss_weight.Add(1);
1024 p.loss_weight.Add(0);
1026 m_softmaxCe = new SoftmaxCrossEntropyLossLayer<T>(m_mycaffeWorker.Cuda, m_mycaffeWorker.Log, p);
1027 }
1028
1029 m_colAccumulatedGradients = m_net.learnable_parameters.Clone();
1030 m_colAccumulatedGradients.SetDiff(0);
1031
1032 m_bCreated = true;
1033 }
1034
1035 private void dispose(ref Blob<T> b)
1036 {
1037 if (b != null)
1038 {
1039 b.Dispose();
1040 b = null;
1041 }
1042 }
1043
1047 public void Dispose()
1048 {
1049 if (m_memLoss != null)
1050 m_memLoss.OnGetLoss -= memLoss_OnGetLoss;
1051
1052 if (m_memData != null)
1053 m_memData.OnDataPack -= memData_OnDataPack;
1054
1055 dispose(ref m_blobDiscountedR);
1056 dispose(ref m_blobPolicyGradient);
1057 dispose(ref m_blobActionOneHot);
1058 dispose(ref m_blobDiscountedR1);
1059 dispose(ref m_blobPolicyGradient1);
1060 dispose(ref m_blobActionOneHot1);
1061 dispose(ref m_blobLoss);
1062 dispose(ref m_blobAprobLogit);
1063
1064 if (m_colAccumulatedGradients != null)
1065 {
1066 m_colAccumulatedGradients.Dispose();
1067 m_colAccumulatedGradients = null;
1068 }
1069
1070 if (m_mycaffeWorker != m_mycaffePrimary && m_mycaffeWorker != null)
1071 m_mycaffeWorker.Dispose();
1072
1073 m_mycaffeWorker = null;
1074 }
1075
1080 {
1081 get { return m_nRecurrentSequenceLength; }
1082 }
1083
1088 {
1089 get { return m_mycaffePrimary.Log; }
1090 }
1091
1095 public bool UsesSoftMax
1096 {
1097 get { return (m_softmax == null) ? false : true; }
1098 }
1099
1105 public int Reshape(Memory mem)
1106 {
1107 int nNum = mem.Count;
1108 int nChannels = mem[0].Data.Channels;
1109 int nHeight = mem[0].Data.Height;
1110 int nWidth = mem[0].Data.Height;
1111 int nActionProbs = 1;
1112 int nFound = 0;
1113
1114 for (int i = 0; i < m_net.output_blobs.Count; i++)
1115 {
1116 if (m_net.output_blobs[i].type != BLOB_TYPE.LOSS)
1117 {
1118 int nCh = m_net.output_blobs[i].channels;
1119 nActionProbs = Math.Max(nCh, nActionProbs);
1120 nFound++;
1121 }
1122 }
1123
1124 if (nFound == 0)
1125 throw new Exception("Could not find a non-loss output! Your model should output the loss and the action probabilities.");
1126
1127 m_blobDiscountedR.Reshape(nNum, nActionProbs, 1, 1);
1128 m_blobPolicyGradient.Reshape(nNum, nActionProbs, 1, 1);
1129 m_blobActionOneHot.Reshape(nNum, nActionProbs, 1, 1);
1130 m_blobDiscountedR1.Reshape(nNum, nActionProbs, 1, 1);
1131 m_blobPolicyGradient1.Reshape(nNum, nActionProbs, 1, 1);
1132 m_blobActionOneHot1.Reshape(nNum, nActionProbs, 1, 1);
1133 m_blobLoss.Reshape(1, 1, 1, 1);
1134
1135 return nActionProbs;
1136 }
1137
1142 public void SetDiscountedR(float[] rg)
1143 {
1144 double dfMean = m_blobDiscountedR.mean(rg);
1145 double dfStd = m_blobDiscountedR.std(dfMean, rg);
1146 int nC = m_blobDiscountedR.channels;
1147
1148 // Fill all items in each channel with the same discount value.
1149 if (nC > 1)
1150 {
1151 List<float> rgR = new List<float>();
1152
1153 for (int i = 0; i < rg.Length; i++)
1154 {
1155 for (int j = 0; j < nC; j++)
1156 {
1157 rgR.Add(rg[i]);
1158 }
1159 }
1160
1161 rg = rgR.ToArray();
1162 }
1163
1164 m_blobDiscountedR.SetData(Utility.ConvertVec<T>(rg));
1165 m_blobDiscountedR.NormalizeData(dfMean, dfStd);
1166 }
1167
1172 public void SetActionProbabilities(float[] rg)
1173 {
1174 m_blobPolicyGradient.SetData(Utility.ConvertVec<T>(rg));
1175 }
1176
1181 public void SetActionOneHotVectors(float[] rg)
1182 {
1183 m_blobActionOneHot.SetData(Utility.ConvertVec<T>(rg));
1184 }
1185
1191 public void SetData(List<Datum> rgData, List<Datum> rgClip)
1192 {
1193 if (m_nRecurrentSequenceLength != 1 && rgData.Count > 1 && rgClip != null)
1194 {
1195 m_rgData = rgData;
1196 m_rgClip = rgClip;
1197 }
1198 else
1199 {
1200 m_memData.AddDatumVector(rgData, rgClip, 1, true, true);
1201 m_rgData = null;
1202 m_rgClip = null;
1203 }
1204 }
1205
1214 public GetDataArgs getDataArgs(Phase phase, int nIdx, int nAction, bool? bResetOverride = null)
1215 {
1216 bool bReset = (nAction == -1) ? true : false;
1217 return new GetDataArgs(phase, nIdx, m_mycaffePrimary, m_mycaffePrimary.Log, m_mycaffePrimary.CancelEvent, bReset, nAction, false, true);
1218 }
1219
1223 public double LastLoss
1224 {
1225 get { return m_dfLastLoss; }
1226 }
1227
1231 public double LearningRate
1232 {
1233 get { return m_dfLearningRate; }
1234 }
1235
1239 public Log Log
1240 {
1241 get { return m_mycaffePrimary.Log; }
1242 }
1243
1248 {
1249 get { return m_mycaffePrimary.CancelEvent; }
1250 }
1251
1258 public SimpleDatum Preprocess(StateBase s, bool bUseRawInput)
1259 {
1260 SimpleDatum sd = new SimpleDatum(s.Data, true);
1261
1262 if (bUseRawInput)
1263 return sd;
1264
1265 if (m_sdLast == null)
1266 sd.Zero();
1267 else
1268 sd.Sub(m_sdLast);
1269
1270 m_sdLast = s.Data;
1271
1272 return sd;
1273 }
1274
1283 public int act(SimpleDatum sd, SimpleDatum sdClip, out float[] rgfAprob)
1284 {
1285 List<Datum> rgData = new List<Datum>();
1286 rgData.Add(new Datum(sd));
1287 List<Datum> rgClip = null;
1288
1289 if (sdClip != null)
1290 {
1291 rgClip = new List<Datum>();
1292 rgClip.Add(new Datum(sdClip));
1293 }
1294
1295 double dfLoss;
1296 float fRandom = (float)m_random.NextDouble(); // Roll the dice.
1297
1298 m_memData.AddDatumVector(rgData, rgClip, 1, true, true);
1299 m_bSkipLoss = true;
1300 BlobCollection<T> res = m_net.Forward(out dfLoss);
1301 m_bSkipLoss = false;
1302
1303 rgfAprob = null;
1304
1305 for (int i = 0; i < res.Count; i++)
1306 {
1307 if (res[i].type != BLOB_TYPE.LOSS)
1308 {
1309 int nStart = 0;
1310 // When using recurrent learning, only act on the last outputs.
1311 if (m_nRecurrentSequenceLength > 1 && res[i].num > 1)
1312 {
1313 int nCount = res[i].count();
1314 int nOutput = nCount / res[i].num;
1315 nStart = nCount - nOutput;
1316
1317 if (nStart < 0)
1318 throw new Exception("The start must be zero or greater!");
1319 }
1320
1321 rgfAprob = Utility.ConvertVecF<T>(res[i].update_cpu_data(), nStart);
1322 break;
1323 }
1324 }
1325
1326 if (rgfAprob == null)
1327 throw new Exception("Could not find a non-loss output! Your model should output the loss and the action probabilities.");
1328
1329 // Select the action from the probability distribution.
1330 float fSum = 0;
1331 for (int i = 0; i < rgfAprob.Length; i++)
1332 {
1333 fSum += rgfAprob[i];
1334
1335 if (fRandom < fSum)
1336 return i;
1337 }
1338
1339 if (rgfAprob.Length == 1)
1340 return 1;
1341
1342 return rgfAprob.Length - 1;
1343 }
1344
1345 private void prepareBlob(Blob<T> b1, Blob<T> b)
1346 {
1347 b1.CopyFrom(b, 0, 0, b1.count(), true, true);
1348 b.Reshape(1, b.channels, b.height, b.width);
1349 }
1350
1351 private void copyBlob(int nIdx, Blob<T> src, Blob<T> dst)
1352 {
1353 int nCount = dst.count();
1354 dst.CopyFrom(src, nIdx * nCount, 0, nCount, true, false);
1355 }
1356
1363 public void Train(int nIteration, TRAIN_STEP step, bool bApplyGradients = true)
1364 {
1365 // Run data/clip groups > 1 in non batch mode.
1366 if (m_nRecurrentSequenceLength != 1 && m_rgData != null && m_rgData.Count > 1 && m_rgClip != null)
1367 {
1368 prepareBlob(m_blobActionOneHot1, m_blobActionOneHot);
1369 prepareBlob(m_blobDiscountedR1, m_blobDiscountedR);
1370 prepareBlob(m_blobPolicyGradient1, m_blobPolicyGradient);
1371
1372 for (int i = 0; i < m_rgData.Count; i++)
1373 {
1374 copyBlob(i, m_blobActionOneHot1, m_blobActionOneHot);
1375 copyBlob(i, m_blobDiscountedR1, m_blobDiscountedR);
1376 copyBlob(i, m_blobPolicyGradient1, m_blobPolicyGradient);
1377
1378 List<Datum> rgData1 = new List<Datum>() { m_rgData[i] };
1379 List<Datum> rgClip1 = new List<Datum>() { m_rgClip[i] };
1380
1381 m_memData.AddDatumVector(rgData1, rgClip1, 1, true, true);
1382
1383 m_solver.Step(1, step, true, m_bUseAcceleratedTraining, true, true);
1384 m_colAccumulatedGradients.Accumulate(m_mycaffeWorker.Cuda, m_net.learnable_parameters, true);
1385 }
1386
1387 m_blobActionOneHot.ReshapeLike(m_blobActionOneHot1);
1388 m_blobDiscountedR.ReshapeLike(m_blobDiscountedR1);
1389 m_blobPolicyGradient.ReshapeLike(m_blobPolicyGradient1);
1390
1391 m_rgData = null;
1392 m_rgClip = null;
1393 }
1394 else
1395 {
1396 m_solver.Step(1, step, true, m_bUseAcceleratedTraining, true, true);
1397 m_colAccumulatedGradients.Accumulate(m_mycaffeWorker.Cuda, m_net.learnable_parameters, true);
1398 }
1399
1400 if (nIteration % m_nMiniBatch == 0 || bApplyGradients || step == TRAIN_STEP.BACKWARD || step == TRAIN_STEP.BOTH)
1401 {
1402 m_net.learnable_parameters.CopyFrom(m_colAccumulatedGradients, true);
1403 m_colAccumulatedGradients.SetDiff(0);
1404
1405 if (m_mycaffePrimary == m_mycaffeWorker)
1406 {
1407 m_dfLearningRate = m_solver.ApplyUpdate(nIteration);
1408 }
1409 else
1410 {
1411 ApplyUpdateArgs<T> args = new ApplyUpdateArgs<T>(nIteration, m_mycaffeWorker);
1412 OnApplyUpdate(this, args);
1413 m_dfLearningRate = args.LearningRate;
1414 }
1415
1416 m_net.ClearParamDiffs();
1417 }
1418 }
1419
1420 private T[] unpackLabel(Datum d)
1421 {
1422 if (d.DataCriteria == null)
1423 return null;
1424
1425 if (d.DataCriteriaFormat == SimpleDatum.DATA_FORMAT.LIST_FLOAT)
1426 {
1427 List<float> rgf = BinaryData.UnPackFloatList(d.DataCriteria, SimpleDatum.DATA_FORMAT.LIST_FLOAT);
1428 return Utility.ConvertVec<T>(rgf.ToArray());
1429 }
1430 else if (d.DataCriteriaFormat == SimpleDatum.DATA_FORMAT.LIST_DOUBLE)
1431 {
1432 List<double> rgf = BinaryData.UnPackDoubleList(d.DataCriteria, SimpleDatum.DATA_FORMAT.LIST_DOUBLE);
1433 return Utility.ConvertVec<T>(rgf.ToArray());
1434 }
1435
1436 return null;
1437 }
1438
1448 private void memData_OnDataPack(object sender, MemoryDataLayerPackDataArgs<T> e)
1449 {
1450 List<int> rgDataShape = e.Data.shape();
1451 List<int> rgClipShape = e.Clip.shape();
1452 List<int> rgLabelShape = e.Label.shape();
1453 int nBatch = e.DataItems.Count;
1454 int nSeqLen = rgDataShape[0];
1455
1456 e.Data.Log.CHECK_GT(nSeqLen, 0, "The sequence lenth must be greater than zero!");
1457 e.Data.Log.CHECK_EQ(nBatch, e.ClipItems.Count, "The data and clip should have the same number of items.");
1458 e.Data.Log.CHECK_EQ(nSeqLen, rgClipShape[0], "The data and clip should have the same sequence count.");
1459
1460 rgDataShape[1] = nBatch; // LSTM uses sizing: seq, batch, data1, data2
1461 rgClipShape[1] = nBatch;
1462 rgLabelShape[1] = nBatch;
1463
1464 e.Data.Reshape(rgDataShape);
1465 e.Clip.Reshape(rgClipShape);
1466 e.Label.Reshape(rgLabelShape);
1467
1468 T[] rgRawData = new T[e.Data.count()];
1469 T[] rgRawClip = new T[e.Clip.count()];
1470 T[] rgRawLabel = new T[e.Label.count()];
1471
1472 int nDataSize = e.Data.count(2);
1473 T[] rgDataItem = new T[nDataSize];
1474 T dfClip;
1475 int nIdx;
1476
1477 for (int i = 0; i < nBatch; i++)
1478 {
1479 Datum data = e.DataItems[i];
1480 Datum clip = e.ClipItems[i];
1481
1482 T[] rgLabel = unpackLabel(data);
1483
1484 for (int j = 0; j < nSeqLen; j++)
1485 {
1486 dfClip = clip.GetDataAt<T>(j);
1487
1488 for (int k = 0; k < nDataSize; k++)
1489 {
1490 rgDataItem[k] = data.GetDataAt<T>(j * nDataSize + k);
1491 }
1492
1493 // LSTM: Create input data, the data must be in the order
1494 // seq1_val1, seq2_val1, ..., seqBatch_Size_val1, seq1_val2, seq2_val2, ..., seqBatch_Size_valSequence_Length
1495 if (e.LstmType == LayerParameter.LayerType.LSTM)
1496 nIdx = nBatch * j + i;
1497
1498 // LSTM_SIMPLE: Create input data, the data must be in the order
1499 // seq1_val1, seq1_val2, ..., seq1_valBatchSize, seq2_val1, seq2_val2, ..., seqSequenceLength_valBatchSize
1500 else
1501 nIdx = i * nBatch + j;
1502
1503 Array.Copy(rgDataItem, 0, rgRawData, nIdx * nDataSize, nDataSize);
1504 rgRawClip[nIdx] = dfClip;
1505
1506 if (rgLabel != null)
1507 {
1508 if (rgLabel.Length == nSeqLen)
1509 rgRawLabel[nIdx] = rgLabel[j];
1510 else if (rgLabel.Length == 1)
1511 {
1512 if (j == nSeqLen - 1)
1513 rgRawLabel[0] = rgLabel[0];
1514 }
1515 else
1516 {
1517 throw new Exception("The Solver SequenceLength parameter does not match the actual sequence length! The label length '" + rgLabel.Length.ToString() + "' must be either '1' for SINGLE labels, or the sequence length of '" + nSeqLen.ToString() + "' for MULTI labels. Stopping training.");
1518 }
1519 }
1520 }
1521 }
1522
1523 e.Data.mutable_cpu_data = rgRawData;
1524 e.Clip.mutable_cpu_data = rgRawClip;
1525 e.Label.mutable_cpu_data = rgRawLabel;
1526 m_nRecurrentSequenceLength = nSeqLen;
1527 }
1528
1529
1545 private void memLoss_OnGetLoss(object sender, MemoryLossLayerGetLossArgs<T> e)
1546 {
1547 if (m_bSkipLoss)
1548 return;
1549
1550 int nCount = m_blobActionOneHot.count();
1551 long hActionOneHot = m_blobActionOneHot.gpu_data;
1552 long hPolicyGrad = 0;
1553 long hDiscountedR = m_blobDiscountedR.gpu_data;
1554 double dfLoss;
1555 int nDataSize = e.Bottom[0].count(1);
1556 bool bUsingEndData = false;
1557
1558 // When using a recurrent model and receiving data with more than one sequence,
1559 // copy and only use the last sequence data.
1560 if (m_nRecurrentSequenceLength > 1)
1561 {
1562 if (e.Bottom[0].num > 1)
1563 {
1564 m_blobAprobLogit.CopyFrom(e.Bottom[0], false, true);
1565 m_blobAprobLogit.CopyFrom(e.Bottom[0], true);
1566
1567 List<int> rgShape = e.Bottom[0].shape();
1568 rgShape[0] = 1;
1569 e.Bottom[0].Reshape(rgShape);
1570 e.Bottom[0].CopyFrom(m_blobAprobLogit, (m_blobAprobLogit.num - 1) * nDataSize, 0, nDataSize, true, true);
1571 bUsingEndData = true;
1572 }
1573 }
1574
1575 long hBottomDiff = e.Bottom[0].mutable_gpu_diff;
1576
1577 // Calculate the initial gradients (policy grad initially just contains the action probabilities)
1578 if (m_softmax != null)
1579 {
1580 BlobCollection<T> colBottom = new BlobCollection<T>();
1581 BlobCollection<T> colTop = new BlobCollection<T>();
1582
1583 colBottom.Add(e.Bottom[0]); // aprob logit
1584 colBottom.Add(m_blobActionOneHot); // action one-hot vectors
1585 colTop.Add(m_blobLoss);
1586 colTop.Add(m_blobPolicyGradient);
1587
1588 if (!m_bSoftmaxCeSetup)
1589 {
1590 m_softmaxCe.Setup(colBottom, colTop);
1591 m_bSoftmaxCeSetup = true;
1592 }
1593
1594 dfLoss = m_softmaxCe.Forward(colBottom, colTop);
1595 m_softmaxCe.Backward(colTop, new List<bool>() { true, false }, colBottom);
1596 hPolicyGrad = colBottom[0].gpu_diff;
1597 }
1598 else
1599 {
1600 hPolicyGrad = m_blobPolicyGradient.mutable_gpu_data;
1601
1602 // Calculate (a=0) ? 1-aprob : 0-aprob
1603 m_mycaffeWorker.Cuda.add_scalar(nCount, -1.0, hActionOneHot); // invert one hot
1604 m_mycaffeWorker.Cuda.abs(nCount, hActionOneHot, hActionOneHot);
1605 m_mycaffeWorker.Cuda.mul_scalar(nCount, -1.0, hPolicyGrad); // negate Aprob
1606 m_mycaffeWorker.Cuda.add(nCount, hActionOneHot, hPolicyGrad, hPolicyGrad); // gradient = ((a=0)?1:0) - Aprob
1607 dfLoss = Utility.ConvertVal<T>(m_blobPolicyGradient.sumsq_data());
1608
1609 m_mycaffeWorker.Cuda.mul_scalar(nCount, -1.0, hPolicyGrad); // invert for ApplyUpdate subtracts the gradients
1610 }
1611
1612 // Modulate the gradient with the advantage (PG magic happens right here.)
1613 m_mycaffeWorker.Cuda.mul(nCount, hPolicyGrad, hDiscountedR, hPolicyGrad);
1614
1615 e.Loss = dfLoss;
1616 e.EnableLossUpdate = false; // dont apply loss to loss weight.
1617
1618 if (hPolicyGrad != hBottomDiff)
1619 m_mycaffeWorker.Cuda.copy(nCount, hPolicyGrad, hBottomDiff);
1620
1621 // When using recurrent model with more than one sequence of data, only
1622 // copy the diff to the last in the sequence and zero out the rest in the sequence.
1623 if (m_nRecurrentSequenceLength > 1 && bUsingEndData)
1624 {
1625 m_blobAprobLogit.SetDiff(0);
1626 m_blobAprobLogit.CopyFrom(e.Bottom[0], 0, (m_blobAprobLogit.num - 1) * nDataSize, nDataSize, false, true);
1627 e.Bottom[0].CopyFrom(m_blobAprobLogit, false, true);
1628 e.Bottom[0].CopyFrom(m_blobAprobLogit, true);
1629 }
1630
1631 m_dfLastLoss = e.Loss;
1632 }
1633 }
1634
1638 class MemoryCache : IEnumerable<Memory>
1639 {
1640 int m_nMax;
1641 List<Memory> m_rgMemory = new List<Memory>();
1642
1647 public MemoryCache(int nMax)
1648 {
1649 m_nMax = nMax;
1650 }
1651
1655 public int Count
1656 {
1657 get { return m_rgMemory.Count; }
1658 }
1659
1665 public Memory this[int nIdx]
1666 {
1667 get { return m_rgMemory[nIdx]; }
1668 }
1669
1675 public bool Add(Memory mem)
1676 {
1677 m_rgMemory.Add(mem);
1678
1679 if (m_rgMemory.Count == m_nMax)
1680 return true;
1681
1682 return false;
1683 }
1684
1688 public void Clear()
1689 {
1690 m_rgMemory.Clear();
1691 }
1692
1697 public void PurgeNonElite(double dfElitePercent)
1698 {
1699 if (dfElitePercent <= 0.0 || dfElitePercent >= 1.0)
1700 return;
1701
1702 double dfMin = m_rgMemory.Min(p => p.RewardSum);
1703 double dfMax = m_rgMemory.Max(p => p.RewardSum);
1704 double dfRange = dfMax - dfMin;
1705 double dfCutoff = dfMin + ((1.0 - dfElitePercent) * dfRange);
1706 List<Memory> rgMem = m_rgMemory.OrderByDescending(p => p.RewardSum).ToList();
1707 List<Memory> rgElite = new List<Memory>();
1708
1709 for (int i = 0; i < rgMem.Count; i++)
1710 {
1711 double dfSum = rgMem[i].RewardSum;
1712
1713 if (dfSum >= dfCutoff)
1714 rgElite.Add(rgMem[i]);
1715 else
1716 break;
1717 }
1718
1719 m_rgMemory = rgElite;
1720 }
1721
1726 public IEnumerator<Memory> GetEnumerator()
1727 {
1728 return m_rgMemory.GetEnumerator();
1729 }
1730
1735 IEnumerator IEnumerable.GetEnumerator()
1736 {
1737 return m_rgMemory.GetEnumerator();
1738 }
1739 }
1740
1745 {
1746 List<MemoryItem> m_rgItems = new List<MemoryItem>();
1747 int m_nEpisodeNumber = 0;
1748 double m_dfRewardSum = 0;
1749
1753 public Memory()
1754 {
1755 }
1756
1760 public int Count
1761 {
1762 get { return m_rgItems.Count; }
1763 }
1764
1772 public void Add(MemoryItem item)
1773 {
1774 m_dfRewardSum += item.Reward;
1775 m_rgItems.Add(item);
1776 }
1777
1781 public void Clear()
1782 {
1783 m_dfRewardSum = 0;
1784 m_rgItems.Clear();
1785 }
1786
1792 public MemoryItem this[int nIdx]
1793 {
1794 get { return m_rgItems[nIdx]; }
1795 set { m_rgItems[nIdx] = value; }
1796 }
1797
1801 public int EpisodeNumber
1802 {
1803 get { return m_nEpisodeNumber; }
1804 set { m_nEpisodeNumber = value; }
1805 }
1806
1810 public double RewardSum
1811 {
1812 get { return m_dfRewardSum; }
1813 set { m_dfRewardSum = value; }
1814 }
1815
1822 public float[] GetDiscountedRewards(float fGamma, bool bAllowReset)
1823 {
1824 float[] rgR = m_rgItems.Select(p => p.Reward).ToArray();
1825 float fRunningAdd = 0;
1826 float[] rgDiscountedR = new float[rgR.Length];
1827
1828 for (int t = Count - 1; t >= 0; t--)
1829 {
1830 if (bAllowReset && rgR[t] != 0)
1831 fRunningAdd = 0;
1832
1833 fRunningAdd = fRunningAdd * fGamma + rgR[t];
1834 rgDiscountedR[t] = fRunningAdd;
1835 }
1836
1837 return rgDiscountedR;
1838 }
1839
1847 public float[] GetActionProbabilities()
1848 {
1849 List<float> rgfAprob = new List<float>();
1850
1851 for (int i = 0; i < m_rgItems.Count; i++)
1852 {
1853 rgfAprob.AddRange(m_rgItems[i].Aprob);
1854 }
1855
1856 return rgfAprob.ToArray();
1857 }
1858
1863 public float[] GetActionOneHotVectors()
1864 {
1865 List<float> rgfAonehot = new List<float>();
1866
1867 for (int i = 0; i < m_rgItems.Count; i++)
1868 {
1869 float[] rgfOneHot = new float[m_rgItems[0].Aprob.Length];
1870
1871 if (rgfOneHot.Length == 1)
1872 rgfOneHot[0] = m_rgItems[i].Action;
1873 else
1874 rgfOneHot[m_rgItems[i].Action] = 1;
1875
1876 rgfAonehot.AddRange(rgfOneHot);
1877 }
1878
1879 return rgfAonehot.ToArray();
1880 }
1881
1886 public List<Datum> GetData()
1887 {
1888 List<Datum> rgData = new List<Datum>();
1889
1890 for (int i = 0; i < m_rgItems.Count; i++)
1891 {
1892 rgData.Add(new Datum(m_rgItems[i].Data));
1893 }
1894
1895 return rgData;
1896 }
1897
1902 public List<Datum> GetClip()
1903 {
1904 if (m_rgItems.Count == 0)
1905 return null;
1906
1907 if (m_rgItems[0].State.Clip == null)
1908 return null;
1909
1910 List<Datum> rgData = new List<Datum>();
1911
1912 for (int i = 0; i < m_rgItems.Count; i++)
1913 {
1914 if (m_rgItems[i].State.Clip == null)
1915 return null;
1916
1917 rgData.Add(new Datum(m_rgItems[i].State.Clip));
1918 }
1919
1920 return rgData;
1921 }
1922 }
1923
1928 {
1929 StateBase m_state;
1930 SimpleDatum m_x;
1931 int m_nAction;
1932 float[] m_rgfAprob;
1933 float m_fReward;
1934
1943 public MemoryItem(StateBase s, SimpleDatum x, int nAction, float[] rgfAprob, float fReward)
1944 {
1945 m_state = s;
1946 m_x = x;
1947 m_nAction = nAction;
1948 m_rgfAprob = rgfAprob;
1949 m_fReward = fReward;
1950 }
1951
1956 {
1957 get { return m_state; }
1958 }
1959
1964 {
1965 get { return m_x; }
1966 }
1967
1971 public int Action
1972 {
1973 get { return m_nAction; }
1974 }
1975
1979 public float Reward
1980 {
1981 get { return m_fReward; }
1982 }
1983
1987 public float[] Aprob
1988 {
1989 get { return m_rgfAprob; }
1990 }
1991
1996 public override string ToString()
1997 {
1998 return "action = " + m_nAction.ToString() + " reward = " + m_fReward.ToString("N2") + " aprob = " + tostring(m_rgfAprob);
1999 }
2000
2001 private string tostring(float[] rg)
2002 {
2003 string str = "{";
2004
2005 for (int i = 0; i < rg.Length; i++)
2006 {
2007 str += rg[i].ToString("N5");
2008 str += ",";
2009 }
2010
2011 str = str.TrimEnd(',');
2012 str += "}";
2013
2014 return str;
2015 }
2016 }
2017}
The MyCaffeControl is the main object used to manage all training, testing and running of the MyCaffe...
CancelEvent CancelEvent
Returns the CancelEvent used.
Net< T > GetInternalNet(Phase phase=Phase.RUN)
Returns the internal net based on the Phase specified: TRAIN, TEST or RUN.
void CopyWeightsFrom(MyCaffeControl< T > src)
Copy the learnable parameter data from the source MyCaffeControl into this one.
Solver< T > GetInternalSolver()
Get the internal solver.
MyCaffeControl< T > Clone(int nGpuID)
Clone the current instance of the MyCaffeControl creating a second instance.
Log Log
Returns the Log (for output) used.
CudaDnn< T > Cuda
Returns the CudaDnn connection used.
ProjectEx CurrentProject
Returns the name of the currently loaded project.
The BinaryData class is used to pack and unpack DataCriteria binary data, optionally stored within ea...
Definition: BinaryData.cs:15
static List< double > UnPackDoubleList(byte[] rg, DATA_FORMAT fmtExpected)
Unpack the byte array into a list of double values.
Definition: BinaryData.cs:75
static List< float > UnPackFloatList(byte[] rg, DATA_FORMAT fmtExpected)
Unpack the byte array into a list of float values.
Definition: BinaryData.cs:132
The CancelEvent provides an extension to the manual cancel event that allows for overriding the manua...
Definition: CancelEvent.cs:17
void Reset()
Resets the event clearing any signaled state.
Definition: CancelEvent.cs:279
bool WaitOne(int nMs=int.MaxValue)
Waits for the signal state to occur.
Definition: CancelEvent.cs:290
CancelEvent()
The CancelEvent constructor.
Definition: CancelEvent.cs:28
void Set()
Sets the event to the signaled state.
Definition: CancelEvent.cs:270
The CryptoRandom is a random number generator that can use either the standard .Net Random objec or t...
Definition: CryptoRandom.cs:14
int Next(int nMinVal, int nMaxVal, bool bMaxInclusive=true)
Returns a random int within the range
double NextDouble()
Returns a random double within the range .
Definition: CryptoRandom.cs:83
The Datum class is a simple wrapper to the SimpleDatum class to ensure compatibility with the origina...
Definition: Datum.cs:12
The Log class provides general output in text form.
Definition: Log.cs:13
void WriteLine(string str, bool bOverrideEnabled=false, bool bHeader=false, bool bError=false, bool bDisable=false)
Write a line of output.
Definition: Log.cs:80
double Progress
Get/set the progress associated with the Log.
Definition: Log.cs:147
void WriteError(Exception e)
Write an error as output.
Definition: Log.cs:130
Log(string strSrc)
The Log constructor.
Definition: Log.cs:33
double? GetSolverSettingAsNumeric(string strParam)
Get a setting from the solver descriptor as a double value.
Definition: ProjectEx.cs:470
int GetBatchSize(Phase phase)
Returns the batch size of the project used in a given Phase.
Definition: ProjectEx.cs:359
Specifies a key-value pair of properties.
Definition: PropertySet.cs:16
string GetProperty(string strName, bool bThrowExceptions=true)
Returns a property as a string value.
Definition: PropertySet.cs:146
int GetPropertyAsInt(string strName, int nDefault=0)
Returns a property as an integer value.
Definition: PropertySet.cs:287
bool GetPropertyAsBool(string strName, bool bDefault=false)
Returns a property as a boolean value.
Definition: PropertySet.cs:267
double GetPropertyAsDouble(string strName, double dfDefault=0)
Returns a property as an double value.
Definition: PropertySet.cs:307
override string ToString()
Returns the string representation of the properties.
Definition: PropertySet.cs:325
The Result class contains a single result.
Definition: Result.cs:14
The SimpleDatum class holds a data input within host memory.
Definition: SimpleDatum.cs:161
void Copy(SimpleDatum d, bool bCopyData, int? nHeight=null, int? nWidth=null)
Copy another SimpleDatum into this one.
float GetDataAtF(int nIdx)
Returns the item at a specified index in the float type.
bool Sub(SimpleDatum sd, bool bSetNegativeToZero=false)
Subtract the data of another SimpleDatum from this one, so this = this - sd.
void Zero()
Zero out all data in the datum but keep the size and other settings.
int ItemCount
Returns the number of data items.
DateTime TimeStamp
Get/set the Timestamp.
byte[] DataCriteria
Get/set data criteria associated with the data.
DATA_FORMAT
Defines the data format of the DebugData and DataCriteria when specified.
Definition: SimpleDatum.cs:223
DATA_FORMAT DataCriteriaFormat
Get/set the data format of the data criteria.
The Utility class provides general utility funtions.
Definition: Utility.cs:35
static double[] ConvertVec(float[] rgf)
Convert an array of float to an array of generics.
Definition: Utility.cs:550
The BlobCollection contains a list of Blobs.
void Dispose()
Release all resource used by the collection and its Blobs.
void Add(Blob< T > b)
Add a new Blob to the collection.
void Accumulate(CudaDnn< T > cuda, BlobCollection< T > src, bool bAccumulateDiff)
Accumulate the diffs from one BlobCollection into another.
void SetDiff(double df)
Set all blob diff to the value specified.
int Count
Returns the number of items in the collection.
The Blob is the main holder of data that moves through the Layers of the Net.
Definition: Blob.cs:25
int channels
DEPRECIATED; legacy shape accessor channels: use shape(1) instead.
Definition: Blob.cs:800
void SetData(T[] rgData, int nCount=-1, bool bSetCount=true)
Sets a number of items within the Blob's data.
Definition: Blob.cs:1922
int height
DEPRECIATED; legacy shape accessor height: use shape(2) instead.
Definition: Blob.cs:808
long mutable_gpu_data
Returns the data GPU handle used by the CudaDnn connection.
Definition: Blob.cs:1487
void Reshape(int nNum, int nChannels, int nHeight, int nWidth, bool? bUseHalfSize=null)
DEPRECIATED; use
Definition: Blob.cs:442
double std(double? dfMean=null, float[] rgDf=null)
Calculate the standard deviation of the blob data.
Definition: Blob.cs:3007
double mean(float[] rgDf=null, bool bDiff=false)
Calculate the mean of the blob data.
Definition: Blob.cs:2965
void CopyFrom(Blob< T > src, int nSrcOffset, int nDstOffset, int nCount, bool bCopyData, bool bCopyDiff)
Copy from a source Blob.
Definition: Blob.cs:903
int width
DEPRECIATED; legacy shape accessor width: use shape(3) instead.
Definition: Blob.cs:816
T sumsq_data()
Calcualte the sum of squares (L2 norm squared) of the data.
Definition: Blob.cs:1730
void NormalizeData(double? dfMean=null, double? dfStd=null)
Normalize the blob data by subtracting the mean and dividing by the standard deviation.
Definition: Blob.cs:2942
int count()
Returns the total number of items in the Blob.
Definition: Blob.cs:739
void ReshapeLike(Blob< T > b, bool? bUseHalfSize=null)
Reshape this Blob to have the same shape as another Blob.
Definition: Blob.cs:648
void SetDiff(double dfVal, int nIdx=-1)
Either sets all of the diff items in the Blob to a given value, or alternatively only sets a single i...
Definition: Blob.cs:1981
int num
DEPRECIATED; legacy shape accessor num: use shape(0) instead.
Definition: Blob.cs:792
long gpu_data
Returns the data GPU handle used by the CudaDnn connection.
Definition: Blob.cs:1479
Connects Layer's together into a direct acrylic graph (DAG) specified by a NetParameter
Definition: Net.cs:23
BlobCollection< T > Forward()
Run forward with the input Blob's already fed separately.
Definition: Net.cs:1445
Layer< T > FindLayer(LayerParameter.LayerType? type, string strName)
Find the layer with the matching type, name and or both.
Definition: Net.cs:2748
BlobCollection< T > output_blobs
Returns the collection of output Blobs.
Definition: Net.cs:2209
void ClearParamDiffs()
Zero out the diffs of all netw parameters. This should be run before Backward.
Definition: Net.cs:1907
BlobCollection< T > learnable_parameters
Returns the learnable parameters.
Definition: Net.cs:2117
The ResultCollection contains the result of a given CaffeControl::Run.
void Backward(BlobCollection< T > colTop, List< bool > rgbPropagateDown, BlobCollection< T > colBottom)
Given the top Blob error gradients, compute the bottom Blob error gradients.
Definition: Layer.cs:815
double Forward(BlobCollection< T > colBottom, BlobCollection< T > colTop)
Given the bottom (input) Blobs, this function computes the top (output) Blobs and the loss.
Definition: Layer.cs:728
void Setup(BlobCollection< T > colBottom, BlobCollection< T > colTop)
Implements common Layer setup functionality.
Definition: Layer.cs:439
The MemoryDataLayer provides data to the Net from memory. This layer is initialized with the MyCaffe....
virtual void AddDatumVector(Datum[] rgData, Datum[] rgClip=null, int nLblAxis=1, bool bReset=false, bool bResizeBatch=false)
This method is used to add a list of Datums to the memory.
EventHandler< MemoryDataLayerPackDataArgs< T > > OnDataPack
The OnDataPack event fires from within the AddDatumVector method and is used to pack the data into a ...
The MemoryDataLayerPackDataArgs is passed to the OnDataPack event which fires each time the data rece...
Blob< T > Label
Returns the label data to fill with ordered label information.
Blob< T > Clip
Returns the clip data to fill with ordered data for clipping.
List< Datum > ClipItems
Returns the raw clip items to use to fill.
LayerParameter.LayerType LstmType
Returns the LSTM type.
Blob< T > Data
Returns the blob data to fill with ordered data.
List< Datum > DataItems
Returns the raw data items to use to fill.
The MemoryLossLayerGetLossArgs class is passed to the OnGetLoss event.
bool EnableLossUpdate
Get/set enabling the loss update within the backpropagation pass.
double Loss
Get/set the externally calculated total loss.
BlobCollection< T > Bottom
Specifies the bottom passed in during the forward pass.
The MemoryLossLayer provides a method of performing a custom loss functionality. Similar to the Memor...
EventHandler< MemoryLossLayerGetLossArgs< T > > OnGetLoss
The OnGetLoss event fires during each forward pass. The value returned is saved, and applied on the b...
The SoftmaxCrossEntropyLossLayer computes the cross-entropy (logisitic) loss and is often used for pr...
The SoftmaxLayer computes the softmax function. This layer is initialized with the MyCaffe....
Definition: SoftmaxLayer.cs:24
Specifies the base parameter for all layers.
List< double > loss_weight
Specifies the loss weight.
LayerType
Specifies the layer type.
LossParameter loss_param
Returns the parameter set when initialized with LayerType.LOSS
Stores the parameters used by loss layers.
NormalizationMode
How to normalize the loss for loss layers that aggregate across batches, spatial dimensions,...
NormalizationMode? normalization
Specifies the normalization mode (default = VALID).
An interface for classes that perform optimization on Nets - this class serves as the base class for ...
Definition: Solver.cs:28
bool Step(int nIters, TRAIN_STEP step=TRAIN_STEP.NONE, bool bZeroDiffs=true, bool bApplyUpdates=true, bool bDisableOutput=false, bool bDisableProgress=false, double? dfLossOverride=null, bool? bAllowSnapshot=null)
Steps a set of iterations through a training cycle.
Definition: Solver.cs:818
abstract double ApplyUpdate(int nIterationOverride=-1)
Make and apply the update value for the current iteration.
The ApplyUpdateArgs is passed to the OnApplyUpdates event.
Definition: EventArgs.cs:19
int Iteration
Returns the iteration from which the gradients are to be applied.
Definition: EventArgs.cs:47
double LearningRate
Returns the learning rate at the time the gradients were applied.
Definition: EventArgs.cs:55
MyCaffeControl< T > MyCaffeWorker
Returns the MyCaffe worker instance whos gradients are to be applied.
Definition: EventArgs.cs:39
The ConvertOutputArgs is passed to the OnConvertOutput event.
Definition: EventArgs.cs:311
byte[] RawOutput
Specifies the raw output byte stream.
Definition: EventArgs.cs:356
string RawType
Specifies the type of the raw output byte stream.
Definition: EventArgs.cs:348
The GetDataArgs is passed to the OnGetData event to retrieve data.
Definition: EventArgs.cs:402
StateBase State
Specifies the state data of the observations.
Definition: EventArgs.cs:517
The InitializeArgs is passed to the OnInitialize event.
Definition: EventArgs.cs:90
The StateBase is the base class for the state of each observation - this is defined by actual trainer...
Definition: StateBase.cs:16
bool Done
Get/set whether the state is done or not.
Definition: StateBase.cs:72
double Reward
Get/set the reward of the state.
Definition: StateBase.cs:63
SimpleDatum Data
Returns other data associated with the state.
Definition: StateBase.cs:98
int ActionCount
Returns the number of actions.
Definition: StateBase.cs:90
SimpleDatum Clip
Returns the clip data assoicated with the state.
Definition: StateBase.cs:116
The WaitArgs is passed to the OnWait event.
Definition: EventArgs.cs:65
The Agent both builds episodes from the envrionment and trains on them using the Brain.
Definition: TrainerPG.cs:490
void Dispose()
Release all resources used.
Definition: TrainerPG.cs:563
byte[] Run(int nIterations, out string type)
Run the action on a set number of iterations and return the results with no training.
Definition: TrainerPG.cs:670
override void doWork(object arg)
This is the main agent thread that runs the agent.
Definition: TrainerPG.cs:576
Agent(int nIdx, IxTrainerCallback icallback, MyCaffeControl< T > mycaffe, PropertySet properties, CryptoRandom random, Phase phase, int nGpuID, int nThreadCount)
The constructor.
Definition: TrainerPG.cs:524
EventHandler< ApplyUpdateArgs< T > > OnApplyUpdates
The OnApplyUpdates event fires each time the Agent needs to apply its updates to the primary instance...
Definition: TrainerPG.cs:511
void Run(Phase phase, int nN, ITERATOR_TYPE type, TRAIN_STEP step)
The Run method provides the main 'actor' loop that performs the following steps: 1....
Definition: TrainerPG.cs:752
Tuple< int, int > Run(int nDelay=1000)
Run a single action on the model.
Definition: TrainerPG.cs:648
The Brain uses the instance of MyCaffe (e.g. the open project) to run new actions and train the netwo...
Definition: TrainerPG.cs:898
void SetDiscountedR(float[] rg)
Sets the discounted returns in the Discounted Returns Blob.
Definition: TrainerPG.cs:1142
EventHandler< ApplyUpdateArgs< T > > OnApplyUpdate
The OnApplyUpdate event fires when the Brain needs to apply its gradients to the primary instance of ...
Definition: TrainerPG.cs:936
double LearningRate
Return the learning rate used.
Definition: TrainerPG.cs:1232
double LastLoss
Return the last loss received.
Definition: TrainerPG.cs:1224
int Reshape(Memory mem)
Reshape all Blobs used based on the Memory specified.
Definition: TrainerPG.cs:1105
GetDataArgs getDataArgs(Phase phase, int nIdx, int nAction, bool? bResetOverride=null)
Returns the GetDataArgs used to retrieve new data from the envrionment implemented by derived parent ...
Definition: TrainerPG.cs:1214
void Train(int nIteration, TRAIN_STEP step, bool bApplyGradients=true)
Train the model at the current iteration.
Definition: TrainerPG.cs:1363
int act(SimpleDatum sd, SimpleDatum sdClip, out float[] rgfAprob)
Returns the action from running the model. The action returned is either randomly selected (when usin...
Definition: TrainerPG.cs:1283
void SetData(List< Datum > rgData, List< Datum > rgClip)
Add the data to the model by adding it to the MemoryData layer.
Definition: TrainerPG.cs:1191
CancelEvent Cancel
Returns the Cancel event used to cancel all MyCaffe tasks.
Definition: TrainerPG.cs:1248
bool? UsesSoftMax
Returns true if the current model uses a SoftMax, false otherwise.
Definition: TrainerPG.cs:1096
Brain(MyCaffeControl< T > mycaffe, PropertySet properties, CryptoRandom random, Phase phase, int nGpuID, int nThreadCount)
The constructor.
Definition: TrainerPG.cs:947
Log OutputLog
Returns the primary MyCaffe output log for writing output information.
Definition: TrainerPG.cs:1088
int RecurrentSequenceLength
Returns the recurrent sequence length detected when training a recurrent network, otherwise 0 is retu...
Definition: TrainerPG.cs:1080
void Dispose()
Release all resources used by the Brain.
Definition: TrainerPG.cs:1047
void SetActionProbabilities(float[] rg)
Set the action probabilities in the Policy Gradient Blob.
Definition: TrainerPG.cs:1172
void Create()
Create the Brain CUDA objects - this is called on the thread from which the Brain runs.
Definition: TrainerPG.cs:972
SimpleDatum Preprocess(StateBase s, bool bUseRawInput)
Preprocesses the data.
Definition: TrainerPG.cs:1258
void SetActionOneHotVectors(float[] rg)
Set the action one-hot vectors in the Action OneHot Vector Blob.
Definition: TrainerPG.cs:1181
Contains the best memory episodes (best by highest total rewards)
Definition: TrainerPG.cs:1639
void Clear()
Clear all items from the memory cache.
Definition: TrainerPG.cs:1688
bool Add(Memory mem)
Add a new episode to the memory cache.
Definition: TrainerPG.cs:1675
void PurgeNonElite(double dfElitePercent)
Purge all non elite episodes.
Definition: TrainerPG.cs:1697
int Count
Returns the number of items in the cache.
Definition: TrainerPG.cs:1656
MemoryCache(int nMax)
Constructor.
Definition: TrainerPG.cs:1647
IEnumerator< Memory > GetEnumerator()
Returns the enumerator.
Definition: TrainerPG.cs:1726
Specifies a single Memory (e.g. an episode).
Definition: TrainerPG.cs:1745
int EpisodeNumber
Get/set the episode number of this memory.
Definition: TrainerPG.cs:1802
void Add(MemoryItem item)
Add a new item to the memory.
Definition: TrainerPG.cs:1772
Memory()
The constructor.
Definition: TrainerPG.cs:1753
float[] GetActionProbabilities()
Retrieve the action probabilities of the episode.
Definition: TrainerPG.cs:1847
double RewardSum
Get/set the reward sum of this memory.
Definition: TrainerPG.cs:1811
float[] GetDiscountedRewards(float fGamma, bool bAllowReset)
Retrieve the discounted rewards for this episode.
Definition: TrainerPG.cs:1822
void Clear()
Remove all items in the list.
Definition: TrainerPG.cs:1781
int Count
Returns the number of memory items in the memory.
Definition: TrainerPG.cs:1761
List< Datum > GetClip()
Returns the clip data if it exists, or null.
Definition: TrainerPG.cs:1902
float[] GetActionOneHotVectors()
Retrieve the action one-hot vectors for the episode.
Definition: TrainerPG.cs:1863
List< Datum > GetData()
Retrieve the data of each step in the episode.
Definition: TrainerPG.cs:1886
The MemoryItem stores the information for one step in an episode.
Definition: TrainerPG.cs:1928
float[] Aprob
Returns the action probabilities which are only used with non-Softmax models.
Definition: TrainerPG.cs:1988
int Action
Returns the action of this episode step.
Definition: TrainerPG.cs:1972
float Reward
Returns the reward for taking the action in this episode step.
Definition: TrainerPG.cs:1980
SimpleDatum Data
Returns the pre-processed data (run through the model) of this episode step.
Definition: TrainerPG.cs:1964
MemoryItem(StateBase s, SimpleDatum x, int nAction, float[] rgfAprob, float fReward)
The constructor.
Definition: TrainerPG.cs:1943
StateBase State
Returns the state and data of this episode step.
Definition: TrainerPG.cs:1956
override string ToString()
Returns the string representation of this episode step.
Definition: TrainerPG.cs:1996
The Optimizer manages a single thread used to apply updates to the primary instance of MyCaffe....
Definition: TrainerPG.cs:392
Optimizer(MyCaffeControl< T > mycaffePrimary)
The constructor.
Definition: TrainerPG.cs:405
void Dispose()
Release all resources used.
Definition: TrainerPG.cs:414
double ApplyUpdates(MyCaffeControl< T > mycaffeWorker, int nIteration)
The ApplyUpdates function sets the parameters, signals the Apply Updates thread, blocks for the opera...
Definition: TrainerPG.cs:462
override void doWork(object arg)
This override is the thread used to apply all updates, its CUDA DeviceID is set to the same device ID...
Definition: TrainerPG.cs:423
The TrainerPG implements a simple Policy Gradient trainer inspired by Andrej Karpathy's blog posed re...
Definition: TrainerPG.cs:28
bool Train(int nN, ITERATOR_TYPE type, TRAIN_STEP step)
Train the network using a modified PG training algorithm optimized for GPU use.
Definition: TrainerPG.cs:195
void Dispose()
Releases all resources used.
Definition: TrainerPG.cs:76
TrainerPG(MyCaffeControl< T > mycaffe, PropertySet properties, CryptoRandom random, IxTrainerCallback icallback)
The constructor.
Definition: TrainerPG.cs:44
ResultCollection RunOne(int nDelay=1000)
Run a single cycle on the environment after the delay.
Definition: TrainerPG.cs:126
byte[] Run(int nN, PropertySet runProp, out string type)
Run a set of iterations and return the resuts.
Definition: TrainerPG.cs:153
bool Shutdown(int nWait)
Shutdown the trainer.
Definition: TrainerPG.cs:108
bool Initialize()
Initialize the trainer.
Definition: TrainerPG.cs:84
bool Test(int nN, ITERATOR_TYPE type)
Run the test cycle - currently this is not implemented.
Definition: TrainerPG.cs:169
The Worker class provides the base class for both the Environment and Optimizer and provides the basi...
Definition: TrainerPG.cs:329
AutoResetEvent m_evtCancel
Specfies the cancel event used to cancel this worker.
Definition: TrainerPG.cs:337
ManualResetEvent m_evtDone
Specfies the done event set when this worker completes.
Definition: TrainerPG.cs:341
int m_nIndex
Specifies the index of this worker.
Definition: TrainerPG.cs:333
virtual void doWork(object arg)
This is the actual thread function that is overriden by each derivative class.
Definition: TrainerPG.cs:360
Task m_workTask
Specifies the worker task that runs the thread function.
Definition: TrainerPG.cs:345
void Start(WorkerStartArgs args)
Start running the thread.
Definition: TrainerPG.cs:368
void Stop(int nWait)
Stop running the thread.
Definition: TrainerPG.cs:378
Worker(int nIdx)
The constructor.
Definition: TrainerPG.cs:351
The WorkerStartArgs provides the arguments used when starting the agent thread.
Definition: TrainerPG.cs:260
WorkerStartArgs(int nCycleDelay, Phase phase, int nN, ITERATOR_TYPE type, TRAIN_STEP step)
The constructor.
Definition: TrainerPG.cs:275
TRAIN_STEP Step
Returns the training step to take (if any). This is used for debugging.
Definition: TrainerPG.cs:288
int CycleDelay
Returns the cycle delay which specifies the amount of time to wait for a cancel.
Definition: TrainerPG.cs:296
Phase Phase
Return the phase on which to run.
Definition: TrainerPG.cs:304
ITERATOR_TYPE IterationType
Returns the iteration type.
Definition: TrainerPG.cs:320
int N
Returns the maximum number of episodes to run.
Definition: TrainerPG.cs:312
The IxTrainerCallback provides functions used by each trainer to 'call-back' to the parent for inform...
Definition: Interfaces.cs:303
The IxTrainerCallbackRNN provides functions used by each trainer to 'call-back' to the parent for inf...
Definition: Interfaces.cs:348
void OnConvertOutput(ConvertOutputArgs e)
The OnConvertOutput callback fires from within the Run method and is used to convert the network's ou...
The IxTrainerRL interface is implemented by each RL Trainer.
Definition: Interfaces.cs:257
The MyCaffe.basecode contains all generic types used throughout MyCaffe.
Definition: Annotation.cs:12
Phase
Defines the Phase under which to run a Net.
Definition: Interfaces.cs:61
The MyCaffe.common namespace contains common MyCaffe classes.
Definition: BatchInput.cs:8
BLOB_TYPE
Defines the tpe of data held by a given Blob.
Definition: Interfaces.cs:62
TRAIN_STEP
Defines the training stepping method (if any).
Definition: Interfaces.cs:131
The MyCaffe.fillers namespace contains all fillers including the Filler class.
The MyCaffe.layers namespace contains all layers that have a solidified code base,...
Definition: LayerFactory.cs:15
The MyCaffe.param namespace contains parameters used to create models.
The MyCaffe.solvers namespace contains all solver classes, including the base Solver.
ITERATOR_TYPE
Specifies the iterator type to use.
Definition: Interfaces.cs:22
The MyCaffe namespace contains the main body of MyCaffe code that closesly tracks the C++ Caffe open-...
Definition: Annotation.cs:12