MyCaffe  1.12.2.41
Deep learning software for Windows C# programmers.
MyCaffeTrainerRL.cs
1using System;
2using System.Collections.Generic;
4using System.Diagnostics;
5using System.Linq;
6using System.Text;
7using System.Threading;
8using System.Threading.Tasks;
9using MyCaffe.basecode;
11using MyCaffe.common;
12using MyCaffe.gym;
13
14namespace MyCaffe.trainers
15{
48 {
56 protected PropertySet m_properties = null;
60 protected int m_nProjectID = 0;
64 protected ConnectInfo m_dsCi = null;
65 IxTrainerRL m_itrainer = null;
66 double m_dfExplorationRate = 0;
67 double m_dfOptimalSelectionRate = 0;
68 double m_dfGlobalRewards = 0;
69 double m_dfGlobalRewardsAve = 0;
70 double m_dfGlobalRewardsMax = -double.MaxValue;
71 int m_nGlobalEpisodeCount = 0;
72 int m_nGlobalEpisodeMax = 0;
73 double m_dfLoss = 0;
74 int m_nThreads = 1;
75 REWARD_TYPE m_rewardType = REWARD_TYPE.MAXIMUM;
76 TRAINER_TYPE m_trainerType = TRAINER_TYPE.PG_ST;
77 int m_nItertions = -1;
78 IXMyCaffeCustomTrainerCallback m_icallback = null;
79 int m_nSnapshot = 0;
80 bool m_bSnapshot = false;
81 object m_syncObj = new object();
82
83 enum TRAINER_TYPE
84 {
85 PG_MT,
86 PG_ST,
87 PG_SIMPLE
88 }
89
90 enum REWARD_TYPE
91 {
92 VALUE,
93 AVERAGE,
94 MAXIMUM
95 }
96
101 {
102 InitializeComponent();
103 }
104
109 public MyCaffeTrainerRL(IContainer container)
110 {
111 container.Add(this);
112
113 InitializeComponent();
114 }
115
116 #region Overrides
117
121 protected virtual string name
122 {
123 get { return "MyCaffe RL Trainer"; }
124 }
125
129 protected virtual TRAINING_CATEGORY category
130 {
131 get { return TRAINING_CATEGORY.REINFORCEMENT; }
132 }
133
140 protected virtual DatasetDescriptor get_dataset_override(int nProjectID, ConnectInfo ci = null)
141 {
142 return null;
143 }
144
149 protected virtual string get_information()
150 {
151 return "";
152 }
153
162 protected virtual IxTrainerRL create_trainerD(Component caffe)
163 {
166 m_dsCi = mycaffe.DatasetConnectInfo;
167
168 int.TryParse(mycaffe.CurrentProject.GetSolverSetting("max_iter"), out m_nItertions);
169 int.TryParse(mycaffe.CurrentProject.GetSolverSetting("snapshot"), out m_nSnapshot);
170
171 switch (m_trainerType)
172 {
173 case TRAINER_TYPE.PG_SIMPLE:
174 return new pg.simple.TrainerPG<double>(mycaffe, m_properties, m_random, this);
175
176 case TRAINER_TYPE.PG_ST:
177 return new pg.st.TrainerPG<double>(mycaffe, m_properties, m_random, this);
178
179 case TRAINER_TYPE.PG_MT:
180 return new pg.mt.TrainerPG<double>(mycaffe, m_properties, m_random, this);
181
182 default:
183 throw new Exception("Unknown trainer type '" + m_trainerType.ToString() + "'!");
184 }
185 }
186
195 protected virtual IxTrainerRL create_trainerF(Component caffe)
196 {
199 m_dsCi = mycaffe.DatasetConnectInfo;
200
201 int.TryParse(mycaffe.CurrentProject.GetSolverSetting("max_iter"), out m_nItertions);
202 int.TryParse(mycaffe.CurrentProject.GetSolverSetting("snapshot"), out m_nSnapshot);
203
204 switch (m_trainerType)
205 {
206 case TRAINER_TYPE.PG_SIMPLE:
207 return new pg.simple.TrainerPG<float>(mycaffe, m_properties, m_random, this);
208
209 case TRAINER_TYPE.PG_ST:
210 return new pg.st.TrainerPG<float>(mycaffe, m_properties, m_random, this);
211
212 case TRAINER_TYPE.PG_MT:
213 return new pg.mt.TrainerPG<float>(mycaffe, m_properties, m_random, this);
214
215 default:
216 throw new Exception("Unknown trainer type '" + m_trainerType.ToString() + "'!");
217 }
218 }
219
223 protected virtual void dispose()
224 {
225 }
226
234 protected virtual void initialize(InitializeArgs e)
235 {
236 }
237
241 protected virtual void shutdown()
242 {
243 }
244
250 protected virtual bool getData(GetDataArgs e)
251 {
252 return false;
253 }
254
260 protected virtual bool get_update_snapshot(out int nIteration, out double dfAccuracy)
261 {
262 nIteration = GlobalEpisodeCount;
263 dfAccuracy = GlobalRewards;
264
265 if (m_bSnapshot)
266 {
267 m_bSnapshot = false;
268 return true;
269 }
270
271 return false;
272 }
273
277 protected virtual void openUi()
278 {
279 }
280
281 #endregion
282
283 #region IXMyCaffeCustomTrainer Interface
284
289 {
290 get { return Stage.RL; }
291 }
292
296 public string Name
297 {
298 get { return name; }
299 }
300
305 {
306 get { return category; }
307 }
308
314 public bool GetUpdateSnapshot(out int nIteration, out double dfAccuracy)
315 {
316 return get_update_snapshot(out nIteration, out dfAccuracy);
317 }
318
325 public DatasetDescriptor GetDatasetOverride(int nProjectID, ConnectInfo ci = null)
326 {
327 return get_dataset_override(nProjectID, ci);
328 }
329
334 {
335 get { return true; }
336 }
337
342 {
343 get { return true; }
344 }
345
350 {
351 get { return true; }
352 }
353
357 public void CleanUp()
358 {
359 cleanup(3000, true);
360 }
361
362 private void cleanup(int nWait, bool bCallShutdown = false)
363 {
364 lock (m_syncObj)
365 {
366 if (m_itrainer != null)
367 {
368 m_itrainer.Shutdown(nWait);
369 m_itrainer = null;
370 }
371
372 if (bCallShutdown)
373 shutdown();
374 }
375 }
376
383 public void Initialize(string strProperties, IXMyCaffeCustomTrainerCallback icallback)
384 {
385 m_icallback = icallback;
386 m_properties = new PropertySet(strProperties);
387 m_nThreads = m_properties.GetPropertyAsInt("Threads", 1);
388
389 string strRewardType = m_properties.GetProperty("RewardType", false);
390 if (strRewardType == null)
391 strRewardType = "VAL";
392 else
393 strRewardType = strRewardType.ToUpper();
394
395 if (strRewardType == "VAL" || strRewardType == "VALUE")
396 m_rewardType = REWARD_TYPE.VALUE;
397 else if (strRewardType == "AVE" || strRewardType == "AVERAGE")
398 m_rewardType = REWARD_TYPE.AVERAGE;
399
400 string strTrainerType = m_properties.GetProperty("TrainerType");
401
402 switch (strTrainerType)
403 {
404 case "PG.SIMPLE": // bare bones model (Sigmoid only)
405 m_trainerType = TRAINER_TYPE.PG_SIMPLE;
406 break;
407
408 case "PG.ST": // single thread (Sigmoid and Softmax)
409 m_trainerType = TRAINER_TYPE.PG_ST;
410 break;
411
412 case "PG":
413 case "PG.MT": // multi-thread (Sigmoid and Softmax)
414 m_trainerType = TRAINER_TYPE.PG_MT;
415 break;
416
417 default:
418 throw new Exception("Unknown trainer type '" + strTrainerType + "'!");
419 }
420 }
421
422 private IxTrainerRL createTrainer(Component mycaffe)
423 {
424 IxTrainerRL itrainer = null;
425
426 if (mycaffe is MyCaffeControl<double>)
427 itrainer = create_trainerD(mycaffe);
428 else
429 itrainer = create_trainerF(mycaffe);
430
431 itrainer.Initialize();
432
433 return itrainer;
434 }
435
442 public ResultCollection RunOne(Component mycaffe, int nDelay = 1000)
443 {
444 if (m_itrainer == null)
445 m_itrainer = createTrainer(mycaffe);
446
447 ResultCollection res = m_itrainer.RunOne(nDelay);
448 cleanup(50);
449
450 return res;
451 }
452
460 public byte[] Run(Component mycaffe, int nN, out string type)
461 {
462 if (m_itrainer == null)
463 m_itrainer = createTrainer(mycaffe);
464
465 PropertySet runProp = null;
467 if (icallback != null)
468 runProp = icallback.GetRunProperties();
469
470 byte[] rgResults = m_itrainer.Run(nN, runProp, out type);
471 cleanup(0);
472
473 return rgResults;
474 }
475
482 public void Test(Component mycaffe, int nIterationOverride, ITERATOR_TYPE type = ITERATOR_TYPE.ITERATION)
483 {
484 if (m_itrainer == null)
485 m_itrainer = createTrainer(mycaffe);
486
487 if (nIterationOverride == -1)
488 nIterationOverride = m_nItertions;
489
490 m_itrainer.Test(nIterationOverride, type);
491 cleanup(500);
492 }
493
501 public void Train(Component mycaffe, int nIterationOverride, ITERATOR_TYPE type = ITERATOR_TYPE.ITERATION, TRAIN_STEP step = TRAIN_STEP.NONE)
502 {
503 if (m_itrainer == null)
504 m_itrainer = createTrainer(mycaffe);
505
506 if (nIterationOverride == -1)
507 nIterationOverride = m_nItertions;
508
509 m_itrainer.Train(nIterationOverride, type, step);
510 cleanup(1000);
511 }
512
513 #endregion
514
519 {
520 initialize(e);
521 }
522
526 public void OnShutdown()
527 {
528 shutdown();
529 }
530
534 public void OnGetData(GetDataArgs e)
535 {
536 getData(e);
537 }
538
543 {
544 m_dfGlobalRewards = e.TotalReward;
545 m_dfGlobalRewardsMax = Math.Max(m_dfGlobalRewardsMax, e.TotalReward);
546 m_dfGlobalRewardsAve = (1.0 / (double)m_nThreads) * e.TotalReward + ((m_nThreads - 1) / (double)m_nThreads) * m_dfGlobalRewardsAve;
547 m_dfExplorationRate = e.ExplorationRate;
548 m_dfOptimalSelectionRate = e.OptimalSelectionCoefficient;
549
550 if (m_nThreads > 1)
551 m_nGlobalEpisodeCount++;
552 else
553 m_nGlobalEpisodeCount = e.Frames;
554
555 m_nGlobalEpisodeMax = e.MaxFrames;
556 m_dfLoss = e.Loss;
557
558 if (m_icallback != null)
559 {
560 Dictionary<string, double> rgValues = new Dictionary<string, double>();
561 rgValues.Add("GlobalIteration", GlobalEpisodeCount);
562 rgValues.Add("GlobalLoss", GlobalLoss);
563 rgValues.Add("LearningRate", e.LearningRate);
564 rgValues.Add("GlobalAccuracy", GlobalRewards);
565 rgValues.Add("Threads", m_nThreads);
566 m_icallback.Update(TrainingCategory, rgValues);
567 }
568
569 e.NewFrameCount = m_nGlobalEpisodeCount;
570
571 if (e.Index == 0 && m_nSnapshot > 0 && m_nGlobalEpisodeCount > 0 && (m_nGlobalEpisodeCount % m_nSnapshot) == 0)
572 m_bSnapshot = true;
573 }
574
578 public void OnWait(WaitArgs e)
579 {
580 Thread.Sleep(e.Wait);
581 }
582
588 public double GetProperty(string strProp)
589 {
590 switch (strProp)
591 {
592 case "GlobalLoss":
593 return GlobalLoss;
594
595 case "GlobalRewards":
596 return GlobalRewards;
597
598 case "GlobalEpisodeCount":
599 return GlobalEpisodeCount;
600
601 case "ExplorationRate":
602 return ExplorationRate;
603
604 default:
605 throw new Exception("The property '" + strProp + "' is not supported by the MyCaffeTrainerRNN.");
606 }
607 }
608
618 public double GlobalRewards
619 {
620 get
621 {
622 switch (m_rewardType)
623 {
624 case REWARD_TYPE.VALUE:
625 return m_dfGlobalRewards;
626
627 case REWARD_TYPE.AVERAGE:
628 return m_dfGlobalRewardsAve;
629
630 default:
631 return (m_dfGlobalRewardsMax == -double.MaxValue) ? 0 : m_dfGlobalRewardsMax;
632 }
633 }
634 }
635
639 public double GlobalLoss
640 {
641 get { return m_dfLoss; }
642 }
643
648 {
649 get { return m_nGlobalEpisodeCount; }
650 }
651
656 {
657 get { return m_nGlobalEpisodeMax; }
658 }
659
663 public double ExplorationRate
664 {
665 get { return m_dfExplorationRate; }
666 }
667
672 {
673 get { return m_dfOptimalSelectionRate; }
674 }
675
679 public string Information
680 {
681 get { return get_information(); }
682 }
683
687 public void OpenUi()
688 {
689 openUi();
690 }
691 }
692}
The MyCaffeControl is the main object used to manage all training, testing and running of the MyCaffe...
ConnectInfo DatasetConnectInfo
Returns the dataset connection information, if used (default = null).
ProjectEx CurrentProject
Returns the name of the currently loaded project.
The ConnectInfo class specifies the server, database and username/password used to connect to a datab...
Definition: ConnectInfo.cs:14
The CryptoRandom is a random number generator that can use either the standard .Net Random objec or t...
Definition: CryptoRandom.cs:14
string GetSolverSetting(string strParam)
Get a setting from the solver descriptor.
Definition: ProjectEx.cs:453
int OriginalID
Get/set the original project ID.
Definition: ProjectEx.cs:541
Specifies a key-value pair of properties.
Definition: PropertySet.cs:16
string GetProperty(string strName, bool bThrowExceptions=true)
Returns a property as a string value.
Definition: PropertySet.cs:146
int GetPropertyAsInt(string strName, int nDefault=0)
Returns a property as an integer value.
Definition: PropertySet.cs:287
The DatasetDescriptor class describes a dataset which contains both a training data source and testin...
The ResultCollection contains the result of a given CaffeControl::Run.
The GetDataArgs is passed to the OnGetData event to retrieve data.
Definition: EventArgs.cs:402
The GetStatusArgs is passed to the OnGetStatus event.
Definition: EventArgs.cs:166
double Loss
Returns the loss value.
Definition: EventArgs.cs:262
double OptimalSelectionCoefficient
Returns the optimal selection coefficient.
Definition: EventArgs.cs:302
int MaxFrames
Returns the maximum frame count.
Definition: EventArgs.cs:246
int Frames
Returns the total frame count across all agents.
Definition: EventArgs.cs:238
int NewFrameCount
Get/set the new frame count.
Definition: EventArgs.cs:229
double ExplorationRate
Returns the current exploration rate.
Definition: EventArgs.cs:294
double TotalReward
Returns the total rewards.
Definition: EventArgs.cs:278
int Index
Returns the index of the caller.
Definition: EventArgs.cs:213
double LearningRate
Returns the current learning rate.
Definition: EventArgs.cs:270
The InitializeArgs is passed to the OnInitialize event.
Definition: EventArgs.cs:90
(Depreciated - use MyCaffeTrainerDual instead.) The MyCaffeTraininerRL is used to perform reinforceme...
void Test(Component mycaffe, int nIterationOverride, ITERATOR_TYPE type=ITERATOR_TYPE.ITERATION)
Create a new trainer and use it to run a test cycle.
void OpenUi()
Open the user interface for the trainer, of one exists.
double OptimalSelectionRate
Returns the rate of selection from the optimal set with the highest reward (this setting is optional,...
void Train(Component mycaffe, int nIterationOverride, ITERATOR_TYPE type=ITERATOR_TYPE.ITERATION, TRAIN_STEP step=TRAIN_STEP.NONE)
Create a new trainer and use it to run a training cycle.
void OnShutdown()
The OnShutdown callback fires when shutting down the trainer.
void OnUpdateStatus(GetStatusArgs e)
The OnGetStatus callback fires on each iteration within the Train method.
double GetProperty(string strProp)
Return a property value from the trainer.
virtual void shutdown()
Override called from within the CleanUp method.
DatasetDescriptor GetDatasetOverride(int nProjectID, ConnectInfo ci=null)
Returns a dataset override to use (if any) instead of the project's dataset. If there is no dataset o...
ResultCollection RunOne(Component mycaffe, int nDelay=1000)
Create a new trainer and use it to run a single run cycle.
double GlobalLoss
Return the global loss.
PropertySet m_properties
Specifies the properties parsed from the key-value pair passed to the Initialize method.
double ExplorationRate
Returns the current exploration rate.
void CleanUp()
Releases any resources used by the component.
ConnectInfo m_dsCi
Optionally, specifies the dataset connection info, or null.
bool IsRunningSupported
Returns whether or not Running is supported.
virtual IxTrainerRL create_trainerF(Component caffe)
Optionally overridden to return a new type of trainer.
virtual TRAINING_CATEGORY category
Override when using a training method other than the REINFORCEMENT method (the default).
void OnWait(WaitArgs e)
The OnWait callback fires when waiting for a shutdown.
void OnGetData(GetDataArgs e)
The OnGetData callback fires from within the Train method and is used to get a new observation data.
string Information
Returns information describing the trainer.
virtual IxTrainerRL create_trainerD(Component caffe)
Optionally overridden to return a new type of trainer.
bool IsTrainingSupported
Returns whether or not Training is supported.
virtual string name
Overriden to give the actual name of the custom trainer.
double? GlobalRewards
Returns the global rewards based on the reward type specified by the 'RewardType' property.
virtual bool getData(GetDataArgs e)
Override called by the OnGetData event fired by the Trainer to retrieve a new set of observation coll...
virtual void initialize(InitializeArgs e)
Override called by the Initialize method of the trainer.
virtual void dispose()
Override to dispose of resources used.
int GlobalEpisodeCount
Returns the global episode count.
void Initialize(string strProperties, IXMyCaffeCustomTrainerCallback icallback)
Initializes a new custom trainer by loading the key-value pair of properties into the property set.
TRAINING_CATEGORY TrainingCategory
Returns the training category of the custom trainer (default = REINFORCEMENT).
CryptoRandom m_random
Random number generator used to get initial actions, etc.
int m_nProjectID
Specifies the project ID of the project held by the instance of MyCaffe.
void OnInitialize(InitializeArgs e)
The OnIntialize callback fires when initializing the trainer.
bool IsTestingSupported
Returns whether or not Testing is supported.
virtual bool get_update_snapshot(out int nIteration, out double dfAccuracy)
Returns true when the training is ready for a snap-shot, false otherwise.
virtual DatasetDescriptor get_dataset_override(int nProjectID, ConnectInfo ci=null)
Returns a dataset override to use (if any) instead of the project's dataset. If there is no dataset o...
MyCaffeTrainerRL(IContainer container)
The constructor.
bool GetUpdateSnapshot(out int nIteration, out double dfAccuracy)
Returns true when the training is ready for a snap-shot, false otherwise.
virtual void openUi()
Called by OpenUi, override this when a UI (via WCF) should be displayed.
string Name
Returns the name of the custom trainer. This method calls the 'name' override.
byte[] Run(Component mycaffe, int nN, out string type)
Run the network using the run technique implemented by this trainer.
virtual string get_information()
Returns information describing the specific trainer, such as the gym used, if any.
int GlobalEpisodeMax
Returns the maximum global episode count.
The WaitArgs is passed to the OnWait event.
Definition: EventArgs.cs:65
int Wait
Returns the amount of time to wait in milliseconds.
Definition: EventArgs.cs:81
The Component class is a standard Microsoft.NET class that implements the IComponent interface and is...
Definition: Component.cs:18
The IXMyCaffeCustomTrainerCallback interface is used to call back to the parent running the custom tr...
Definition: Interfaces.cs:199
void Update(TRAINING_CATEGORY cat, Dictionary< string, double > rgValues)
The Update method updates the parent with the global iteration, reward and loss.
The IXMyCaffeCustomTrainerCallbackRNN interface is used to call back to the parent running the custom...
Definition: Interfaces.cs:212
PropertySet GetRunProperties()
The GetRunProperties method is used to qeury the properties used when Running, if any.
The IXMyCaffeCustomTrainer interface is used by the MyCaffeCustomTraininer components that provide va...
Definition: Interfaces.cs:135
The IxTrainerCallback provides functions used by each trainer to 'call-back' to the parent for inform...
Definition: Interfaces.cs:303
bool Initialize()
Initialize the trainer.
bool Train(int nN, ITERATOR_TYPE type, TRAIN_STEP step)
Train the network.
bool Test(int nN, ITERATOR_TYPE type)
Test the newtork.
bool Shutdown(int nWait)
Shutdown the trainer.
The IxTrainerRL interface is implemented by each RL Trainer.
Definition: Interfaces.cs:257
ResultCollection RunOne(int nDelay=1000)
Run a single cycle on the trainer.
byte[] Run(int nN, PropertySet runProp, out string type)
Run a number of 'nN' samples on the trainer.
The descriptors namespace contains all descriptor used to describe various items stored within the da...
The MyCaffe.basecode contains all generic types used throughout MyCaffe.
Definition: Annotation.cs:12
TRAINING_CATEGORY
Defines the category of training.
Definition: Interfaces.cs:34
Stage
Specifies the stage underwhich to run a custom trainer.
Definition: Interfaces.cs:88
The MyCaffe.common namespace contains common MyCaffe classes.
Definition: BatchInput.cs:8
TRAIN_STEP
Defines the training stepping method (if any).
Definition: Interfaces.cs:131
The MyCaffe.gym namespace contains all classes related to the Gym's supported by MyCaffe.
The MyCaffe.trainers namespace contains all reinforcement and recurrent learning trainers.
ITERATOR_TYPE
Specifies the iterator type to use.
Definition: Interfaces.cs:22
The MyCaffe namespace contains the main body of MyCaffe code that closesly tracks the C++ Caffe open-...
Definition: Annotation.cs:12