MyCaffe  1.12.2.41
Deep learning software for Windows C# programmers.
TrainerC51.cs
1using MyCaffe.basecode;
2using MyCaffe.common;
3using MyCaffe.data;
4using MyCaffe.layers;
5using MyCaffe.param;
6using MyCaffe.solvers;
7using System;
8using System.Collections;
9using System.Collections.Generic;
10using System.Diagnostics;
11using System.Drawing;
12using System.Linq;
13using System.Text;
14using System.Threading.Tasks;
15
17{
31 public class TrainerC51<T> : IxTrainerRL, IDisposable
32 {
33 IxTrainerCallback m_icallback;
34 CryptoRandom m_random = new CryptoRandom();
35 MyCaffeControl<T> m_mycaffe;
36 PropertySet m_properties;
37
45 public TrainerC51(MyCaffeControl<T> mycaffe, PropertySet properties, CryptoRandom random, IxTrainerCallback icallback)
46 {
47 m_icallback = icallback;
48 m_mycaffe = mycaffe;
49 m_properties = properties;
50 m_random = random;
51 }
52
56 public void Dispose()
57 {
58 }
59
64 public bool Initialize()
65 {
66 m_mycaffe.CancelEvent.Reset();
67 m_icallback.OnInitialize(new InitializeArgs(m_mycaffe));
68 return true;
69 }
70
76 public bool Shutdown(int nWait)
77 {
78 if (m_mycaffe != null)
79 {
80 m_mycaffe.CancelEvent.Set();
81 wait(nWait);
82 }
83
84 m_icallback.OnShutdown();
85
86 return true;
87 }
88
89 private void wait(int nWait)
90 {
91 int nWaitInc = 250;
92 int nTotalWait = 0;
93
94 while (nTotalWait < nWait)
95 {
96 m_icallback.OnWait(new WaitArgs(nWaitInc));
97 nTotalWait += nWaitInc;
98 }
99 }
100
106 public ResultCollection RunOne(int nDelay = 1000)
107 {
108 m_mycaffe.CancelEvent.Reset();
109 DqnAgent<T> agent = new DqnAgent<T>(m_icallback, m_mycaffe, m_properties, m_random, Phase.TRAIN);
110 agent.Run(Phase.TEST, 1, ITERATOR_TYPE.ITERATION, TRAIN_STEP.NONE);
111 agent.Dispose();
112 return null;
113 }
114
122 public byte[] Run(int nN, PropertySet runProp, out string type)
123 {
124 m_mycaffe.CancelEvent.Reset();
125 DqnAgent<T> agent = new DqnAgent<T>(m_icallback, m_mycaffe, m_properties, m_random, Phase.RUN);
126 byte[] rgResults = agent.Run(nN, out type);
127 agent.Dispose();
128
129 return rgResults;
130 }
131
138 public bool Test(int nN, ITERATOR_TYPE type)
139 {
140 int nDelay = 1000;
141 string strProp = m_properties.ToString();
142
143 // Turn off the num-skip to run at normal speed.
144 strProp += "EnableNumSkip=False;";
145 PropertySet properties = new PropertySet(strProp);
146
147 m_mycaffe.CancelEvent.Reset();
148 DqnAgent<T> agent = new DqnAgent<T>(m_icallback, m_mycaffe, properties, m_random, Phase.TRAIN);
149 agent.Run(Phase.TEST, nN, type, TRAIN_STEP.NONE);
150
151 agent.Dispose();
152 Shutdown(nDelay);
153
154 return true;
155 }
156
164 public bool Train(int nN, ITERATOR_TYPE type, TRAIN_STEP step)
165 {
166 m_mycaffe.CancelEvent.Reset();
167 DqnAgent<T> agent = new DqnAgent<T>(m_icallback, m_mycaffe, m_properties, m_random, Phase.TRAIN);
168 agent.Run(Phase.TRAIN, nN, type, step);
169 agent.Dispose();
170
171 return false;
172 }
173 }
174
175
180 class DqnAgent<T> : IDisposable
181 {
182 IxTrainerCallback m_icallback;
183 Brain<T> m_brain;
184 PropertySet m_properties;
185 CryptoRandom m_random;
186 float m_fGamma = 0.95f;
187 bool m_bUseRawInput = true;
188 int m_nMaxMemory = 50000;
189 int m_nTrainingUpdateFreq = 5000;
190 int m_nExplorationNum = 50000;
191 int m_nEpsSteps = 0;
192 double m_dfEpsStart = 0;
193 double m_dfEpsEnd = 0;
194 double m_dfEpsDelta = 0;
195 double m_dfExplorationRate = 0;
196 STATE m_state = STATE.EXPLORING;
197
198 enum STATE
199 {
200 EXPLORING,
201 TRAINING
202 }
203
204
213 public DqnAgent(IxTrainerCallback icallback, MyCaffeControl<T> mycaffe, PropertySet properties, CryptoRandom random, Phase phase)
214 {
215 m_icallback = icallback;
216 m_brain = new Brain<T>(mycaffe, properties, random, phase);
217 m_properties = properties;
218 m_random = random;
219
220 m_fGamma = (float)properties.GetPropertyAsDouble("Gamma", m_fGamma);
221 m_bUseRawInput = properties.GetPropertyAsBool("UseRawInput", m_bUseRawInput);
222 m_nMaxMemory = properties.GetPropertyAsInt("MaxMemory", m_nMaxMemory);
223 m_nTrainingUpdateFreq = properties.GetPropertyAsInt("TrainingUpdateFreq", m_nTrainingUpdateFreq);
224 m_nExplorationNum = properties.GetPropertyAsInt("ExplorationNum", m_nExplorationNum);
225 m_nEpsSteps = properties.GetPropertyAsInt("EpsSteps", m_nEpsSteps);
226 m_dfEpsStart = properties.GetPropertyAsDouble("EpsStart", m_dfEpsStart);
227 m_dfEpsEnd = properties.GetPropertyAsDouble("EpsEnd", m_dfEpsEnd);
228 m_dfEpsDelta = (m_dfEpsStart - m_dfEpsEnd) / m_nEpsSteps;
229 m_dfExplorationRate = m_dfEpsStart;
230
231 if (m_dfEpsStart < 0 || m_dfEpsStart > 1)
232 throw new Exception("The 'EpsStart' is out of range - please specify a real number in the range [0,1]");
233
234 if (m_dfEpsEnd < 0 || m_dfEpsEnd > 1)
235 throw new Exception("The 'EpsEnd' is out of range - please specify a real number in the range [0,1]");
236
237 if (m_dfEpsEnd > m_dfEpsStart)
238 throw new Exception("The 'EpsEnd' must be less than the 'EpsStart' value.");
239 }
240
244 public void Dispose()
245 {
246 if (m_brain != null)
247 {
248 m_brain.Dispose();
249 m_brain = null;
250 }
251 }
252
253 private StateBase getData(Phase phase, int nAction, int nIdx)
254 {
255 GetDataArgs args = m_brain.getDataArgs(phase, nAction);
256 m_icallback.OnGetData(args);
257 args.State.Data.Index = nIdx;
258 return args.State;
259 }
260
261
262 private int getAction(int nIteration, SimpleDatum sd, SimpleDatum sdClip, int nActionCount, TRAIN_STEP step)
263 {
264 if (step == TRAIN_STEP.NONE)
265 {
266 switch (m_state)
267 {
268 case STATE.EXPLORING:
269 return m_random.Next(nActionCount);
270
271 case STATE.TRAINING:
272 if (m_dfExplorationRate > m_dfEpsEnd)
273 m_dfExplorationRate -= m_dfEpsDelta;
274
275 if (m_random.NextDouble() < m_dfExplorationRate)
276 return m_random.Next(nActionCount);
277 break;
278 }
279 }
280
281 return m_brain.act(sd, sdClip, nActionCount);
282 }
283
284 private void updateStatus(int nIteration, int nEpisodeCount, double dfRewardSum, double dfRunningReward, double dfLoss, double dfLearningRate, bool bModelUpdated)
285 {
286 GetStatusArgs args = new GetStatusArgs(0, nIteration, nEpisodeCount, 1000000, dfRunningReward, dfRewardSum, m_dfExplorationRate, 0, dfLoss, dfLearningRate, bModelUpdated);
287 m_icallback.OnUpdateStatus(args);
288 }
289
296 public byte[] Run(int nIterations, out string type)
297 {
298 IxTrainerCallbackRNN icallback = m_icallback as IxTrainerCallbackRNN;
299 if (icallback == null)
300 throw new Exception("The Run method requires an IxTrainerCallbackRNN interface to convert the results into the native format!");
301
302 StateBase s = getData(Phase.RUN, -1, 0);
303 int nIteration = 0;
304 List<float> rgResults = new List<float>();
305 bool bDifferent;
306
307 while (!m_brain.Cancel.WaitOne(0) && (nIterations == -1 || nIteration < nIterations))
308 {
309 // Preprocess the observation.
310 SimpleDatum x = m_brain.Preprocess(s, m_bUseRawInput, out bDifferent);
311
312 // Forward the policy network and sample an action.
313 int action = m_brain.act(x, s.Clip, s.ActionCount);
314
315 rgResults.Add(s.Data.TimeStamp.ToFileTime());
316 rgResults.Add(s.Data.GetDataAtF(0));
317 rgResults.Add(action);
318
319 nIteration++;
320
321 // Take the next step using the action
322 s = getData(Phase.RUN, action, nIteration);
323 }
324
325 ConvertOutputArgs args = new ConvertOutputArgs(nIterations, rgResults.ToArray());
326 icallback.OnConvertOutput(args);
327
328 type = args.RawType;
329 return args.RawOutput;
330 }
331
332 private bool isAtIteration(int nN, ITERATOR_TYPE type, int nIteration, int nEpisode)
333 {
334 if (nN == -1)
335 return false;
336
337 if (type == ITERATOR_TYPE.EPISODE)
338 {
339 if (nEpisode < nN)
340 return false;
341
342 return true;
343 }
344 else
345 {
346 if (nIteration < nN)
347 return false;
348
349 return true;
350 }
351 }
352
364 public void Run(Phase phase, int nN, ITERATOR_TYPE type, TRAIN_STEP step)
365 {
366 MemoryEpisodeCollection rgMemory = new MemoryEpisodeCollection(m_nMaxMemory);
367 int nIteration = 0;
368 double? dfRunningReward = null;
369 double dfRewardSum = 0;
370 int nEpisode = 0;
371 bool bDifferent = false;
372
373 StateBase s = getData(phase, -1, -1);
374 // Preprocess the observation.
375 SimpleDatum x = m_brain.Preprocess(s, m_bUseRawInput, out bDifferent, true);
376
377 while (!m_brain.Cancel.WaitOne(0) && !isAtIteration(nN, type, nIteration, nEpisode))
378 {
379 if (nIteration > m_nExplorationNum && rgMemory.Count > m_brain.BatchSize)
380 m_state = STATE.TRAINING;
381
382 // Forward the policy network and sample an action.
383 int action = getAction(nIteration, x, s.Clip, s.ActionCount, step);
384
385 // Take the next step using the action
386 StateBase s_ = getData(phase, action, nIteration);
387
388 // Preprocess the next observation.
389 SimpleDatum x_ = m_brain.Preprocess(s_, m_bUseRawInput, out bDifferent);
390 if (!bDifferent)
391 m_brain.Log.WriteLine("WARNING: The current state is the same as the previous state!");
392
393 dfRewardSum += s_.Reward;
394
395 // Build up episode memory, using reward for taking the action.
396 rgMemory.Add(new MemoryItem(s, x, action, s_, x_, s_.Reward, s_.Done, nIteration, nEpisode));
397
398 // Do the training
399 if (m_state == STATE.TRAINING)
400 {
401 MemoryCollection rgRandomSamples = rgMemory.GetRandomSamples(m_random, m_brain.BatchSize);
402 m_brain.Train(nIteration, rgRandomSamples, s.ActionCount);
403
404 if (nIteration % m_nTrainingUpdateFreq == 0)
405 m_brain.UpdateTargetModel();
406 }
407
408 if (s_.Done)
409 {
410 // Update reward running
411 if (!dfRunningReward.HasValue)
412 dfRunningReward = dfRewardSum;
413 else
414 dfRunningReward = dfRunningReward.Value * 0.99 + dfRewardSum * 0.01;
415
416 nEpisode++;
417 updateStatus(nIteration, nEpisode, dfRewardSum, dfRunningReward.Value, 0, 0, m_brain.GetModelUpdated());
418
419 s = getData(phase, -1, -1);
420 x = m_brain.Preprocess(s, m_bUseRawInput, out bDifferent, true);
421 dfRewardSum = 0;
422 }
423 else
424 {
425 s = s_;
426 x = x_;
427 }
428
429 nIteration++;
430 }
431 }
432 }
433
438 class Brain<T> : IDisposable, IxTrainerGetDataCallback
439 {
440 MyCaffeControl<T> m_mycaffe;
441 Solver<T> m_solver;
442 Net<T> m_net;
443 Net<T> m_netTarget;
444 PropertySet m_properties;
445 CryptoRandom m_random;
446 SimpleDatum m_sdLast = null;
447 DataTransformer<T> m_transformer;
448 MemoryLossLayer<T> m_memLoss;
449 SoftmaxCrossEntropyLossLayer<T> m_softmaxLoss = null;
450 SoftmaxLayer<T> m_softmax;
451 Blob<T> m_blobZ = null;
452 Blob<T> m_blobZ1 = null;
453 Blob<T> m_blobQ = null;
454 Blob<T> m_blobMLoss = null;
455 Blob<T> m_blobPLoss = null;
456 Blob<T> m_blobLoss = null;
457 Blob<T> m_blobActionBinaryLoss = null;
458 Blob<T> m_blobActionTarget = null;
459 Blob<T> m_blobAction = null;
460 Blob<T> m_blobLabel = null;
461 float m_fDeltaZ = 0;
462 float[] m_rgfZ = null;
463 float m_fGamma = 0.99f;
464 int m_nAtoms = 51;
465 double m_dfVMax = 10; // Max possible score for Pong per action is 1
466 double m_dfVMin = -10; // Min possible score for Pong per action is -1
467 int m_nFramesPerX = 4;
468 int m_nStackPerX = 4;
469 int m_nBatchSize = 32;
470 int m_nMiniBatch = 1;
471 BlobCollection<T> m_colAccumulatedGradients = new BlobCollection<T>();
472 bool m_bUseAcceleratedTraining = false;
473 double m_dfLearningRate;
474 MemoryCollection m_rgSamples;
475 int m_nActionCount = 3;
476 bool m_bModelUpdated = false;
477 Font m_font = null;
478 Dictionary<Color, Tuple<Brush, Brush, Pen, Brush>> m_rgStyle = new Dictionary<Color, Tuple<Brush, Brush, Pen, Brush>>();
479 List<SimpleDatum> m_rgX = new List<SimpleDatum>();
480 bool m_bNormalizeOverlay = true;
481 List<List<float>> m_rgOverlay = null;
482
483
491 public Brain(MyCaffeControl<T> mycaffe, PropertySet properties, CryptoRandom random, Phase phase)
492 {
493 m_mycaffe = mycaffe;
494 m_solver = mycaffe.GetInternalSolver();
495 m_net = mycaffe.GetInternalNet(phase);
496 m_netTarget = new Net<T>(m_mycaffe.Cuda, m_mycaffe.Log, m_net.net_param, m_mycaffe.CancelEvent, null, phase);
497 m_properties = properties;
498 m_random = random;
499
500 m_transformer = m_mycaffe.DataTransformer;
501 if (m_transformer == null)
502 {
504 int nC = m_mycaffe.CurrentProject.Dataset.TrainingSource.Channels;
505 int nH = m_mycaffe.CurrentProject.Dataset.TrainingSource.Height;
506 int nW = m_mycaffe.CurrentProject.Dataset.TrainingSource.Width;
507 m_transformer = new DataTransformer<T>(m_mycaffe.Cuda, m_mycaffe.Log, trans_param, phase, nC, nH, nW);
508 }
509 m_transformer.param.mean_value.Add(255 / 2); // center
510 m_transformer.param.mean_value.Add(255 / 2);
511 m_transformer.param.mean_value.Add(255 / 2);
512 m_transformer.param.mean_value.Add(255 / 2);
513 m_transformer.param.scale = 1.0 / 255; // normalize
514 m_transformer.Update();
515
516 m_fGamma = (float)properties.GetPropertyAsDouble("Gamma", m_fGamma);
517 m_nAtoms = properties.GetPropertyAsInt("Atoms", m_nAtoms);
518 m_dfVMin = properties.GetPropertyAsDouble("VMin", m_dfVMin);
519 m_dfVMax = properties.GetPropertyAsDouble("VMax", m_dfVMax);
520
521 m_blobZ = new Blob<T>(m_mycaffe.Cuda, m_mycaffe.Log, false);
522 m_blobZ1 = new Blob<T>(m_mycaffe.Cuda, m_mycaffe.Log, false);
523 m_blobQ = new Blob<T>(m_mycaffe.Cuda, m_mycaffe.Log, true);
524 m_blobMLoss = new Blob<T>(m_mycaffe.Cuda, m_mycaffe.Log, true);
525 m_blobPLoss = new Blob<T>(m_mycaffe.Cuda, m_mycaffe.Log, true);
526 m_blobLoss = new Blob<T>(m_mycaffe.Cuda, m_mycaffe.Log, true);
527 m_blobActionBinaryLoss = new Blob<T>(m_mycaffe.Cuda, m_mycaffe.Log, false);
528 m_blobActionTarget = new Blob<T>(m_mycaffe.Cuda, m_mycaffe.Log, false);
529 m_blobAction = new Blob<T>(m_mycaffe.Cuda, m_mycaffe.Log, false);
530 m_blobLabel = new Blob<T>(m_mycaffe.Cuda, m_mycaffe.Log, true);
531
532 m_memLoss = m_net.FindLastLayer(LayerParameter.LayerType.MEMORY_LOSS) as MemoryLossLayer<T>;
533 if (m_memLoss == null)
534 m_mycaffe.Log.FAIL("Missing the expected MEMORY_LOSS layer!");
535
536 m_nMiniBatch = m_properties.GetPropertyAsInt("MiniBatch", m_nMiniBatch);
537 m_bUseAcceleratedTraining = properties.GetPropertyAsBool("UseAcceleratedTraining", false);
538
539 Blob<T> data = m_net.blob_by_name("data");
540 if (data == null)
541 m_mycaffe.Log.FAIL("Missing the expected input 'data' blob!");
542
543 m_nFramesPerX = data.channels;
544 m_nBatchSize = data.num;
545
546 m_solver.parameter.delta = 0.01 / (double)m_nBatchSize;
547
548 if (m_nMiniBatch > 1)
549 {
550 m_colAccumulatedGradients = m_net.learnable_parameters.Clone();
551 m_colAccumulatedGradients.SetDiff(0);
552 }
553 }
554
555 private void dispose(ref Blob<T> b)
556 {
557 if (b != null)
558 {
559 b.Dispose();
560 b = null;
561 }
562 }
563
567 public void Dispose()
568 {
569 dispose(ref m_blobZ);
570 dispose(ref m_blobZ1);
571 dispose(ref m_blobQ);
572 dispose(ref m_blobMLoss);
573 dispose(ref m_blobPLoss);
574 dispose(ref m_blobActionBinaryLoss);
575 dispose(ref m_blobActionTarget);
576 dispose(ref m_blobAction);
577 dispose(ref m_blobLabel);
578
579 if (m_colAccumulatedGradients != null)
580 {
581 m_colAccumulatedGradients.Dispose();
582 m_colAccumulatedGradients = null;
583 }
584
585 if (m_softmax != null)
586 {
587 m_softmax.Dispose();
588 m_softmax = null;
589 }
590
591 if (m_netTarget != null)
592 {
593 m_netTarget.Dispose();
594 m_netTarget = null;
595 }
596
597 if (m_font != null)
598 {
599 m_font.Dispose();
600 m_font = null;
601 }
602
603 foreach (KeyValuePair<Color, Tuple<Brush, Brush, Pen, Brush>> kv in m_rgStyle)
604 {
605 kv.Value.Item1.Dispose();
606 kv.Value.Item2.Dispose();
607 kv.Value.Item3.Dispose();
608 kv.Value.Item4.Dispose();
609 }
610
611 m_rgStyle.Clear();
612 }
613
620 public GetDataArgs getDataArgs(Phase phase, int nAction)
621 {
622 bool bReset = (nAction == -1) ? true : false;
623 return new GetDataArgs(phase, 0, m_mycaffe, m_mycaffe.Log, m_mycaffe.CancelEvent, bReset, nAction, true, false, false, this);
624 }
625
629 public int FrameStack
630 {
631 get { return m_nFramesPerX; }
632 }
633
637 public int BatchSize
638 {
639 get { return m_nBatchSize; }
640 }
641
645 public Log Log
646 {
647 get { return m_mycaffe.Log; }
648 }
649
654 {
655 get { return m_mycaffe.CancelEvent; }
656 }
657
666 public SimpleDatum Preprocess(StateBase s, bool bUseRawInput, out bool bDifferent, bool bReset = false)
667 {
668 bDifferent = false;
669
670 SimpleDatum sd = new SimpleDatum(s.Data, true);
671
672 if (!bUseRawInput)
673 {
674 if (bReset)
675 m_sdLast = null;
676
677 if (m_sdLast == null)
678 sd.Zero();
679 else
680 bDifferent = sd.Sub(m_sdLast);
681
682 m_sdLast = new SimpleDatum(s.Data, true);
683 }
684 else
685 {
686 bDifferent = true;
687 }
688
689 sd.Tag = bReset;
690
691 if (bReset)
692 {
693 m_rgX = new List<SimpleDatum>();
694
695 for (int i = 0; i < m_nFramesPerX * m_nStackPerX; i++)
696 {
697 m_rgX.Add(sd);
698 }
699 }
700 else
701 {
702 m_rgX.Add(sd);
703 m_rgX.RemoveAt(0);
704 }
705
706 SimpleDatum[] rgSd = new SimpleDatum[m_nStackPerX];
707
708 for (int i=0; i<m_nStackPerX; i++)
709 {
710 int nIdx = ((m_nStackPerX - i) * m_nFramesPerX) - 1;
711 rgSd[i] = m_rgX[nIdx];
712 }
713
714 return new SimpleDatum(rgSd.ToList(), true);
715 }
716
717 private float[] createZArray(double dfVMin, double dfVMax, int nAtoms, out float fDeltaZ)
718 {
719 float[] rgZ = new float[nAtoms];
720 float fZ = (float)dfVMin;
721 fDeltaZ = (float)((dfVMax - dfVMin) / (nAtoms - 1));
722
723 for (int i = 0; i < nAtoms; i++)
724 {
725 rgZ[i] = fZ;
726 fZ += fDeltaZ;
727 }
728
729 return rgZ;
730 }
731
732 private void createZ(int nNumSamples, int nActions, int nAtoms)
733 {
734 int nOffset = 0;
735
736 if (m_rgfZ == null)
737 {
738 m_blobZ1.Reshape(1, nAtoms, 1, 1);
739
740 m_rgfZ = createZArray(m_dfVMin, m_dfVMax, m_nAtoms, out m_fDeltaZ);
741 T[] rgfZ0 = Utility.ConvertVec<T>(m_rgfZ);
742
743 m_blobZ1.mutable_cpu_data = rgfZ0;
744 nOffset = 0;
745
746 m_blobZ.Reshape(nActions, m_nBatchSize, nAtoms, 1);
747
748 for (int i = 0; i < nActions; i++)
749 {
750 for (int j = 0; j < m_nBatchSize; j++)
751 {
752 m_mycaffe.Cuda.copy(m_blobZ1.count(), m_blobZ1.gpu_data, m_blobZ.mutable_gpu_data, 0, nOffset);
753 nOffset += m_blobZ1.count();
754 }
755 }
756 }
757
758 m_blobZ.Reshape(nActions, nNumSamples, nAtoms, 1);
759 }
760
769 public int act(SimpleDatum sd, SimpleDatum sdClip, int nActionCount)
770 {
771 setData(m_net, sd, sdClip);
772 m_net.ForwardFromTo(0, m_net.layers.Count - 2);
773
774 Blob<T> logits = m_net.blob_by_name("logits");
775 if (logits == null)
776 throw new Exception("Missing expected 'logits' blob!");
777
778 Blob<T> actions = softmax_forward(logits, m_blobAction);
779
780 createZ(1, nActionCount, m_nAtoms);
781 m_blobQ.ReshapeLike(actions);
782
783 m_mycaffe.Cuda.mul(actions.count(), actions.gpu_data, m_blobZ.gpu_data, m_blobQ.mutable_gpu_data);
784 reduce_sum_axis2(m_blobQ);
785
786 return argmax(Utility.ConvertVecF<T>(m_blobQ.mutable_cpu_data), nActionCount, 0);
787 }
788
793 public bool GetModelUpdated()
794 {
795 bool bModelUpdated = m_bModelUpdated;
796 m_bModelUpdated = false;
797 return bModelUpdated;
798 }
799
803 public void UpdateTargetModel()
804 {
805 m_mycaffe.Log.Enable = false;
806 m_net.CopyTrainedLayersTo(m_netTarget);
807 m_mycaffe.Log.Enable = true;
808 m_bModelUpdated = true;
809 }
810
817 public void Train(int nIteration, MemoryCollection rgSamples, int nActionCount)
818 {
819 m_rgSamples = rgSamples;
820 m_nActionCount = nActionCount;
821
822 m_mycaffe.Log.Enable = false;
823 setData1(m_netTarget, rgSamples);
824 m_netTarget.ForwardFromTo(0, m_netTarget.layers.Count - 2);
825
826 setData1(m_net, rgSamples);
827 m_memLoss.OnGetLoss += m_memLoss_OnGetLoss;
828 m_net.ForwardFromTo();
829 m_memLoss.OnGetLoss -= m_memLoss_OnGetLoss;
830
831 setData0(m_net, rgSamples);
832 m_memLoss.OnGetLoss += m_memLoss_ProjectDistribution;
833
834 if (m_nMiniBatch == 1)
835 {
836 m_solver.Step(1);
837 }
838 else
839 {
840 m_solver.Step(1, TRAIN_STEP.NONE, true, m_bUseAcceleratedTraining, true, true);
841 m_colAccumulatedGradients.Accumulate(m_mycaffe.Cuda, m_net.learnable_parameters, true);
842
843 if (nIteration % m_nMiniBatch == 0)
844 {
845 m_net.learnable_parameters.CopyFrom(m_colAccumulatedGradients, true);
846 m_colAccumulatedGradients.SetDiff(0);
847 m_dfLearningRate = m_solver.ApplyUpdate(nIteration);
848 m_net.ClearParamDiffs();
849 }
850 }
851
852 m_memLoss.OnGetLoss -= m_memLoss_ProjectDistribution;
853 m_mycaffe.Log.Enable = true;
854 }
855
861 private void m_memLoss_ProjectDistribution(object sender, MemoryLossLayerGetLossArgs<T> e)
862 {
863 int nNumSamples = m_rgSamples.Count;
864
865 Blob<T> logits = m_net.blob_by_name("logits");
866 if (logits == null)
867 throw new Exception("Missing expected 'logits' blob!");
868
869 //-------------------------------------------------------
870 // Loss function
871 //-------------------------------------------------------
872
873 m_blobPLoss.ReshapeLike(logits);
874 m_blobLabel.ReshapeLike(logits);
875
876 m_mycaffe.Cuda.mul(logits.count(), logits.gpu_data, m_blobActionBinaryLoss.mutable_gpu_data, m_blobPLoss.mutable_gpu_data); // Logits valid
877 m_blobPLoss.Reshape(m_blobPLoss.num, m_nActionCount, m_nAtoms, 1);
878 m_blobLabel.Reshape(m_blobLabel.num, m_nActionCount, m_nAtoms, 1);
879
880 int nDstOffset = 0;
881 int nSrcOffset = 0;
882
883 for (int i = 0; i < nNumSamples; i++)
884 {
885 for (int j = 0; j < m_nActionCount; j++)
886 {
887 m_mycaffe.Cuda.mul(m_nAtoms, m_blobMLoss.gpu_data, m_blobActionBinaryLoss.gpu_data, m_blobLabel.mutable_gpu_data, nSrcOffset, nDstOffset, nDstOffset);
888 nDstOffset += m_nAtoms;
889 }
890
891 nSrcOffset += m_nAtoms;
892 }
893
894 e.Loss = softmaxLoss_forward(m_blobPLoss, m_blobLabel, m_blobLoss);
895 softmaxLoss_backward(m_blobPLoss, m_blobLabel, m_blobLoss);
896
897 e.EnableLossUpdate = false;
898 m_mycaffe.Cuda.mul(m_blobPLoss.count(), m_blobPLoss.gpu_diff, m_blobActionBinaryLoss.gpu_data, e.Bottom[0].mutable_gpu_diff);
899 }
900
906 private void m_memLoss_OnGetLoss(object sender, MemoryLossLayerGetLossArgs<T> e)
907 {
908 Blob<T> logits = m_net.blob_by_name("logits");
909 if (logits == null)
910 throw new Exception("Missing expected 'logits' blob!");
911
912 Blob<T> actions = softmax_forward(logits, m_blobAction);
913
914 Blob<T> p_logits = m_netTarget.blob_by_name("logits");
915 if (p_logits == null)
916 throw new Exception("Missing expected 'logits' blob!");
917
918 Blob<T> p_actions = softmax_forward(p_logits, m_blobActionTarget);
919
920 int nNumSamples = m_rgSamples.Count;
921 createZ(nNumSamples, m_nActionCount, m_nAtoms);
922
923 m_blobQ.ReshapeLike(actions);
924
925 m_mycaffe.Log.CHECK_EQ(m_blobQ.shape(0), nNumSamples, "The result should have shape(0) = NumSamples which is " + nNumSamples.ToString());
926 m_mycaffe.Log.CHECK_EQ(m_blobQ.shape(1), m_nActionCount, "The result should have shape(1) = Actions which is " + m_nActionCount.ToString());
927 m_mycaffe.Log.CHECK_EQ(m_blobQ.shape(2), m_nAtoms, "The result should have shape(2) = Atoms which is " + m_nAtoms.ToString());
928
929 // Get Optimal Actions for the next states (for distribution z)
930 m_mycaffe.Cuda.mul(actions.count(), actions.gpu_data, m_blobZ.gpu_data, m_blobQ.mutable_gpu_data);
931 reduce_sum_axis2(m_blobQ);
932 m_blobQ.Reshape(nNumSamples, m_nActionCount, 1, 1);
933
934 float[] rgQbatch = Utility.ConvertVecF<T>(m_blobQ.mutable_cpu_data);
935 float[] rgPbatch = Utility.ConvertVecF<T>(p_actions.mutable_cpu_data);
936 float[] rgMBatch = new float[nNumSamples * m_nAtoms];
937
938 for (int i = 0; i < nNumSamples; i++)
939 {
940 int nActionMax = argmax(rgQbatch, m_nActionCount, i);
941
942 if (m_rgSamples[i].IsTerminated)
943 {
944 double dfTz = m_rgSamples[i].Reward;
945
946 // Bounding Tz
947 dfTz = setBounds(dfTz, m_dfVMin, m_dfVMax);
948
949 double dfB = (dfTz - m_dfVMin) / m_fDeltaZ;
950 int nL = (int)Math.Floor(dfB);
951 int nU = (int)Math.Ceiling(dfB);
952 int nIdx = i * m_nAtoms;
953
954 rgMBatch[nIdx + nL] += (float)(nU - dfB);
955 rgMBatch[nIdx + nU] += (float)(dfB - nL);
956 }
957 else
958 {
959 for (int j = 0; j < m_nAtoms; j++)
960 {
961 double dfTz = m_rgSamples[i].Reward + m_fGamma * m_rgfZ[j];
962
963 // Bounding Tz
964 dfTz = setBounds(dfTz, m_dfVMin, m_dfVMax);
965
966 double dfB = (dfTz - m_dfVMin) / m_fDeltaZ;
967 int nL = (int)Math.Floor(dfB);
968 int nU = (int)Math.Ceiling(dfB);
969 int nIdx = i * m_nAtoms;
970 int nIdxT = (i * m_nActionCount * m_nAtoms) + (nActionMax * m_nAtoms);
971
972 rgMBatch[nIdx + nL] += rgPbatch[nIdxT + j] * (float)(nU - dfB);
973 rgMBatch[nIdx + nU] += rgPbatch[nIdxT + j] * (float)(dfB - nL);
974 }
975 }
976
977 // Normalize the atom values to range [0,1]
978 float fSum = 0;
979 for (int j = 0; j < m_nAtoms; j++)
980 {
981 fSum += rgMBatch[(i * m_nAtoms) + j];
982 }
983
984 if (fSum != 0)
985 {
986 for (int j = 0; j < m_nAtoms; j++)
987 {
988 rgMBatch[(i * m_nAtoms) + j] /= fSum;
989 }
990 }
991 }
992
993 m_blobMLoss.Reshape(nNumSamples, m_nAtoms, 1, 1);
994 m_blobMLoss.mutable_cpu_data = Utility.ConvertVec<T>(rgMBatch);
995
996 m_blobActionBinaryLoss.Reshape(nNumSamples, m_nActionCount, m_nAtoms, 1);
997 m_blobActionBinaryLoss.SetData(0.0);
998
999 for (int i = 0; i < m_rgSamples.Count; i++)
1000 {
1001 int nAction = m_rgSamples[i].Action;
1002 int nIdx = (i * m_nActionCount * m_nAtoms) + (nAction * m_nAtoms);
1003
1004 m_blobActionBinaryLoss.SetData(1.0, nIdx, m_nAtoms);
1005 }
1006 }
1007
1008 private float reduce_mean(Blob<T> b)
1009 {
1010 float[] rg = Utility.ConvertVecF<T>(b.mutable_cpu_data);
1011 float fSum = rg.Sum(p => p);
1012 return fSum / rg.Length;
1013 }
1014
1015 private void reduce_sum_axis1(Blob<T> b)
1016 {
1017 int nNum = b.shape(0);
1018 int nActions = b.shape(1);
1019 int nAtoms = b.shape(2);
1020 float[] rg = Utility.ConvertVecF<T>(b.mutable_cpu_data);
1021 float[] rgSum = new float[nNum * nAtoms];
1022
1023 for (int i = 0; i < nNum; i++)
1024 {
1025 for (int j = 0; j < nAtoms; j++)
1026 {
1027 float fSum = 0;
1028
1029 for (int k = 0; k < nActions; k++)
1030 {
1031 int nIdx = (i * nActions * nAtoms) + (k * nAtoms);
1032 fSum += rg[nIdx + j];
1033 }
1034
1035 int nIdxR = i * nAtoms;
1036 rgSum[nIdxR + j] = fSum;
1037 }
1038 }
1039
1040 b.Reshape(nNum, nAtoms, 1, 1);
1041 b.mutable_cpu_data = Utility.ConvertVec<T>(rgSum);
1042 }
1043
1044 private void reduce_sum_axis2(Blob<T> b)
1045 {
1046 int nNum = b.shape(0);
1047 int nActions = b.shape(1);
1048 int nAtoms = b.shape(2);
1049 float[] rg = Utility.ConvertVecF<T>(b.mutable_cpu_data);
1050 float[] rgSum = new float[nNum * nActions];
1051
1052 for (int i = 0; i < nNum; i++)
1053 {
1054 for (int j = 0; j < nActions; j++)
1055 {
1056 int nIdx = (i * nActions * nAtoms) + (j * nAtoms);
1057 float fSum = 0;
1058
1059 for (int k = 0; k < nAtoms; k++)
1060 {
1061 fSum += rg[nIdx + k];
1062 }
1063
1064 int nIdxR = i * nActions;
1065 rgSum[nIdxR + j] = fSum;
1066 }
1067 }
1068
1069 b.Reshape(nNum, nAtoms, 1, 1);
1070 b.mutable_cpu_data = Utility.ConvertVec<T>(rgSum);
1071 }
1072
1073 private double softmaxLoss_forward(Blob<T> actual, Blob<T> target, Blob<T> loss)
1074 {
1075 BlobCollection<T> colBottom = new BlobCollection<T>();
1076 colBottom.Add(actual);
1077 colBottom.Add(target);
1078
1079 BlobCollection<T> colTop = new BlobCollection<T>();
1080 colTop.Add(loss);
1081
1082 if (m_softmaxLoss == null)
1083 {
1084 LayerParameter p = new LayerParameter(LayerParameter.LayerType.SOFTMAXCROSSENTROPY_LOSS);
1085 p.softmax_param.axis = 2;
1087 m_softmaxLoss = new SoftmaxCrossEntropyLossLayer<T>(m_mycaffe.Cuda, m_mycaffe.Log, p);
1088 m_softmaxLoss.Setup(colBottom, colTop);
1089 }
1090
1091 return m_softmaxLoss.Forward(colBottom, colTop);
1092 }
1093
1094 private void softmaxLoss_backward(Blob<T> actual, Blob<T> target, Blob<T> loss)
1095 {
1096 BlobCollection<T> colBottom = new BlobCollection<T>();
1097 colBottom.Add(actual);
1098 colBottom.Add(target);
1099
1100 BlobCollection<T> colTop = new BlobCollection<T>();
1101 colTop.Add(loss);
1102
1103 m_softmaxLoss.Backward(colTop, new List<bool>() { true, false }, colBottom);
1104 }
1105
1106 private Blob<T> softmax_forward(Blob<T> bBottom, Blob<T> bTop)
1107 {
1108 BlobCollection<T> colBottom = new BlobCollection<T>();
1109 colBottom.Add(bBottom);
1110
1111 BlobCollection<T> colTop = new BlobCollection<T>();
1112 colTop.Add(bTop);
1113
1114 if (m_softmax == null)
1115 {
1117 p.softmax_param.axis = 2;
1118 m_softmax = new SoftmaxLayer<T>(m_mycaffe.Cuda, m_mycaffe.Log, p);
1119 m_softmax.Setup(colBottom, colTop);
1120 }
1121
1122 m_softmax.Reshape(colBottom, colTop);
1123 m_softmax.Forward(colBottom, colTop);
1124
1125 return colTop[0];
1126 }
1127
1128 private double setBounds(double z, double dfMin, double dfMax)
1129 {
1130 if (z > dfMax)
1131 return dfMax;
1132
1133 if (z < dfMin)
1134 return dfMin;
1135
1136 return z;
1137 }
1138
1139 private int argmax(float[] rgProb, int nActionCount, int nSampleIdx)
1140 {
1141 float[] rgfProb = new float[nActionCount];
1142
1143 for (int j = 0; j < nActionCount; j++)
1144 {
1145 int nIdx = (nSampleIdx * nActionCount) + j;
1146 rgfProb[j] = rgProb[nIdx];
1147 }
1148
1149 return argmax(rgfProb);
1150 }
1151
1152 private int argmax(float[] rgfAprob)
1153 {
1154 float fMax = -float.MaxValue;
1155 int nIdx = 0;
1156
1157 for (int i = 0; i < rgfAprob.Length; i++)
1158 {
1159 if (rgfAprob[i] == fMax)
1160 {
1161 if (m_random.NextDouble() > 0.5)
1162 nIdx = i;
1163 }
1164 else if (fMax < rgfAprob[i])
1165 {
1166 fMax = rgfAprob[i];
1167 nIdx = i;
1168 }
1169 }
1170
1171 return nIdx;
1172 }
1173
1174 private void setData(Net<T> net, SimpleDatum sdData, SimpleDatum sdClip)
1175 {
1176 SimpleDatum[] rgData = new SimpleDatum[] { sdData };
1177 SimpleDatum[] rgClip = null;
1178
1179 if (sdClip != null)
1180 rgClip = new SimpleDatum[] { sdClip };
1181
1182 setData(net, rgData, rgClip);
1183 }
1184
1185 private void setData0(Net<T> net, MemoryCollection rgSamples)
1186 {
1187 List<SimpleDatum> rgData0 = rgSamples.GetData0();
1188 List<SimpleDatum> rgClip0 = rgSamples.GetClip0();
1189
1190 SimpleDatum[] rgData = rgData0.ToArray();
1191 SimpleDatum[] rgClip = (rgClip0 != null) ? rgClip0.ToArray() : null;
1192
1193 setData(net, rgData, rgClip);
1194 }
1195
1196 private void setData1(Net<T> net, MemoryCollection rgSamples)
1197 {
1198 List<SimpleDatum> rgData1 = rgSamples.GetData1();
1199 List<SimpleDatum> rgClip1 = rgSamples.GetClip1();
1200
1201 SimpleDatum[] rgData = rgData1.ToArray();
1202 SimpleDatum[] rgClip = (rgClip1 != null) ? rgClip1.ToArray() : null;
1203
1204 setData(net, rgData, rgClip);
1205 }
1206
1207 private void setData(Net<T> net, SimpleDatum[] rgData, SimpleDatum[] rgClip)
1208 {
1209 Blob<T> data = net.blob_by_name("data");
1210
1211 data.Reshape(rgData.Length, data.channels, data.height, data.width);
1212 m_transformer.Transform(rgData, data, m_mycaffe.Cuda, m_mycaffe.Log);
1213
1214 if (rgClip != null)
1215 {
1216 Blob<T> clip = net.blob_by_name("clip");
1217
1218 if (clip != null)
1219 {
1220 clip.Reshape(rgClip.Length, rgClip[0].Channels, rgClip[0].Height, rgClip[0].Width);
1221 m_transformer.Transform(rgClip, clip, m_mycaffe.Cuda, m_mycaffe.Log, true);
1222 }
1223 }
1224 }
1225
1230 public void OnOverlay(OverlayArgs e)
1231 {
1232 Blob<T> logits = m_net.blob_by_name("logits");
1233 if (logits == null)
1234 return;
1235
1236 if (logits.num == 1)
1237 {
1238 Blob<T> actions = softmax_forward(logits, m_blobAction);
1239
1240 float[] rgActions = Utility.ConvertVecF<T>(actions.mutable_cpu_data);
1241
1242 List<List<float>> rgData = new List<List<float>>();
1243 for (int i = 0; i < m_nActionCount; i++)
1244 {
1245 List<float> rgProb = new List<float>();
1246
1247 for (int j = 0; j < m_nAtoms; j++)
1248 {
1249 int nIdx = (i * m_nAtoms) + j;
1250 rgProb.Add(rgActions[nIdx]);
1251 }
1252
1253 rgData.Add(rgProb);
1254 }
1255
1256 m_rgOverlay = rgData;
1257 }
1258
1259 if (m_rgOverlay == null)
1260 return;
1261
1262 using (Graphics g = Graphics.FromImage(e.DisplayImage))
1263 {
1264 int nBorder = 30;
1265 int nWid = e.DisplayImage.Width - (nBorder * 2);
1266 int nWid1 = nWid / m_rgOverlay.Count;
1267 int nHt1 = (int)(e.DisplayImage.Height * 0.3);
1268 int nX = nBorder;
1269 int nY = e.DisplayImage.Height - nHt1;
1270 ColorMapper clrMap = new ColorMapper(0, m_rgOverlay.Count + 1, Color.Black, Color.Red);
1271 float[] rgfMin = new float[m_rgOverlay.Count];
1272 float[] rgfMax = new float[m_rgOverlay.Count];
1273 float fMax = -float.MaxValue;
1274 float fMaxMax = -float.MaxValue;
1275 int nMaxIdx = 0;
1276
1277 for (int i=0; i<m_rgOverlay.Count; i++)
1278 {
1279 rgfMin[i] = m_rgOverlay[i].Min(p => p);
1280 rgfMax[i] = m_rgOverlay[i].Max(p => p);
1281
1282 if (rgfMax[i] > fMax)
1283 {
1284 fMax = rgfMax[i];
1285 nMaxIdx = i;
1286 }
1287
1288 fMaxMax = Math.Max(fMax, fMaxMax);
1289 }
1290
1291 if (fMaxMax > 0.2f)
1292 m_bNormalizeOverlay = false;
1293
1294 for (int i = 0; i < m_rgOverlay.Count; i++)
1295 {
1296 drawProbabilities(g, nX, nY, nWid1, nHt1, i, m_rgOverlay[i], clrMap.GetColor(i + 1), rgfMin.Min(p => p), rgfMax.Max(p => p), (i == nMaxIdx) ? true : false, m_bNormalizeOverlay);
1297 nX += nWid1;
1298 }
1299 }
1300 }
1301
1302 private void drawProbabilities(Graphics g, int nX, int nY, int nWid, int nHt, int nAction, List<float> rgProb, Color clr, float fMin, float fMax, bool bMax, bool bNormalize)
1303 {
1304 string str = "";
1305
1306 if (m_font == null)
1307 m_font = new Font("Century Gothic", 9.0f);
1308
1309 if (!m_rgStyle.ContainsKey(clr))
1310 {
1311 Color clr1 = Color.FromArgb(128, clr);
1312 Brush br1 = new SolidBrush(clr1);
1313 Color clr2 = Color.FromArgb(64, clr);
1314 Pen pen = new Pen(clr2, 1.0f);
1315 Brush br2 = new SolidBrush(clr2);
1316 Brush brBright = new SolidBrush(clr);
1317 m_rgStyle.Add(clr, new Tuple<Brush, Brush, Pen, Brush>(br1, br2, pen, brBright));
1318 }
1319
1320 Brush brBack = m_rgStyle[clr].Item1;
1321 Brush brFront = m_rgStyle[clr].Item2;
1322 Brush brTop = m_rgStyle[clr].Item4;
1323 Pen penLine = m_rgStyle[clr].Item3;
1324
1325 if (fMin != 0 || fMax != 0)
1326 {
1327 str = "Action " + nAction.ToString() + " (" + (fMax - fMin).ToString("N7") + ")";
1328 }
1329 else
1330 {
1331 str = "Action " + nAction.ToString() + " - No Probabilities";
1332 }
1333
1334 SizeF sz = g.MeasureString(str, m_font);
1335
1336 int nY1 = (int)(nY + (nHt - sz.Height));
1337 int nX1 = (int)(nX + (nWid / 2) - (sz.Width / 2));
1338 g.DrawString(str, m_font, (bMax) ? brTop : brFront, new Point(nX1, nY1));
1339
1340 if (fMin != 0 || fMax != 0)
1341 {
1342 float fX = nX;
1343 float fWid = nWid / (float)rgProb.Count;
1344 nHt -= (int)sz.Height;
1345
1346 for (int i = 0; i < rgProb.Count; i++)
1347 {
1348 float fProb = rgProb[i];
1349
1350 if (bNormalize)
1351 fProb = (fProb - fMin) / (fMax - fMin);
1352
1353 float fHt = nHt * fProb;
1354 float fHt1 = nHt - fHt;
1355 RectangleF rc1 = new RectangleF(fX, nY + fHt1, fWid, fHt);
1356 g.FillRectangle(brBack, rc1);
1357 g.DrawRectangle(penLine, rc1.X, rc1.Y, rc1.Width, rc1.Height);
1358 fX += fWid;
1359 }
1360 }
1361 }
1362 }
1363
1364 class MemoryEpisodeCollection
1365 {
1366 int m_nTotalCount = 0;
1367 int m_nMax;
1368 List<MemoryCollection> m_rgItems = new List<MemoryCollection>();
1369
1370 public enum ITEM
1371 {
1372 DATA0,
1373 DATA1,
1374 CLIP0,
1375 CLIP1
1376 }
1377
1378 public MemoryEpisodeCollection(int nMax)
1379 {
1380 m_nMax = nMax;
1381 }
1382
1383 public int Count
1384 {
1385 get { return m_rgItems.Count; }
1386 }
1387
1388 public void Clear()
1389 {
1390 m_nTotalCount = 0;
1391 m_rgItems.Clear();
1392 }
1393
1394 public void Add(MemoryItem item)
1395 {
1396 m_nTotalCount++;
1397
1398 if (m_rgItems.Count == 0 || m_rgItems[m_rgItems.Count - 1].Episode != item.Episode)
1399 {
1400 MemoryCollection col = new MemoryCollection(int.MaxValue);
1401 col.Add(item);
1402 m_rgItems.Add(col);
1403 }
1404 else
1405 {
1406 m_rgItems[m_rgItems.Count - 1].Add(item);
1407 }
1408
1409 if (m_nTotalCount > m_nMax)
1410 {
1411 List<MemoryCollection> rgItems = m_rgItems.OrderBy(p => p.TotalReward).ToList();
1412 m_nTotalCount -= rgItems[0].Count;
1413 m_rgItems.Remove(rgItems[0]);
1414 }
1415 }
1416
1417 public MemoryCollection GetRandomSamples(CryptoRandom random, int nCount)
1418 {
1419 MemoryCollection col = new MemoryCollection(nCount);
1420 List<string> rgItems = new List<string>();
1421
1422 for (int i = 0; i < nCount; i++)
1423 {
1424 int nEpisode = random.Next(m_rgItems.Count);
1425 int nItem = random.Next(m_rgItems[nEpisode].Count);
1426 string strItem = nEpisode.ToString() + "_" + nItem.ToString();
1427
1428 if (!rgItems.Contains(strItem))
1429 {
1430 col.Add(m_rgItems[nEpisode][nItem]);
1431 rgItems.Add(strItem);
1432 }
1433 }
1434
1435 return col;
1436 }
1437
1438 List<StateBase> GetState1()
1439 {
1440 List<StateBase> rgItems = new List<StateBase>();
1441
1442 for (int i = 0; i < m_rgItems.Count; i++)
1443 {
1444 for (int j = 0; j < m_rgItems[i].Count; j++)
1445 {
1446 rgItems.Add(m_rgItems[i][j].State1);
1447 }
1448 }
1449
1450 return rgItems;
1451 }
1452
1453 List<SimpleDatum> GetItem(ITEM item)
1454 {
1455 List<SimpleDatum> rgItems = new List<SimpleDatum>();
1456
1457 for (int i = 0; i < m_rgItems.Count; i++)
1458 {
1459 switch (item)
1460 {
1461 case ITEM.DATA0:
1462 rgItems.AddRange(m_rgItems[i].GetData0());
1463 break;
1464
1465 case ITEM.DATA1:
1466 rgItems.AddRange(m_rgItems[i].GetData1());
1467 break;
1468
1469 case ITEM.CLIP0:
1470 rgItems.AddRange(m_rgItems[i].GetClip0());
1471 break;
1472
1473 case ITEM.CLIP1:
1474 rgItems.AddRange(m_rgItems[i].GetClip1());
1475 break;
1476 }
1477 }
1478
1479 return rgItems;
1480 }
1481 }
1482
1483 class MemoryCollection : IEnumerable<MemoryItem>
1484 {
1485 double m_dfTotalReward = 0;
1486 int m_nEpisode;
1487 int m_nMax;
1488 List<MemoryItem> m_rgItems = new List<MemoryItem>();
1489
1490 public MemoryCollection(int nMax)
1491 {
1492 m_nMax = nMax;
1493 }
1494
1495 public int Count
1496 {
1497 get { return m_rgItems.Count; }
1498 }
1499
1500 public MemoryItem this[int nIdx]
1501 {
1502 get { return m_rgItems[nIdx]; }
1503 }
1504
1505 public void Add(MemoryItem item)
1506 {
1507 m_nEpisode = item.Episode;
1508 m_dfTotalReward += item.Reward;
1509
1510 m_rgItems.Add(item);
1511
1512 if (m_rgItems.Count > m_nMax)
1513 m_rgItems.RemoveAt(0);
1514 }
1515
1516 public void Clear()
1517 {
1518 m_nEpisode = 0;
1519 m_dfTotalReward = 0;
1520 m_rgItems.Clear();
1521 }
1522
1523 public int Episode
1524 {
1525 get { return m_nEpisode; }
1526 }
1527
1528 public double TotalReward
1529 {
1530 get { return m_dfTotalReward; }
1531 }
1532
1533 public MemoryCollection GetRandomSamples(CryptoRandom random, int nCount)
1534 {
1535 MemoryCollection col = new MemoryCollection(m_nMax);
1536 List<int> rgIdx = new List<int>();
1537
1538 while (col.Count < nCount)
1539 {
1540 int nIdx = random.Next(m_rgItems.Count);
1541 if (!rgIdx.Contains(nIdx))
1542 {
1543 col.Add(m_rgItems[nIdx]);
1544 rgIdx.Add(nIdx);
1545 }
1546 }
1547
1548 return col;
1549 }
1550
1551 public List<StateBase> GetState1()
1552 {
1553 return m_rgItems.Select(p => p.State1).ToList();
1554 }
1555
1556 public List<SimpleDatum> GetData1()
1557 {
1558 return m_rgItems.Select(p => p.Data1).ToList();
1559 }
1560
1561 public List<SimpleDatum> GetClip1()
1562 {
1563 if (m_rgItems[0].State1.Clip != null)
1564 return m_rgItems.Select(p => p.State1.Clip).ToList();
1565
1566 return null;
1567 }
1568
1569 public List<SimpleDatum> GetData0()
1570 {
1571 return m_rgItems.Select(p => p.Data0).ToList();
1572 }
1573
1574 public List<SimpleDatum> GetClip0()
1575 {
1576 if (m_rgItems[0].State0.Clip != null)
1577 return m_rgItems.Select(p => p.State0.Clip).ToList();
1578
1579 return null;
1580 }
1581
1582 public IEnumerator<MemoryItem> GetEnumerator()
1583 {
1584 return m_rgItems.GetEnumerator();
1585 }
1586
1587 IEnumerator IEnumerable.GetEnumerator()
1588 {
1589 return m_rgItems.GetEnumerator();
1590 }
1591
1592 public override string ToString()
1593 {
1594 return "Episode #" + m_nEpisode.ToString() + " (" + m_rgItems.Count.ToString() + ") => " + m_dfTotalReward.ToString();
1595 }
1596 }
1597
1598 class MemoryItem
1599 {
1600 StateBase m_state0;
1601 StateBase m_state1;
1602 SimpleDatum m_x0;
1603 SimpleDatum m_x1;
1604 int m_nAction;
1605 int m_nIteration;
1606 int m_nEpisode;
1607 bool m_bTerminated;
1608 double m_dfReward;
1609
1610 public MemoryItem(StateBase s, SimpleDatum x, int nAction, StateBase s_, SimpleDatum x_, double dfReward, bool bTerminated, int nIteration, int nEpisode)
1611 {
1612 m_state0 = s;
1613 m_state1 = s_;
1614 m_x0 = x;
1615 m_x1 = x_;
1616 m_nAction = nAction;
1617 m_bTerminated = bTerminated;
1618 m_dfReward = dfReward;
1619 m_nIteration = nIteration;
1620 m_nEpisode = nEpisode;
1621 }
1622
1623 public bool IsTerminated
1624 {
1625 get { return m_bTerminated; }
1626 }
1627
1628 public double Reward
1629 {
1630 get { return m_dfReward; }
1631 set { m_dfReward = value; }
1632 }
1633
1634 public StateBase State0
1635 {
1636 get { return m_state0; }
1637 }
1638
1639 public StateBase State1
1640 {
1641 get { return m_state1; }
1642 }
1643
1644 public SimpleDatum Data0
1645 {
1646 get { return m_x0; }
1647 }
1648
1649 public SimpleDatum Data1
1650 {
1651 get { return m_x1; }
1652 }
1653
1654 public int Action
1655 {
1656 get { return m_nAction; }
1657 }
1658
1659 public int Iteration
1660 {
1661 get { return m_nIteration; }
1662 }
1663
1664 public int Episode
1665 {
1666 get { return m_nEpisode; }
1667 }
1668
1669 public override string ToString()
1670 {
1671 return "episode = " + m_nEpisode.ToString() + " action = " + m_nAction.ToString() + " reward = " + m_dfReward.ToString("N2");
1672 }
1673
1674 private string tostring(float[] rg)
1675 {
1676 string str = "{";
1677
1678 for (int i = 0; i < rg.Length; i++)
1679 {
1680 str += rg[i].ToString("N5");
1681 str += ",";
1682 }
1683
1684 str = str.TrimEnd(',');
1685 str += "}";
1686
1687 return str;
1688 }
1689 }
1690}
The MyCaffeControl is the main object used to manage all training, testing and running of the MyCaffe...
CancelEvent CancelEvent
Returns the CancelEvent used.
Net< T > GetInternalNet(Phase phase=Phase.RUN)
Returns the internal net based on the Phase specified: TRAIN, TEST or RUN.
Solver< T > GetInternalSolver()
Get the internal solver.
The CancelEvent provides an extension to the manual cancel event that allows for overriding the manua...
Definition: CancelEvent.cs:17
void Reset()
Resets the event clearing any signaled state.
Definition: CancelEvent.cs:279
bool WaitOne(int nMs=int.MaxValue)
Waits for the signal state to occur.
Definition: CancelEvent.cs:290
CancelEvent()
The CancelEvent constructor.
Definition: CancelEvent.cs:28
void Set()
Sets the event to the signaled state.
Definition: CancelEvent.cs:270
The ColorMapper maps a value within a number range, to a Color within a color scheme.
Definition: ColorMapper.cs:14
Color GetColor(double dfVal)
Find the color using a binary search algorithm.
Definition: ColorMapper.cs:350
The CryptoRandom is a random number generator that can use either the standard .Net Random objec or t...
Definition: CryptoRandom.cs:14
int Next(int nMinVal, int nMaxVal, bool bMaxInclusive=true)
Returns a random int within the range
double NextDouble()
Returns a random double within the range .
Definition: CryptoRandom.cs:83
The Log class provides general output in text form.
Definition: Log.cs:13
void WriteLine(string str, bool bOverrideEnabled=false, bool bHeader=false, bool bError=false, bool bDisable=false)
Write a line of output.
Definition: Log.cs:80
Log(string strSrc)
The Log constructor.
Definition: Log.cs:33
Specifies a key-value pair of properties.
Definition: PropertySet.cs:16
int GetPropertyAsInt(string strName, int nDefault=0)
Returns a property as an integer value.
Definition: PropertySet.cs:287
bool GetPropertyAsBool(string strName, bool bDefault=false)
Returns a property as a boolean value.
Definition: PropertySet.cs:267
double GetPropertyAsDouble(string strName, double dfDefault=0)
Returns a property as an double value.
Definition: PropertySet.cs:307
override string ToString()
Returns the string representation of the properties.
Definition: PropertySet.cs:325
The SimpleDatum class holds a data input within host memory.
Definition: SimpleDatum.cs:161
float GetDataAtF(int nIdx)
Returns the item at a specified index in the float type.
bool Sub(SimpleDatum sd, bool bSetNegativeToZero=false)
Subtract the data of another SimpleDatum from this one, so this = this - sd.
void Zero()
Zero out all data in the datum but keep the size and other settings.
DateTime TimeStamp
Get/set the Timestamp.
object Tag
Specifies user data associated with the SimpleDatum.
Definition: SimpleDatum.cs:901
override string ToString()
Return a string representation of the SimpleDatum.
int Channels
Return the number of channels of the data.
int Index
Returns the index of the SimpleDatum.
The Utility class provides general utility funtions.
Definition: Utility.cs:35
static double[] ConvertVec(float[] rgf)
Convert an array of float to an array of generics.
Definition: Utility.cs:550
The BlobCollection contains a list of Blobs.
void Dispose()
Release all resource used by the collection and its Blobs.
void Add(Blob< T > b)
Add a new Blob to the collection.
void Accumulate(CudaDnn< T > cuda, BlobCollection< T > src, bool bAccumulateDiff)
Accumulate the diffs from one BlobCollection into another.
void SetDiff(double df)
Set all blob diff to the value specified.
The Blob is the main holder of data that moves through the Layers of the Net.
Definition: Blob.cs:25
int channels
DEPRECIATED; legacy shape accessor channels: use shape(1) instead.
Definition: Blob.cs:800
void SetData(T[] rgData, int nCount=-1, bool bSetCount=true)
Sets a number of items within the Blob's data.
Definition: Blob.cs:1922
int height
DEPRECIATED; legacy shape accessor height: use shape(2) instead.
Definition: Blob.cs:808
long mutable_gpu_data
Returns the data GPU handle used by the CudaDnn connection.
Definition: Blob.cs:1487
T[] mutable_cpu_data
Get data from the GPU and bring it over to the host, or Set data from the Host and send it over to th...
Definition: Blob.cs:1461
void Reshape(int nNum, int nChannels, int nHeight, int nWidth, bool? bUseHalfSize=null)
DEPRECIATED; use
Definition: Blob.cs:442
int width
DEPRECIATED; legacy shape accessor width: use shape(3) instead.
Definition: Blob.cs:816
List< int > shape()
Returns an array where each element contains the shape of an axis of the Blob.
Definition: Blob.cs:684
int count()
Returns the total number of items in the Blob.
Definition: Blob.cs:739
void ReshapeLike(Blob< T > b, bool? bUseHalfSize=null)
Reshape this Blob to have the same shape as another Blob.
Definition: Blob.cs:648
long gpu_diff
Returns the diff GPU handle used by the CudaDnn connection.
Definition: Blob.cs:1541
int num
DEPRECIATED; legacy shape accessor num: use shape(0) instead.
Definition: Blob.cs:792
long gpu_data
Returns the data GPU handle used by the CudaDnn connection.
Definition: Blob.cs:1479
Connects Layer's together into a direct acrylic graph (DAG) specified by a NetParameter
Definition: Net.cs:23
List< Layer< T > > layers
Returns the layers.
Definition: Net.cs:2003
double ForwardFromTo(int nStart=0, int nEnd=int.MaxValue)
The FromTo variant of forward and backward operate on the (topological) ordering by which the net is ...
Definition: Net.cs:1402
void CopyTrainedLayersTo(Net< T > dstNet)
Copies the trained layer of this Net to another Net.
Definition: Net.cs:1714
Layer< T > FindLastLayer(LayerParameter.LayerType type)
Find the last layer with the matching type.
Definition: Net.cs:2806
virtual void Dispose(bool bDisposing)
Releases all resources (GPU and Host) used by the Net.
Definition: Net.cs:184
void ClearParamDiffs()
Zero out the diffs of all netw parameters. This should be run before Backward.
Definition: Net.cs:1907
BlobCollection< T > learnable_parameters
Returns the learnable parameters.
Definition: Net.cs:2117
NetParameter net_param
Returns the net parameter.
Definition: Net.cs:1857
Blob< T > blob_by_name(string strName, bool bThrowExceptionOnError=true)
Returns a blob given its name.
Definition: Net.cs:2245
The ResultCollection contains the result of a given CaffeControl::Run.
Applies common transformations to the input data, such as scaling, mirroring, subtracting the image m...
DataTransformer(CudaDnn< T > cuda, Log log, TransformationParameter p, Phase phase, int nC, int nH, int nW, SimpleDatum imgMean=null)
The DataTransformer constructor.
void Update(int nDataSize=0, SimpleDatum imgMean=null)
Resync the transformer with changes in its parameter.
void Transform(List< Datum > rgDatum, Blob< T > blobTransformed, CudaDnn< T > cuda, Log log)
Transforms a list of Datum and places the transformed data into a Blob.
TransformationParameter param
Returns the TransformationParameter used.
void Backward(BlobCollection< T > colTop, List< bool > rgbPropagateDown, BlobCollection< T > colBottom)
Given the top Blob error gradients, compute the bottom Blob error gradients.
Definition: Layer.cs:815
double Forward(BlobCollection< T > colBottom, BlobCollection< T > colTop)
Given the bottom (input) Blobs, this function computes the top (output) Blobs and the loss.
Definition: Layer.cs:728
void Dispose()
Releases all GPU and host resources used by the Layer.
Definition: Layer.cs:180
void Setup(BlobCollection< T > colBottom, BlobCollection< T > colTop)
Implements common Layer setup functionality.
Definition: Layer.cs:439
The MemoryLossLayerGetLossArgs class is passed to the OnGetLoss event.
bool EnableLossUpdate
Get/set enabling the loss update within the backpropagation pass.
double Loss
Get/set the externally calculated total loss.
BlobCollection< T > Bottom
Specifies the bottom passed in during the forward pass.
The MemoryLossLayer provides a method of performing a custom loss functionality. Similar to the Memor...
EventHandler< MemoryLossLayerGetLossArgs< T > > OnGetLoss
The OnGetLoss event fires during each forward pass. The value returned is saved, and applied on the b...
The SoftmaxCrossEntropyLossLayer computes the cross-entropy (logisitic) loss and is often used for pr...
The SoftmaxLayer computes the softmax function. This layer is initialized with the MyCaffe....
Definition: SoftmaxLayer.cs:24
override void Reshape(BlobCollection< T > colBottom, BlobCollection< T > colTop)
Reshape the bottom (input) and top (output) blobs.
Specifies the base parameter for all layers.
SoftmaxParameter softmax_param
Returns the parameter set when initialized with LayerType.SOFTMAX
LayerType
Specifies the layer type.
LossParameter loss_param
Returns the parameter set when initialized with LayerType.LOSS
Stores the parameters used by loss layers.
NormalizationMode
How to normalize the loss for loss layers that aggregate across batches, spatial dimensions,...
NormalizationMode? normalization
Specifies the normalization mode (default = VALID).
int axis
The axis along which to perform the softmax – may be negative to index from the end (e....
double delta
Numerical stability for RMSProp, AdaGrad, AdaDelta, Adam and AdamW solvers (default = 1e-08).
Stores parameters used to apply transformation to the data layer's data.
List< double > mean_value
If specified can be repeated once (would subtract it from all the channels or can be repeated the sam...
double scale
For data pre-processing, we can do simple scaling and subtracting the data mean, if provided....
An interface for classes that perform optimization on Nets - this class serves as the base class for ...
Definition: Solver.cs:28
SolverParameter parameter
Returns the SolverParameter used.
Definition: Solver.cs:1221
bool Step(int nIters, TRAIN_STEP step=TRAIN_STEP.NONE, bool bZeroDiffs=true, bool bApplyUpdates=true, bool bDisableOutput=false, bool bDisableProgress=false, double? dfLossOverride=null, bool? bAllowSnapshot=null)
Steps a set of iterations through a training cycle.
Definition: Solver.cs:818
abstract double ApplyUpdate(int nIterationOverride=-1)
Make and apply the update value for the current iteration.
The ConvertOutputArgs is passed to the OnConvertOutput event.
Definition: EventArgs.cs:311
byte[] RawOutput
Specifies the raw output byte stream.
Definition: EventArgs.cs:356
string RawType
Specifies the type of the raw output byte stream.
Definition: EventArgs.cs:348
The GetDataArgs is passed to the OnGetData event to retrieve data.
Definition: EventArgs.cs:402
StateBase State
Specifies the state data of the observations.
Definition: EventArgs.cs:517
The InitializeArgs is passed to the OnInitialize event.
Definition: EventArgs.cs:90
The OverlayArgs is passed ot the OnOverlay event, optionally fired just before displaying a gym image...
Definition: EventArgs.cs:376
Bitmap DisplayImage
Get/set the display image.
Definition: EventArgs.cs:392
The StateBase is the base class for the state of each observation - this is defined by actual trainer...
Definition: StateBase.cs:16
bool Done
Get/set whether the state is done or not.
Definition: StateBase.cs:72
double Reward
Get/set the reward of the state.
Definition: StateBase.cs:63
SimpleDatum Data
Returns other data associated with the state.
Definition: StateBase.cs:98
int ActionCount
Returns the number of actions.
Definition: StateBase.cs:90
SimpleDatum Clip
Returns the clip data assoicated with the state.
Definition: StateBase.cs:116
The WaitArgs is passed to the OnWait event.
Definition: EventArgs.cs:65
The Brain uses the instance of MyCaffe (e.g. the open project) to run new actions and train the netwo...
Definition: TrainerC51.cs:439
Log Log
Returns the output log.
Definition: TrainerC51.cs:646
GetDataArgs getDataArgs(Phase phase, int nAction)
Returns the GetDataArgs used to retrieve new data from the envrionment implemented by derived parent ...
Definition: TrainerC51.cs:620
int FrameStack
Specifies the number of frames per X value.
Definition: TrainerC51.cs:630
Brain(MyCaffeControl< T > mycaffe, PropertySet properties, CryptoRandom random, Phase phase)
The constructor.
Definition: TrainerC51.cs:491
void Train(int nIteration, MemoryCollection rgSamples, int nActionCount)
Train the model at the current iteration.
Definition: TrainerC51.cs:817
int BatchSize
Returns the batch size defined by the model.
Definition: TrainerC51.cs:638
CancelEvent Cancel
Returns the Cancel event used to cancel all MyCaffe tasks.
Definition: TrainerC51.cs:654
void OnOverlay(OverlayArgs e)
The OnOverlay callback is called just before displaying the gym image, thus allowing for an overlay t...
Definition: TrainerC51.cs:1230
int act(SimpleDatum sd, SimpleDatum sdClip, int nActionCount)
Returns the action from running the model. The action returned is either randomly selected (when usin...
Definition: TrainerC51.cs:769
void UpdateTargetModel()
The UpdateTargetModel transfers the trained layers from the active Net to the target Net.
Definition: TrainerC51.cs:803
SimpleDatum Preprocess(StateBase s, bool bUseRawInput, out bool bDifferent, bool bReset=false)
Preprocesses the data.
Definition: TrainerC51.cs:666
void Dispose()
Release all resources used by the Brain.
Definition: TrainerC51.cs:567
bool GetModelUpdated()
Get whether or not the model has been udpated or not.
Definition: TrainerC51.cs:793
The DqnAgent both builds episodes from the envrionment and trains on them using the Brain.
Definition: TrainerC51.cs:181
void Dispose()
Release all resources used.
Definition: TrainerC51.cs:244
byte[] Run(int nIterations, out string type)
Run the action on a set number of iterations and return the results with no training.
Definition: TrainerC51.cs:296
DqnAgent(IxTrainerCallback icallback, MyCaffeControl< T > mycaffe, PropertySet properties, CryptoRandom random, Phase phase)
The constructor.
Definition: TrainerC51.cs:213
void Run(Phase phase, int nN, ITERATOR_TYPE type, TRAIN_STEP step)
The Run method provides the main loop that performs the following steps: 1.) get state 2....
Definition: TrainerC51.cs:364
The TrainerC51 implements the C51-DQN algorithm as described by Bellemare et al., Google Dopamine Rai...
Definition: TrainerC51.cs:32
bool Test(int nN, ITERATOR_TYPE type)
Run the test cycle - currently this is not implemented.
Definition: TrainerC51.cs:138
byte[] Run(int nN, PropertySet runProp, out string type)
Run a set of iterations and return the resuts.
Definition: TrainerC51.cs:122
TrainerC51(MyCaffeControl< T > mycaffe, PropertySet properties, CryptoRandom random, IxTrainerCallback icallback)
The constructor.
Definition: TrainerC51.cs:45
bool Shutdown(int nWait)
Shutdown the trainer.
Definition: TrainerC51.cs:76
ResultCollection RunOne(int nDelay=1000)
Run a single cycle on the environment after the delay.
Definition: TrainerC51.cs:106
bool Train(int nN, ITERATOR_TYPE type, TRAIN_STEP step)
Train the network using a modified PG training algorithm optimized for GPU use.
Definition: TrainerC51.cs:164
void Dispose()
Release all resources used.
Definition: TrainerC51.cs:56
bool Initialize()
Initialize the trainer.
Definition: TrainerC51.cs:64
The IxTrainerCallback provides functions used by each trainer to 'call-back' to the parent for inform...
Definition: Interfaces.cs:303
The IxTrainerCallbackRNN provides functions used by each trainer to 'call-back' to the parent for inf...
Definition: Interfaces.cs:348
void OnConvertOutput(ConvertOutputArgs e)
The OnConvertOutput callback fires from within the Run method and is used to convert the network's ou...
The IxTrainerGetDataCallback interface is called right after rendering the output image and just befo...
Definition: Interfaces.cs:335
The IxTrainerRL interface is implemented by each RL Trainer.
Definition: Interfaces.cs:257
The MyCaffe.basecode contains all generic types used throughout MyCaffe.
Definition: Annotation.cs:12
Phase
Defines the Phase under which to run a Net.
Definition: Interfaces.cs:61
The MyCaffe.common namespace contains common MyCaffe classes.
Definition: BatchInput.cs:8
TRAIN_STEP
Defines the training stepping method (if any).
Definition: Interfaces.cs:131
The MyCaffe.data namespace contains dataset creators used to create common testing datasets such as M...
Definition: BinaryFile.cs:16
The MyCaffe.layers namespace contains all layers that have a solidified code base,...
Definition: LayerFactory.cs:15
The MyCaffe.param namespace contains parameters used to create models.
The MyCaffe.solvers namespace contains all solver classes, including the base Solver.
ITERATOR_TYPE
Specifies the iterator type to use.
Definition: Interfaces.cs:22
The MyCaffe namespace contains the main body of MyCaffe code that closesly tracks the C++ Caffe open-...
Definition: Annotation.cs:12