MyCaffe  1.12.2.41
Deep learning software for Windows C# programmers.
TrainerPG.cs
1using System;
2using System.Collections.Generic;
3using System.Drawing;
4using System.Linq;
5using System.Text;
6using System.Threading;
7using System.Threading.Tasks;
8using MyCaffe.basecode;
9using MyCaffe.common;
10using MyCaffe.fillers;
11using MyCaffe.layers;
12using MyCaffe.param;
13using MyCaffe.solvers;
14
16{
25 public class TrainerPG<T> : IxTrainerRL, IDisposable
26 {
27 IxTrainerCallback m_icallback;
28 CryptoRandom m_random = new CryptoRandom();
29 MyCaffeControl<T> m_mycaffe;
30 PropertySet m_properties;
31
39 public TrainerPG(MyCaffeControl<T> mycaffe, PropertySet properties, CryptoRandom random, IxTrainerCallback icallback)
40 {
41 m_icallback = icallback;
42 m_mycaffe = mycaffe;
43 m_properties = properties;
44 m_random = random;
45 }
46
50 public void Dispose()
51 {
52 }
53
58 public bool Initialize()
59 {
60 m_mycaffe.CancelEvent.Reset();
61 m_icallback.OnInitialize(new InitializeArgs(m_mycaffe));
62 return true;
63 }
64
65 private void wait(int nWait)
66 {
67 int nWaitInc = 250;
68 int nTotalWait = 0;
69
70 while (nTotalWait < nWait)
71 {
72 m_icallback.OnWait(new WaitArgs(nWaitInc));
73 nTotalWait += nWaitInc;
74 }
75 }
76
82 public bool Shutdown(int nWait)
83 {
84 if (m_mycaffe != null)
85 {
86 m_mycaffe.CancelEvent.Set();
87 wait(nWait);
88 }
89
90 m_icallback.OnShutdown();
91
92 return true;
93 }
94
100 public ResultCollection RunOne(int nDelay = 1000)
101 {
102 m_mycaffe.CancelEvent.Reset();
103 Agent<T> agent = new Agent<T>(m_icallback, m_mycaffe, m_properties, m_random, Phase.TRAIN);
104 agent.Run(Phase.TEST, 1, ITERATOR_TYPE.ITERATION);
105 agent.Dispose();
106 return null;
107 }
108
116 public byte[] Run(int nN, PropertySet runProp, out string type)
117 {
118 m_mycaffe.CancelEvent.Reset();
119 Agent<T> agent = new Agent<T>(m_icallback, m_mycaffe, m_properties, m_random, Phase.RUN);
120 byte[] rgResults = agent.Run(nN, out type);
121 agent.Dispose();
122
123 return rgResults;
124 }
125
132 public bool Test(int nN, ITERATOR_TYPE type)
133 {
134 int nDelay = 1000;
135 string strProp = m_properties.ToString();
136
137 // Turn off the num-skip to run at normal speed.
138 strProp += "EnableNumSkip=False;";
139 PropertySet properties = new PropertySet(strProp);
140
141 m_mycaffe.CancelEvent.Reset();
142 Agent<T> agent = new Agent<T>(m_icallback, m_mycaffe, properties, m_random, Phase.TRAIN);
143 agent.Run(Phase.TEST, nN, type);
144
145 agent.Dispose();
146 Shutdown(nDelay);
147
148 return true;
149 }
150
158 public bool Train(int nN, ITERATOR_TYPE type, TRAIN_STEP step)
159 {
160 if (step != TRAIN_STEP.NONE)
161 throw new Exception("The simple traininer does not support stepping - use the 'PG.MT' trainer instead.");
162
163 m_mycaffe.CancelEvent.Reset();
164 Agent<T> agent = new Agent<T>(m_icallback, m_mycaffe, m_properties, m_random, Phase.TRAIN);
165 agent.Run(Phase.TRAIN, nN, type);
166 agent.Dispose();
167
168 return false;
169 }
170 }
171
172 class Agent<T> : IDisposable
173 {
174 IxTrainerCallback m_icallback;
175 Brain<T> m_brain;
176 PropertySet m_properties;
177 CryptoRandom m_random;
178 float m_fGamma;
179 bool m_bAllowDiscountReset = false;
180 bool m_bUseRawInput = false;
181
182 public Agent(IxTrainerCallback icallback, MyCaffeControl<T> mycaffe, PropertySet properties, CryptoRandom random, Phase phase)
183 {
184 m_icallback = icallback;
185 m_brain = new Brain<T>(mycaffe, properties, random, phase);
186 m_properties = properties;
187 m_random = random;
188
189 m_fGamma = (float)properties.GetPropertyAsDouble("Gamma", 0.99);
190 m_bAllowDiscountReset = properties.GetPropertyAsBool("AllowDiscountReset", false);
191 m_bUseRawInput = properties.GetPropertyAsBool("UseRawInput", false);
192 }
193
194 public void Dispose()
195 {
196 if (m_brain != null)
197 {
198 m_brain.Dispose();
199 m_brain = null;
200 }
201 }
202
203 private StateBase getData(Phase phase, int nAction)
204 {
205 GetDataArgs args = m_brain.getDataArgs(phase, nAction);
206 m_icallback.OnGetData(args);
207 return args.State;
208 }
209
210 private void updateStatus(int nIteration, int nEpisodeCount, double dfRewardSum, double dfRunningReward)
211 {
212 GetStatusArgs args = new GetStatusArgs(0, nIteration, nEpisodeCount, 1000000, dfRunningReward, dfRewardSum, 0, 0, 0, 0);
213 m_icallback.OnUpdateStatus(args);
214 }
215
216 public byte[] Run(int nIterations, out string type)
217 {
218 IxTrainerCallbackRNN icallback = m_icallback as IxTrainerCallbackRNN;
219 if (icallback == null)
220 throw new Exception("The Run method requires an IxTrainerCallbackRNN interface to convert the results into the native format!");
221
222 StateBase s = getData(Phase.RUN, -1);
223 int nIteration = 0;
224 List<float> rgResults = new List<float>();
225
226 while (!m_brain.Cancel.WaitOne(0) && (nIterations == -1 || nIteration < nIterations))
227 {
228 // Preprocess the observation.
229 SimpleDatum x = m_brain.Preprocess(s, m_bUseRawInput);
230
231 // Forward the policy network and sample an action.
232 float fAprob;
233 int action = m_brain.act(x, out fAprob);
234
235 rgResults.Add(s.Data.TimeStamp.ToFileTime());
236 rgResults.Add(s.Data.GetDataAtF(0));
237 rgResults.Add(action);
238
239 // Take the next step using the action
240 StateBase s_ = getData(Phase.RUN, action);
241 nIteration++;
242 }
243
244 ConvertOutputArgs args = new ConvertOutputArgs(nIterations, rgResults.ToArray());
245 icallback.OnConvertOutput(args);
246
247 type = args.RawType;
248 return args.RawOutput;
249 }
250
251 private bool isAtIteration(int nN, ITERATOR_TYPE type, int nIteration, int nEpisode)
252 {
253 if (nN == -1)
254 return false;
255
256 if (type == ITERATOR_TYPE.EPISODE)
257 {
258 if (nEpisode < nN)
259 return false;
260
261 return true;
262 }
263 else
264 {
265 if (nIteration < nN)
266 return false;
267
268 return true;
269 }
270 }
271
282 public void Run(Phase phase, int nN, ITERATOR_TYPE type)
283 {
284 MemoryCollection m_rgMemory = new MemoryCollection();
285 double? dfRunningReward = null;
286 double dfEpisodeReward = 0;
287 int nEpisode = 0;
288 int nIteration = 0;
289
290 StateBase s = getData(phase, -1);
291
292 if (s.Clip != null)
293 throw new Exception("The PG.SIMPLE trainer does not support recurrent layers or clip data, use the 'PG.ST' or 'PG.MT' trainer instead.");
294
295 while (!m_brain.Cancel.WaitOne(0) && !isAtIteration(nN, type, nIteration, nEpisode))
296 {
297 // Preprocess the observation.
298 SimpleDatum x = m_brain.Preprocess(s, m_bUseRawInput);
299
300 // Forward the policy network and sample an action.
301 float fAprob;
302 int action = m_brain.act(x, out fAprob);
303
304 // Take the next step using the action
305 StateBase s_ = getData(phase, action);
306 dfEpisodeReward += s_.Reward;
307
308 if (phase == Phase.TRAIN)
309 {
310 // Build up episode memory, using reward for taking the action.
311 m_rgMemory.Add(new MemoryItem(s, x, action, fAprob, (float)s_.Reward));
312
313 // An episode has finished.
314 if (s_.Done)
315 {
316 nEpisode++;
317 nIteration++;
318
319 m_brain.Reshape(m_rgMemory);
320
321 // Compute the discounted reward (backwards through time)
322 float[] rgDiscountedR = m_rgMemory.GetDiscountedRewards(m_fGamma, m_bAllowDiscountReset);
323 // Rewards are standardized when set to be unit normal (helps control the gradient estimator variance)
324 m_brain.SetDiscountedR(rgDiscountedR);
325
326 // Modulate the gradient with the advantage (PG magic happens right here.)
327 float[] rgDlogp = m_rgMemory.GetPolicyGradients();
328 // discounted R applied to policy gradient within loss function, just before the backward pass.
329 m_brain.SetPolicyGradients(rgDlogp);
330
331 // Train for one iteration, which triggers the loss function.
332 List<Datum> rgData = m_rgMemory.GetData();
333 m_brain.SetData(rgData);
334 m_brain.Train(nIteration);
335
336 // Update reward running
337 if (!dfRunningReward.HasValue)
338 dfRunningReward = dfEpisodeReward;
339 else
340 dfRunningReward = dfRunningReward * 0.99 + dfEpisodeReward * 0.01;
341
342 updateStatus(nIteration, nEpisode, dfEpisodeReward, dfRunningReward.Value);
343 dfEpisodeReward = 0;
344
345 s = getData(phase, -1);
346 m_rgMemory.Clear();
347 }
348 else
349 {
350 s = s_;
351 }
352 }
353 else
354 {
355 if (s_.Done)
356 {
357 nEpisode++;
358
359 // Update reward running
360 if (!dfRunningReward.HasValue)
361 dfRunningReward = dfEpisodeReward;
362 else
363 dfRunningReward = dfRunningReward * 0.99 + dfEpisodeReward * 0.01;
364
365 updateStatus(nIteration, nEpisode, dfEpisodeReward, dfRunningReward.Value);
366 dfEpisodeReward = 0;
367
368 s = getData(phase, -1);
369 }
370 else
371 {
372 s = s_;
373 }
374
375 nIteration++;
376 }
377 }
378 }
379 }
380
381 class Brain<T> : IDisposable
382 {
383 MyCaffeControl<T> m_mycaffe;
384 Net<T> m_net;
385 Solver<T> m_solver;
386 MemoryDataLayer<T> m_memData;
387 MemoryLossLayer<T> m_memLoss;
388 PropertySet m_properties;
389 CryptoRandom m_random;
390 Blob<T> m_blobDiscountedR;
391 Blob<T> m_blobPolicyGradient;
392 bool m_bSkipLoss;
393 int m_nMiniBatch = 10;
394 SimpleDatum m_sdLast = null;
395
396 public Brain(MyCaffeControl<T> mycaffe, PropertySet properties, CryptoRandom random, Phase phase)
397 {
398 m_mycaffe = mycaffe;
399 m_net = mycaffe.GetInternalNet(phase);
400 m_solver = mycaffe.GetInternalSolver();
401 m_properties = properties;
402 m_random = random;
403
404 m_memData = m_net.FindLayer(LayerParameter.LayerType.MEMORYDATA, null) as MemoryDataLayer<T>;
405 m_memLoss = m_net.FindLayer(LayerParameter.LayerType.MEMORY_LOSS, null) as MemoryLossLayer<T>;
406 SoftmaxLayer<T> softmax = m_net.FindLayer(LayerParameter.LayerType.SOFTMAX, null) as SoftmaxLayer<T>;
407
408 if (softmax != null)
409 throw new Exception("The PG.SIMPLE trainer does not support the Softmax layer, use the 'PG.ST' or 'PG.MT' trainer instead.");
410
411 if (m_memData == null)
412 throw new Exception("Could not find the MemoryData Layer!");
413
414 if (m_memLoss == null)
415 throw new Exception("Could not find the MemoryLoss Layer!");
416
417 m_memLoss.OnGetLoss += memLoss_OnGetLoss;
418
419 m_blobDiscountedR = new Blob<T>(mycaffe.Cuda, mycaffe.Log);
420 m_blobPolicyGradient = new Blob<T>(mycaffe.Cuda, mycaffe.Log);
421
422 int nMiniBatch = mycaffe.CurrentProject.GetBatchSize(phase);
423 if (nMiniBatch != 0)
424 m_nMiniBatch = nMiniBatch;
425
426 m_nMiniBatch = m_properties.GetPropertyAsInt("MiniBatch", m_nMiniBatch);
427 }
428
429 private void dispose(ref Blob<T> b)
430 {
431 if (b != null)
432 {
433 b.Dispose();
434 b = null;
435 }
436 }
437
438 public void Dispose()
439 {
440 m_memLoss.OnGetLoss -= memLoss_OnGetLoss;
441 dispose(ref m_blobDiscountedR);
442 dispose(ref m_blobPolicyGradient);
443 }
444
445 public void Reshape(MemoryCollection col)
446 {
447 int nNum = col.Count;
448 int nChannels = col[0].Data.Channels;
449 int nHeight = col[0].Data.Height;
450 int nWidth = col[0].Data.Height;
451
452 m_blobDiscountedR.Reshape(nNum, 1, 1, 1);
453 m_blobPolicyGradient.Reshape(nNum, 1, 1, 1);
454 }
455
456 public void SetDiscountedR(float[] rg)
457 {
458 double dfMean = m_blobDiscountedR.mean(rg);
459 double dfStd = m_blobDiscountedR.std(dfMean, rg);
460 m_blobDiscountedR.SetData(Utility.ConvertVec<T>(rg));
461 m_blobDiscountedR.NormalizeData(dfMean, dfStd);
462 }
463
464 public void SetPolicyGradients(float[] rg)
465 {
466 m_blobPolicyGradient.SetData(Utility.ConvertVec<T>(rg));
467 }
468
469 public void SetData(List<Datum> rgData)
470 {
471 m_memData.AddDatumVector(rgData, null, 1, true, true);
472 }
473
474 public GetDataArgs getDataArgs(Phase phase, int nAction)
475 {
476 bool bReset = (nAction == -1) ? true : false;
477 return new GetDataArgs(phase, 0, m_mycaffe, m_mycaffe.Log, m_mycaffe.CancelEvent, bReset, nAction, false);
478 }
479
480 public Log Log
481 {
482 get { return m_mycaffe.Log; }
483 }
484
485 public CancelEvent Cancel
486 {
487 get { return m_mycaffe.CancelEvent; }
488 }
489
490 public SimpleDatum Preprocess(StateBase s, bool bUseRawInput)
491 {
492 SimpleDatum sd = new SimpleDatum(s.Data, true);
493
494 if (bUseRawInput)
495 return sd;
496
497 if (m_sdLast == null)
498 sd.Zero();
499 else
500 sd.Sub(m_sdLast);
501
502 m_sdLast = s.Data;
503
504 return sd;
505 }
506
507 public int act(SimpleDatum sd, out float fAprob)
508 {
509 List<Datum> rgData = new List<Datum>();
510 rgData.Add(new Datum(sd));
511 double dfLoss;
512
513 m_memData.AddDatumVector(rgData, null, 1, true, true);
514 m_bSkipLoss = true;
515 BlobCollection<T> res = m_net.Forward(out dfLoss);
516 m_bSkipLoss = false;
517 float[] rgfAprob = null;
518
519 for (int i = 0; i < res.Count; i++)
520 {
521 if (res[i].type != BLOB_TYPE.LOSS)
522 {
523 rgfAprob = Utility.ConvertVecF<T>(res[i].update_cpu_data());
524 break;
525 }
526 }
527
528 if (rgfAprob == null)
529 throw new Exception("Could not find a non-loss output! Your model should output the loss and the action probabilities.");
530
531 if (rgfAprob.Length != 1)
532 throw new Exception("The simple policy gradient only supports a single data output!");
533
534 fAprob = rgfAprob[0];
535
536 // Roll the dice!
537 if (m_random.NextDouble() < (double)fAprob)
538 return 0;
539 else
540 return 1;
541 }
542
543 public void Train(int nIteration)
544 {
545 m_mycaffe.Log.Enable = false;
546 m_solver.Step(1, TRAIN_STEP.NONE, false, false, true, true); // accumulate grad over batch
547
548 if (nIteration % m_nMiniBatch == 0)
549 {
550 m_solver.ApplyUpdate(nIteration);
551 m_net.ClearParamDiffs();
552 }
553
554 m_mycaffe.Log.Enable = true;
555 }
556
557 private void memLoss_OnGetLoss(object sender, MemoryLossLayerGetLossArgs<T> e)
558 {
559 if (m_bSkipLoss)
560 return;
561
562 int nCount = m_blobPolicyGradient.count();
563 long hPolicyGrad = m_blobPolicyGradient.mutable_gpu_data;
564 long hBottomDiff = e.Bottom[0].mutable_gpu_diff;
565 long hDiscountedR = m_blobDiscountedR.gpu_data;
566
567 // Calculate the actual loss.
568 double dfSumSq = Utility.ConvertVal<T>(m_blobPolicyGradient.sumsq_data());
569 double dfMean = dfSumSq;
570
571 e.Loss = dfMean;
572 e.EnableLossUpdate = false; // apply gradients to bottom directly.
573
574 // Modulate the gradient with the advantage (PG magic happens right here.)
575 m_mycaffe.Cuda.mul(nCount, hPolicyGrad, hDiscountedR, hPolicyGrad);
576 m_mycaffe.Cuda.copy(nCount, hPolicyGrad, hBottomDiff);
577 m_mycaffe.Cuda.mul_scalar(nCount, -1.0, hBottomDiff);
578 }
579 }
580
581 class MemoryCollection : GenericList<MemoryItem>
582 {
583 public MemoryCollection()
584 {
585 }
586
587 public float[] GetDiscountedRewards(float fGamma, bool bAllowReset)
588 {
589 float fRunningAdd = 0;
590 float[] rgR = m_rgItems.Select(p => p.Reward).ToArray();
591 float[] rgDiscountedR = new float[rgR.Length];
592
593 for (int t = Count - 1; t >= 0; t--)
594 {
595 if (bAllowReset && rgR[t] != 0)
596 fRunningAdd = 0;
597
598 fRunningAdd = fRunningAdd * fGamma + rgR[t];
599 rgDiscountedR[t] = fRunningAdd;
600 }
601
602 return rgDiscountedR;
603 }
604
605 public float[] GetPolicyGradients()
606 {
607 return m_rgItems.Select(p => p.dlogps).ToArray();
608 }
609
610 public List<Datum> GetData()
611 {
612 List<Datum> rgData = new List<Datum>();
613
614 for (int i = 0; i < m_rgItems.Count; i++)
615 {
616 rgData.Add(new Datum(m_rgItems[i].Data));
617 }
618
619 return rgData;
620 }
621
622 public List<Datum> GetClip()
623 {
624 return null;
625 }
626 }
627
628 class MemoryItem
629 {
630 StateBase m_state;
631 SimpleDatum m_x;
632 int m_nAction;
633 float m_fAprob;
634 float m_fReward;
635
636 public MemoryItem(StateBase s, SimpleDatum x, int nAction, float fAprob, float fReward)
637 {
638 m_state = s;
639 m_x = x;
640 m_nAction = nAction;
641 m_fAprob = fAprob;
642 m_fReward = fReward;
643 }
644
645 public StateBase State
646 {
647 get { return m_state; }
648 }
649
650 public SimpleDatum Data
651 {
652 get { return m_x; }
653 }
654
655 public int Action
656 {
657 get { return m_nAction; }
658 }
659
660 public float Reward
661 {
662 get { return m_fReward; }
663 }
664
671 public float dlogps
672 {
673 get
674 {
675 float fY = 0;
676
677 if (m_nAction == 0)
678 fY = 1;
679
680 return fY - m_fAprob;
681 }
682 }
683
684 public override string ToString()
685 {
686 return "action = " + m_nAction.ToString() + " reward = " + m_fReward.ToString("N2") + " aprob = " + m_fAprob.ToString("N5") + " dlogps = " + dlogps.ToString("N5");
687 }
688 }
689}
The MyCaffeControl is the main object used to manage all training, testing and running of the MyCaffe...
CancelEvent CancelEvent
Returns the CancelEvent used.
The CancelEvent provides an extension to the manual cancel event that allows for overriding the manua...
Definition: CancelEvent.cs:17
void Reset()
Resets the event clearing any signaled state.
Definition: CancelEvent.cs:279
CancelEvent()
The CancelEvent constructor.
Definition: CancelEvent.cs:28
void Set()
Sets the event to the signaled state.
Definition: CancelEvent.cs:270
The CryptoRandom is a random number generator that can use either the standard .Net Random objec or t...
Definition: CryptoRandom.cs:14
double NextDouble()
Returns a random double within the range .
Definition: CryptoRandom.cs:83
The Datum class is a simple wrapper to the SimpleDatum class to ensure compatibility with the origina...
Definition: Datum.cs:12
The GenericList provides a base used to implement a generic list by only implementing the minimum amo...
Definition: GenericList.cs:15
List< T > m_rgItems
The actual list of items.
Definition: GenericList.cs:19
The Log class provides general output in text form.
Definition: Log.cs:13
Log(string strSrc)
The Log constructor.
Definition: Log.cs:33
Specifies a key-value pair of properties.
Definition: PropertySet.cs:16
int GetPropertyAsInt(string strName, int nDefault=0)
Returns a property as an integer value.
Definition: PropertySet.cs:287
bool GetPropertyAsBool(string strName, bool bDefault=false)
Returns a property as a boolean value.
Definition: PropertySet.cs:267
double GetPropertyAsDouble(string strName, double dfDefault=0)
Returns a property as an double value.
Definition: PropertySet.cs:307
override string ToString()
Returns the string representation of the properties.
Definition: PropertySet.cs:325
The SimpleDatum class holds a data input within host memory.
Definition: SimpleDatum.cs:161
bool Sub(SimpleDatum sd, bool bSetNegativeToZero=false)
Subtract the data of another SimpleDatum from this one, so this = this - sd.
void Zero()
Zero out all data in the datum but keep the size and other settings.
SimpleDatum Add(SimpleDatum d)
Creates a new SimpleDatum and adds another SimpleDatum to it.
override string ToString()
Return a string representation of the SimpleDatum.
The Utility class provides general utility funtions.
Definition: Utility.cs:35
static double[] ConvertVec(float[] rgf)
Convert an array of float to an array of generics.
Definition: Utility.cs:550
The BlobCollection contains a list of Blobs.
int Count
Returns the number of items in the collection.
The Blob is the main holder of data that moves through the Layers of the Net.
Definition: Blob.cs:25
void SetData(T[] rgData, int nCount=-1, bool bSetCount=true)
Sets a number of items within the Blob's data.
Definition: Blob.cs:1922
long mutable_gpu_data
Returns the data GPU handle used by the CudaDnn connection.
Definition: Blob.cs:1487
void Reshape(int nNum, int nChannels, int nHeight, int nWidth, bool? bUseHalfSize=null)
DEPRECIATED; use
Definition: Blob.cs:442
double std(double? dfMean=null, float[] rgDf=null)
Calculate the standard deviation of the blob data.
Definition: Blob.cs:3007
double mean(float[] rgDf=null, bool bDiff=false)
Calculate the mean of the blob data.
Definition: Blob.cs:2965
T sumsq_data()
Calcualte the sum of squares (L2 norm squared) of the data.
Definition: Blob.cs:1730
void NormalizeData(double? dfMean=null, double? dfStd=null)
Normalize the blob data by subtracting the mean and dividing by the standard deviation.
Definition: Blob.cs:2942
int count()
Returns the total number of items in the Blob.
Definition: Blob.cs:739
long gpu_data
Returns the data GPU handle used by the CudaDnn connection.
Definition: Blob.cs:1479
Connects Layer's together into a direct acrylic graph (DAG) specified by a NetParameter
Definition: Net.cs:23
BlobCollection< T > Forward()
Run forward with the input Blob's already fed separately.
Definition: Net.cs:1445
Layer< T > FindLayer(LayerParameter.LayerType? type, string strName)
Find the layer with the matching type, name and or both.
Definition: Net.cs:2748
void ClearParamDiffs()
Zero out the diffs of all netw parameters. This should be run before Backward.
Definition: Net.cs:1907
The ResultCollection contains the result of a given CaffeControl::Run.
The MemoryDataLayer provides data to the Net from memory. This layer is initialized with the MyCaffe....
virtual void AddDatumVector(Datum[] rgData, Datum[] rgClip=null, int nLblAxis=1, bool bReset=false, bool bResizeBatch=false)
This method is used to add a list of Datums to the memory.
The MemoryLossLayerGetLossArgs class is passed to the OnGetLoss event.
bool EnableLossUpdate
Get/set enabling the loss update within the backpropagation pass.
double Loss
Get/set the externally calculated total loss.
BlobCollection< T > Bottom
Specifies the bottom passed in during the forward pass.
The MemoryLossLayer provides a method of performing a custom loss functionality. Similar to the Memor...
EventHandler< MemoryLossLayerGetLossArgs< T > > OnGetLoss
The OnGetLoss event fires during each forward pass. The value returned is saved, and applied on the b...
The SoftmaxLayer computes the softmax function. This layer is initialized with the MyCaffe....
Definition: SoftmaxLayer.cs:24
Specifies the base parameter for all layers.
LayerType
Specifies the layer type.
An interface for classes that perform optimization on Nets - this class serves as the base class for ...
Definition: Solver.cs:28
bool Step(int nIters, TRAIN_STEP step=TRAIN_STEP.NONE, bool bZeroDiffs=true, bool bApplyUpdates=true, bool bDisableOutput=false, bool bDisableProgress=false, double? dfLossOverride=null, bool? bAllowSnapshot=null)
Steps a set of iterations through a training cycle.
Definition: Solver.cs:818
abstract double ApplyUpdate(int nIterationOverride=-1)
Make and apply the update value for the current iteration.
The InitializeArgs is passed to the OnInitialize event.
Definition: EventArgs.cs:90
The WaitArgs is passed to the OnWait event.
Definition: EventArgs.cs:65
The TrainerPG implements a simple Policy Gradient trainer inspired by Andrej Karpathy's blog posed re...
Definition: TrainerPG.cs:26
byte[] Run(int nN, PropertySet runProp, out string type)
Run a set of iterations and return the resuts.
Definition: TrainerPG.cs:116
bool Initialize()
Initialize the trainer.
Definition: TrainerPG.cs:58
bool Test(int nN, ITERATOR_TYPE type)
Run the test cycle - currently this is not implemented.
Definition: TrainerPG.cs:132
bool Train(int nN, ITERATOR_TYPE type, TRAIN_STEP step)
Train the network using a modified PG training algorithm optimized for GPU use.
Definition: TrainerPG.cs:158
TrainerPG(MyCaffeControl< T > mycaffe, PropertySet properties, CryptoRandom random, IxTrainerCallback icallback)
The constructor.
Definition: TrainerPG.cs:39
bool Shutdown(int nWait)
Shutdown the trainer.
Definition: TrainerPG.cs:82
void Dispose()
Releases all resources used.
Definition: TrainerPG.cs:50
ResultCollection RunOne(int nDelay=1000)
Run a single cycle on the environment after the delay.
Definition: TrainerPG.cs:100
The IxTrainerCallback provides functions used by each trainer to 'call-back' to the parent for inform...
Definition: Interfaces.cs:303
The IxTrainerRL interface is implemented by each RL Trainer.
Definition: Interfaces.cs:257
The MyCaffe.basecode contains all generic types used throughout MyCaffe.
Definition: Annotation.cs:12
Phase
Defines the Phase under which to run a Net.
Definition: Interfaces.cs:61
The MyCaffe.common namespace contains common MyCaffe classes.
Definition: BatchInput.cs:8
BLOB_TYPE
Defines the tpe of data held by a given Blob.
Definition: Interfaces.cs:62
TRAIN_STEP
Defines the training stepping method (if any).
Definition: Interfaces.cs:131
The MyCaffe.fillers namespace contains all fillers including the Filler class.
The MyCaffe.layers namespace contains all layers that have a solidified code base,...
Definition: LayerFactory.cs:15
The MyCaffe.param namespace contains parameters used to create models.
The MyCaffe.solvers namespace contains all solver classes, including the base Solver.
ITERATOR_TYPE
Specifies the iterator type to use.
Definition: Interfaces.cs:22
The MyCaffe namespace contains the main body of MyCaffe code that closesly tracks the C++ Caffe open-...
Definition: Annotation.cs:12