MyCaffe  1.12.2.41
Deep learning software for Windows C# programmers.
MyCaffeTrainerDual.cs
1using System;
2using System.Collections.Generic;
4using System.Diagnostics;
5using System.Linq;
6using System.Text;
7using System.Threading;
8using System.Threading.Tasks;
9using MyCaffe.basecode;
11using MyCaffe.common;
12using MyCaffe.gym;
13using MyCaffe.param;
14
15namespace MyCaffe.trainers
16{
58 {
66 protected PropertySet m_properties = null;
70 protected int m_nProjectID = 0;
74 protected ConnectInfo m_dsCi = null;
75 IxTrainer m_itrainer = null;
76 double m_dfExplorationRate = 0;
77 double m_dfOptimalSelectionRate = 0;
78 double m_dfImmediateRewards = 0;
79 double m_dfGlobalRewards = 0;
80 double m_dfGlobalRewardsAve = 0;
81 double m_dfGlobalRewardsMax = -double.MaxValue;
82 int m_nGlobalEpisodeCount = 0;
83 int m_nGlobalEpisodeMax = 0;
84 double m_dfLoss = 0;
85 int m_nThreads = 1;
86 REWARD_TYPE m_rewardType = REWARD_TYPE.MAXIMUM;
87 TRAINER_TYPE m_trainerType = TRAINER_TYPE.PG_ST;
88 double m_dfAccuracy = 0;
89 int m_nIteration = 0;
90 int m_nIterations = -1;
91 IXMyCaffeCustomTrainerCallback m_icallback = null;
92 int m_nSnapshot = 0;
93 bool m_bSnapshot = false;
94 BucketCollection m_rgVocabulary = null;
95 Stage m_stage = Stage.RL;
96 bool m_bUsePreloadData = false;
97 object m_syncObj = new object();
98
99 enum TRAINER_TYPE
100 {
101 PG_MT,
102 PG_ST,
103 PG_SIMPLE,
104 C51_ST,
105 C51b_ST,
106 DQN_ST,
107 DQN_SIMPLE,
108 RNN_SIMPLE,
109 RNN_SUPER_SIMPLE
110 }
111
112 enum REWARD_TYPE
113 {
114 VALUE,
115 AVERAGE,
116 MAXIMUM
117 }
118
123 {
124 InitializeComponent();
125 }
126
131 public MyCaffeTrainerDual(IContainer container)
132 {
133 container.Add(this);
134
135 InitializeComponent();
136 }
137
138 #region Overrides
139
143 protected virtual string name
144 {
145 get { return "MyCaffe RL/RNN Dual Trainer"; }
146 }
147
151 protected virtual TRAINING_CATEGORY category
152 {
153 get { return TRAINING_CATEGORY.DUAL; }
154 }
155
162 protected virtual DatasetDescriptor get_dataset_override(int nProjectID, ConnectInfo ci = null)
163 {
164 return null;
165 }
166
171 protected virtual string get_information()
172 {
173 return "";
174 }
175
185 protected virtual IxTrainer create_trainerD(Component caffe, Stage stage)
186 {
189 m_dsCi = mycaffe.DatasetConnectInfo;
190
191 int.TryParse(mycaffe.CurrentProject.GetSolverSetting("max_iter"), out m_nIterations);
192 int.TryParse(mycaffe.CurrentProject.GetSolverSetting("snapshot"), out m_nSnapshot);
193
194 m_properties.SetProperty("UsePreLoadData", m_bUsePreloadData.ToString());
195
196 if (stage == Stage.RNN)
197 {
198 switch (m_trainerType)
199 {
200 case TRAINER_TYPE.RNN_SUPER_SIMPLE:
201 return new rnn.simple.TrainerRNNSimple<double>(mycaffe, m_properties, m_random, this, m_rgVocabulary);
202
203 case TRAINER_TYPE.RNN_SIMPLE:
204 return new rnn.simple.TrainerRNN<double>(mycaffe, m_properties, m_random, this, m_rgVocabulary);
205
206 default:
207 throw new Exception("The trainer type '" + m_trainerType.ToString() + "' is not supported in the RNN stage!");
208 }
209 }
210 else
211 {
212 switch (m_trainerType)
213 {
214 case TRAINER_TYPE.PG_SIMPLE:
215 return new pg.simple.TrainerPG<double>(mycaffe, m_properties, m_random, this);
216
217 case TRAINER_TYPE.PG_ST:
218 return new pg.st.TrainerPG<double>(mycaffe, m_properties, m_random, this);
219
220 case TRAINER_TYPE.PG_MT:
221 return new pg.mt.TrainerPG<double>(mycaffe, m_properties, m_random, this);
222
223 case TRAINER_TYPE.C51_ST:
224 return new dqn.c51.st.TrainerC51<double>(mycaffe, m_properties, m_random, this);
225
226 case TRAINER_TYPE.DQN_ST:
227 return new dqn.noisy.st.TrainerNoisyDqn<double>(mycaffe, m_properties, m_random, this);
228
229 case TRAINER_TYPE.DQN_SIMPLE:
230 return new dqn.noisy.simple.TrainerNoisyDqn<double>(mycaffe, m_properties, m_random, this);
231
232 default:
233 throw new Exception("The trainer type '" + m_trainerType.ToString() + "' is not supported in the RL stage!");
234 }
235 }
236 }
237
247 protected virtual IxTrainer create_trainerF(Component caffe, Stage stage)
248 {
251 m_dsCi = mycaffe.DatasetConnectInfo;
252
253 int.TryParse(mycaffe.CurrentProject.GetSolverSetting("max_iter"), out m_nIterations);
254 int.TryParse(mycaffe.CurrentProject.GetSolverSetting("snapshot"), out m_nSnapshot);
255
256 m_properties.SetProperty("UsePreLoadData", m_bUsePreloadData.ToString());
257
258 if (stage == Stage.RNN)
259 {
260 switch (m_trainerType)
261 {
262 case TRAINER_TYPE.RNN_SUPER_SIMPLE:
263 return new rnn.simple.TrainerRNNSimple<float>(mycaffe, m_properties, m_random, this, m_rgVocabulary);
264
265 case TRAINER_TYPE.RNN_SIMPLE:
266 return new rnn.simple.TrainerRNN<float>(mycaffe, m_properties, m_random, this, m_rgVocabulary);
267
268 default:
269 throw new Exception("The trainer type '" + m_trainerType.ToString() + "' is not supported in the RNN stage!");
270 }
271 }
272 else
273 {
274 switch (m_trainerType)
275 {
276 case TRAINER_TYPE.PG_SIMPLE:
277 return new pg.simple.TrainerPG<float>(mycaffe, m_properties, m_random, this);
278
279 case TRAINER_TYPE.PG_ST:
280 return new pg.st.TrainerPG<float>(mycaffe, m_properties, m_random, this);
281
282 case TRAINER_TYPE.PG_MT:
283 return new pg.mt.TrainerPG<float>(mycaffe, m_properties, m_random, this);
284
285 case TRAINER_TYPE.C51_ST:
286 return new dqn.c51.st.TrainerC51<float>(mycaffe, m_properties, m_random, this);
287
288 case TRAINER_TYPE.DQN_ST:
289 return new dqn.noisy.st.TrainerNoisyDqn<float>(mycaffe, m_properties, m_random, this);
290
291 case TRAINER_TYPE.DQN_SIMPLE:
292 return new dqn.noisy.simple.TrainerNoisyDqn<float>(mycaffe, m_properties, m_random, this);
293
294 default:
295 throw new Exception("The trainer type '" + m_trainerType.ToString() + "' is not supported in the RL stage!");
296 }
297 }
298 }
299
303 protected virtual void dispose()
304 {
305 }
306
314 protected virtual void initialize(InitializeArgs e)
315 {
316 }
317
321 protected virtual void shutdown()
322 {
323 }
324
330 protected virtual bool getData(GetDataArgs e)
331 {
332 return false;
333 }
334
340 protected virtual bool convertOutput(ConvertOutputArgs e)
341 {
342 return false;
343 }
344
350 {
351 }
352
358 protected virtual bool get_update_snapshot(out int nIteration, out double dfAccuracy)
359 {
360 nIteration = GlobalEpisodeCount;
361 dfAccuracy = GlobalRewards;
362
363 if (m_bSnapshot)
364 {
365 m_bSnapshot = false;
366 return true;
367 }
368
369 return false;
370 }
371
375 protected virtual void openUi()
376 {
377 }
378
388 protected virtual BucketCollection preloaddata(Log log, CancelEvent evtCancel, int nProjectID, PropertySet propertyOverride = null, ConnectInfo ci = null)
389 {
390 return null;
391 }
392
393 #endregion
394
395 #region IXMyCaffeCustomTrainer Interface
396
401 {
402 get { return m_stage; }
403 }
404
408 public string Name
409 {
410 get { return name; }
411 }
412
417 {
418 get { return category; }
419 }
420
426 public bool GetUpdateSnapshot(out int nIteration, out double dfAccuracy)
427 {
428 return get_update_snapshot(out nIteration, out dfAccuracy);
429 }
430
437 public DatasetDescriptor GetDatasetOverride(int nProjectID, ConnectInfo ci = null)
438 {
439 return get_dataset_override(nProjectID, ci);
440 }
441
446 {
447 get { return true; }
448 }
449
454 {
455 get { return true; }
456 }
457
462 {
463 get { return true; }
464 }
465
469 public void CleanUp()
470 {
471 cleanup(3000, true);
472 }
473
474 private void cleanup(int nWait, bool bCallShutdown)
475 {
476 lock (m_syncObj)
477 {
478 if (m_itrainer != null)
479 {
480 m_itrainer.Shutdown(nWait);
481 m_itrainer = null;
482 }
483
484 if (bCallShutdown)
485 shutdown();
486 }
487 }
488
495 public void Initialize(string strProperties, IXMyCaffeCustomTrainerCallback icallback)
496 {
497 m_icallback = icallback;
498 m_properties = new PropertySet(strProperties);
499 m_nThreads = m_properties.GetPropertyAsInt("Threads", 1);
500
501 string strRewardType = m_properties.GetProperty("RewardType", false);
502 if (strRewardType == null)
503 strRewardType = "VAL";
504 else
505 strRewardType = strRewardType.ToUpper();
506
507 if (strRewardType == "VAL" || strRewardType == "VALUE")
508 m_rewardType = REWARD_TYPE.VALUE;
509 else if (strRewardType == "AVE" || strRewardType == "AVERAGE")
510 m_rewardType = REWARD_TYPE.AVERAGE;
511
512 string strTrainerType = m_properties.GetProperty("TrainerType");
513
514 switch (strTrainerType)
515 {
516 case "PG.SIMPLE": // bare bones model (Sigmoid only)
517 m_trainerType = TRAINER_TYPE.PG_SIMPLE;
518 m_stage = Stage.RL;
519 break;
520
521 case "PG.ST": // single thread (Sigmoid and Softmax)
522 m_trainerType = TRAINER_TYPE.PG_ST;
523 m_stage = Stage.RL;
524 break;
525
526 case "PG":
527 case "PG.MT": // multi-thread (Sigmoid and Softmax)
528 m_trainerType = TRAINER_TYPE.PG_MT;
529 m_stage = Stage.RL;
530 break;
531
532 case "C51.ST": // single threaded C51
533 m_trainerType = TRAINER_TYPE.C51_ST;
534 m_stage = Stage.RL;
535 break;
536
537 case "DQN.ST": // single threaded Noisy DQN
538 m_trainerType = TRAINER_TYPE.DQN_ST;
539 m_stage = Stage.RL;
540 break;
541
542 case "DQN.SIMPLE": // single threaded Noisy DQN
543 m_trainerType = TRAINER_TYPE.DQN_SIMPLE;
544 m_stage = Stage.RL;
545 break;
546
547 case "RNN.SIMPLE":
548 m_trainerType = TRAINER_TYPE.RNN_SIMPLE;
549 m_stage = Stage.RNN;
550 break;
551
552 case "RNN.SUPER.SIMPLE":
553 m_trainerType = TRAINER_TYPE.RNN_SUPER_SIMPLE;
554 m_stage = Stage.RNN;
555 break;
556
557 default:
558 throw new Exception("Unknown trainer type '" + strTrainerType + "'!");
559 }
560 }
561
562 private Stage getStage()
563 {
564 if (m_trainerType == TRAINER_TYPE.RNN_SIMPLE || m_trainerType == TRAINER_TYPE.RNN_SUPER_SIMPLE)
565 return Stage.RNN;
566 else
567 return Stage.RL;
568 }
569
570 private IxTrainer createTrainer(Component mycaffe, Stage stage)
571 {
572 IxTrainer itrainer = null;
573
574 if (mycaffe is MyCaffeControl<double>)
575 itrainer = create_trainerD(mycaffe, stage);
576 else
577 itrainer = create_trainerF(mycaffe, stage);
578
579 itrainer.Initialize();
580
581 return itrainer;
582 }
583
590 public void Test(Component mycaffe, int nIterationOverride, ITERATOR_TYPE type = ITERATOR_TYPE.ITERATION)
591 {
592 if (m_itrainer == null)
593 m_itrainer = createTrainer(mycaffe, getStage());
594
595 if (nIterationOverride == -1)
596 nIterationOverride = m_nIterations;
597
598 m_itrainer.Test(nIterationOverride, type);
599 cleanup(500, false);
600 }
601
609 public void Train(Component mycaffe, int nIterationOverride, ITERATOR_TYPE type = ITERATOR_TYPE.ITERATION, TRAIN_STEP step = TRAIN_STEP.NONE)
610 {
611 if (m_itrainer == null)
612 m_itrainer = createTrainer(mycaffe, getStage());
613
614 if (nIterationOverride == -1)
615 nIterationOverride = m_nIterations;
616
617 m_itrainer.Train(nIterationOverride, type, step);
618 cleanup(1000, false);
619 }
620
621 #endregion
622
627 {
628 initialize(e);
629 }
630
634 public void OnShutdown()
635 {
636 shutdown();
637 }
638
642 public void OnGetData(GetDataArgs e)
643 {
644 getData(e);
645 }
646
652 {
653 convertOutput(e);
654 }
655
661 {
663 }
664
669 {
670 m_nIteration = e.Iteration;
671 m_dfAccuracy = e.TotalReward;
672 m_nIterations = e.MaxFrames;
673 m_dfImmediateRewards = e.Reward;
674 m_dfGlobalRewards = e.TotalReward;
675 m_dfGlobalRewardsMax = Math.Max(m_dfGlobalRewardsMax, e.TotalReward);
676 m_dfGlobalRewardsAve = (1.0 / (double)m_nThreads) * e.TotalReward + ((m_nThreads - 1) / (double)m_nThreads) * m_dfGlobalRewardsAve;
677 m_dfExplorationRate = e.ExplorationRate;
678 m_dfOptimalSelectionRate = e.OptimalSelectionCoefficient;
679
680 if (m_nThreads > 1)
681 m_nGlobalEpisodeCount++;
682 else
683 m_nGlobalEpisodeCount = e.Frames;
684
685 m_nGlobalEpisodeMax = e.MaxFrames;
686 m_dfLoss = e.Loss;
687
688 if (m_icallback != null)
689 {
690 Dictionary<string, double> rgValues = new Dictionary<string, double>();
691 rgValues.Add("GlobalIteration", GlobalEpisodeCount);
692 rgValues.Add("GlobalLoss", GlobalLoss);
693 rgValues.Add("LearningRate", e.LearningRate);
694 rgValues.Add("GlobalAccuracy", GlobalRewards);
695 rgValues.Add("Threads", m_nThreads);
696
697 m_icallback.Update(TrainingCategory, rgValues);
698 }
699
700 e.NewFrameCount = m_nGlobalEpisodeCount;
701
702 if (e.Index == 0 && m_nSnapshot > 0 && m_nGlobalEpisodeCount > 0 && (m_nGlobalEpisodeCount % m_nSnapshot) == 0)
703 m_bSnapshot = true;
704 }
705
709 public void OnWait(WaitArgs e)
710 {
711 Thread.Sleep(e.Wait);
712 }
713
719 public double GetProperty(string strProp)
720 {
721 switch (strProp)
722 {
723 case "GlobalLoss":
724 return GlobalLoss;
725
726 case "GlobalAccuracy": // RNN
727 return m_dfAccuracy;
728
729 case "GlobalIteration": // RNN
730 return m_nIteration;
731
732 case "GlobalMaxIterations":
733 return m_nIterations;
734
735 case "GlobalRewards":
736 return GlobalRewards;
737
738 case "GlobalEpisodeCount":
739 return GlobalEpisodeCount;
740
741 case "ExplorationRate":
742 return ExplorationRate;
743
744 default:
745 throw new Exception("The property '" + strProp + "' is not supported by the MyCaffeTrainerRNN.");
746 }
747 }
748
758 public double GlobalRewards
759 {
760 get
761 {
762 switch (m_rewardType)
763 {
764 case REWARD_TYPE.VALUE:
765 return m_dfGlobalRewards;
766
767 case REWARD_TYPE.AVERAGE:
768 return m_dfGlobalRewardsAve;
769
770 default:
771 return (m_dfGlobalRewardsMax == -double.MaxValue) ? 0 : m_dfGlobalRewardsMax;
772 }
773 }
774 }
775
779 public double ImmediateRewards
780 {
781 get { return m_dfImmediateRewards; }
782 }
783
787 public double GlobalLoss
788 {
789 get { return m_dfLoss; }
790 }
791
796 {
797 get { return m_nIteration; }
798 }
799
804 {
805 get { return m_nGlobalEpisodeCount; }
806 }
807
812 {
813 get { return m_nGlobalEpisodeMax; }
814 }
815
819 public double ExplorationRate
820 {
821 get { return m_dfExplorationRate; }
822 }
823
828 {
829 get { return m_dfOptimalSelectionRate; }
830 }
831
835 public string Information
836 {
837 get { return get_information(); }
838 }
839
843 public void OpenUi()
844 {
845 openUi();
846 }
847
848 #region IXMyCaffeCustomTrainerRL Methods
849
857 {
858 if (m_itrainer == null)
859 m_itrainer = createTrainer(mycaffe, Stage.RL);
860
861 IxTrainerRL itrainer = m_itrainer as IxTrainerRL;
862 if (itrainer == null)
863 throw new Exception("The trainer must be set to to 'C51.ST', PG.SIMPLE', 'PG.ST' or 'PG.MT' to run in reinforcement learning mode.");
864
865 ResultCollection res = itrainer.RunOne(nDelay);
866 cleanup(50, false);
867
868 return res;
869 }
870
878 public byte[] Run(Component mycaffe, int nN, out string type)
879 {
880 if (m_itrainer == null)
881 m_itrainer = createTrainer(mycaffe, Stage.RL);
882
883 PropertySet runProp = null;
885 if (icallback != null)
886 runProp = icallback.GetRunProperties();
887
888 IxTrainerRL itrainer = m_itrainer as IxTrainerRL;
889 if (itrainer == null)
890 throw new Exception("The IxTrainerRL interface must be implemented.");
891
892 byte[] rgResults = itrainer.Run(nN, runProp, out type);
893 cleanup(0, false);
894
895 return rgResults;
896 }
897
898 #endregion
899
900 #region IXMyCaffeCustomTrainerRNN
901
908 float[] IXMyCaffeCustomTrainerRNN.Run(Component mycaffe, int nN)
909 {
910 if (m_itrainer == null)
911 m_itrainer = createTrainer(mycaffe, Stage.RNN);
912
913 IxTrainerRNN itrainer = m_itrainer as IxTrainerRNN;
914 if (itrainer == null)
915 throw new Exception("The trainer must be set to to 'RNN.SIMPLE' to run in recurrent learning mode.");
916
917 PropertySet runProp = null;
919 if (icallback != null)
920 runProp = icallback.GetRunProperties();
921
922 float[] rgResults = itrainer.Run(nN, runProp);
923 cleanup(0, false);
924
925 return rgResults;
926 }
927
935 byte[] IXMyCaffeCustomTrainerRNN.Run(Component mycaffe, int nN, out string type)
936 {
937 if (m_itrainer == null)
938 m_itrainer = createTrainer(mycaffe, Stage.RNN);
939
940 IxTrainerRNN itrainer = m_itrainer as IxTrainerRNN;
941 if (itrainer == null)
942 throw new Exception("The trainer must be set to to 'RNN.SIMPLE' to run in recurrent learning mode.");
943
944 PropertySet runProp = null;
945 IXMyCaffeCustomTrainerCallbackRNN icallback = m_icallback as IXMyCaffeCustomTrainerCallbackRNN;
946 if (icallback != null)
947 runProp = icallback.GetRunProperties();
948
949 byte[] rgResults = itrainer.Run(nN, runProp, out type);
950 m_itrainer.Shutdown(0);
951 m_itrainer = null;
952
953 return rgResults;
954 }
955
965 BucketCollection IXMyCaffeCustomTrainerRNN.PreloadData(Log log, CancelEvent evtCancel, int nProjectID, PropertySet propertyOverride = null, ConnectInfo ci = null)
966 {
967 return preloaddata(log, evtCancel, nProjectID, propertyOverride, ci);
968 }
969
978 string IXMyCaffeCustomTrainerRNN.ResizeModel(Log log, string strModel, BucketCollection rgVocabulary)
979 {
980 if (rgVocabulary == null || rgVocabulary.Count == 0)
981 return strModel;
982
983 int nVocabCount = rgVocabulary.Count;
985 string strEmbedName = "";
986 EmbedParameter embed = null;
987 string strIpName = "";
988 InnerProductParameter ip = null;
989
990 foreach (LayerParameter layer in p.layer)
991 {
992 if (layer.type == LayerParameter.LayerType.EMBED)
993 {
994 strEmbedName = layer.name;
995 embed = layer.embed_param;
996 }
997 else if (layer.type == LayerParameter.LayerType.INNERPRODUCT)
998 {
999 strIpName = layer.name;
1000 ip = layer.inner_product_param;
1001 }
1002 }
1003
1004 if (embed != null)
1005 {
1006 if (embed.input_dim != (uint)nVocabCount)
1007 {
1008 log.WriteLine("WARNING: Embed layer '" + strEmbedName + "' input dim changed from " + embed.input_dim.ToString() + " to " + nVocabCount.ToString() + " to accomodate for the vocabulary count.");
1009 embed.input_dim = (uint)nVocabCount;
1010 }
1011 }
1012
1013 if (ip != null && ip.num_output != (uint)nVocabCount)
1014 {
1015 log.WriteLine("WARNING: InnerProduct layer '" + strIpName + "' num_output changed from " + ip.num_output.ToString() + " to " + nVocabCount.ToString() + " to accomodate for the vocabulary count.");
1016 ip.num_output = (uint)nVocabCount;
1017 }
1018
1019 m_rgVocabulary = rgVocabulary;
1020
1021 RawProto proto = p.ToProto("root");
1022 return proto.ToString();
1023 }
1024
1025 #endregion
1026 }
1027}
The MyCaffeControl is the main object used to manage all training, testing and running of the MyCaffe...
ConnectInfo DatasetConnectInfo
Returns the dataset connection information, if used (default = null).
ProjectEx CurrentProject
Returns the name of the currently loaded project.
The BucketCollection contains a set of Buckets.
int Count
Returns the number of Buckets.
The CancelEvent provides an extension to the manual cancel event that allows for overriding the manua...
Definition: CancelEvent.cs:17
The ConnectInfo class specifies the server, database and username/password used to connect to a datab...
Definition: ConnectInfo.cs:14
The CryptoRandom is a random number generator that can use either the standard .Net Random objec or t...
Definition: CryptoRandom.cs:14
The Log class provides general output in text form.
Definition: Log.cs:13
void WriteLine(string str, bool bOverrideEnabled=false, bool bHeader=false, bool bError=false, bool bDisable=false)
Write a line of output.
Definition: Log.cs:80
string GetSolverSetting(string strParam)
Get a setting from the solver descriptor.
Definition: ProjectEx.cs:453
int OriginalID
Get/set the original project ID.
Definition: ProjectEx.cs:541
Specifies a key-value pair of properties.
Definition: PropertySet.cs:16
string GetProperty(string strName, bool bThrowExceptions=true)
Returns a property as a string value.
Definition: PropertySet.cs:146
int GetPropertyAsInt(string strName, int nDefault=0)
Returns a property as an integer value.
Definition: PropertySet.cs:287
void SetProperty(string strName, string strVal)
Sets a property in the property set to a value if it exists, otherwise it adds the new property.
Definition: PropertySet.cs:211
The RawProto class is used to parse and output Google prototxt file data.
Definition: RawProto.cs:17
override string ToString()
Returns the RawProto as its full prototxt string.
Definition: RawProto.cs:681
static RawProto Parse(string str)
Parses a prototxt and places it in a new RawProto.
Definition: RawProto.cs:306
The DatasetDescriptor class describes a dataset which contains both a training data source and testin...
The ResultCollection contains the result of a given CaffeControl::Run.
Specifies the parameters used by the EmbedLayer.
uint input_dim
Specifies the input given as integers to be interpreted as one-hot vector indices with dimension num_...
Specifies the parameters for the InnerProductLayer.
uint num_output
The number of outputs for the layer.
Specifies the base parameter for all layers.
string name
Specifies the name of this LayerParameter.
LayerType type
Specifies the type of this LayerParameter.
EmbedParameter embed_param
Returns the parameter set when initialized with LayerType.EMBED
InnerProductParameter inner_product_param
Returns the parameter set when initialized with LayerType.INNERPRODUCT
LayerType
Specifies the layer type.
Specifies the parameters use to create a Net
Definition: NetParameter.cs:18
static NetParameter FromProto(RawProto rp)
Parse a RawProto into a new instance of the parameter.
override RawProto ToProto(string strName)
Constructor for the parameter.
List< LayerParameter > layer
The layers that make up the net. Each of their configurations, including connectivity and behavior,...
The ConvertOutputArgs is passed to the OnConvertOutput event.
Definition: EventArgs.cs:311
The GetDataArgs is passed to the OnGetData event to retrieve data.
Definition: EventArgs.cs:402
The GetStatusArgs is passed to the OnGetStatus event.
Definition: EventArgs.cs:166
double Loss
Returns the loss value.
Definition: EventArgs.cs:262
double OptimalSelectionCoefficient
Returns the optimal selection coefficient.
Definition: EventArgs.cs:302
int MaxFrames
Returns the maximum frame count.
Definition: EventArgs.cs:246
int Iteration
Returns the number of iterations (steps) run.
Definition: EventArgs.cs:221
int Frames
Returns the total frame count across all agents.
Definition: EventArgs.cs:238
int NewFrameCount
Get/set the new frame count.
Definition: EventArgs.cs:229
double ExplorationRate
Returns the current exploration rate.
Definition: EventArgs.cs:294
double TotalReward
Returns the total rewards.
Definition: EventArgs.cs:278
int Index
Returns the index of the caller.
Definition: EventArgs.cs:213
double Reward
Returns the immediate reward for the current episode.
Definition: EventArgs.cs:286
double LearningRate
Returns the current learning rate.
Definition: EventArgs.cs:270
The InitializeArgs is passed to the OnInitialize event.
Definition: EventArgs.cs:90
The MyCaffeTraininerDual is used to perform both reinforcement and recurrent learning training tasks ...
virtual TRAINING_CATEGORY category
Override when using a training method other than the REINFORCEMENT method (the default).
void OnUpdateStatus(GetStatusArgs e)
The OnGetStatus callback fires on each iteration within the Train method.
ConnectInfo m_dsCi
Optionally, specifies the dataset connection info, or null.
virtual BucketCollection preloaddata(Log log, CancelEvent evtCancel, int nProjectID, PropertySet propertyOverride=null, ConnectInfo ci=null)
The preloaddata method gives the custom trainer an opportunity to pre-load any data.
virtual string get_information()
Returns information describing the specific trainer, such as the gym used, if any.
double ImmediateRewards
Returns the immediate rewards for the current training cycle as opposed to the averaged rewards.
bool IsRunningSupported
Returns whether or not Running is supported.
void Initialize(string strProperties, IXMyCaffeCustomTrainerCallback icallback)
Initializes a new custom trainer by loading the key-value pair of properties into the property set.
virtual bool getData(GetDataArgs e)
Override called by the OnGetData event fired by the Trainer to retrieve a new set of observation coll...
void Test(Component mycaffe, int nIterationOverride, ITERATOR_TYPE type=ITERATOR_TYPE.ITERATION)
Create a new trainer and use it to run a test cycle using the current 'stage' = RNN or RL.
int GlobalEpisodeMax
Returns the maximum global episode count.
virtual void testAccuracyUpdate(TestAccuracyUpdateArgs e)
Override called by the OnTestAccuracyUpdate event fired from within the Run method and is used to giv...
int GlobalEpisodeCount
Returns the global episode count.
bool IsTrainingSupported
Returns whether or not Training is supported.
byte[] Run(Component mycaffe, int nN, out string type)
Run the network using the run technique implemented by this trainer.
virtual void openUi()
Called by OpenUi, override this when a UI (via WCF) should be displayed.
double ExplorationRate
Returns the current exploration rate.
int m_nProjectID
Specifies the project ID of the project held by the instance of MyCaffe.
MyCaffeTrainerDual(IContainer container)
The constructor.
void OnWait(WaitArgs e)
The OnWait callback fires when waiting for a shutdown.
string Information
Returns information describing the trainer.
void OnTestAccuracyUpdate(TestAccuracyUpdateArgs e)
The OnTestAccuracyUpdate callback fires from within the Run method and is used to give the recipient ...
void OnInitialize(InitializeArgs e)
The OnIntialize callback fires when initializing the trainer.
bool GetUpdateSnapshot(out int nIteration, out double dfAccuracy)
Returns true when the training is ready for a snap-shot, false otherwise.
double OptimalSelectionRate
Returns the rate of selection from the optimal set with the highest reward (this setting is optional,...
void OpenUi()
Open the user interface for the trainer, of one exists.
virtual DatasetDescriptor get_dataset_override(int nProjectID, ConnectInfo ci=null)
Returns a dataset override to use (if any) instead of the project's dataset. If there is no dataset o...
TRAINING_CATEGORY TrainingCategory
Returns the training category of the custom trainer (default = REINFORCEMENT).
CryptoRandom m_random
Random number generator used to get initial actions, etc.
int GlobalIteration
Returns the global iteration.
virtual IxTrainer create_trainerF(Component caffe, Stage stage)
Optionally overridden to return a new type of trainer.
virtual string name
Overriden to give the actual name of the custom trainer.
DatasetDescriptor GetDatasetOverride(int nProjectID, ConnectInfo ci=null)
Returns a dataset override to use (if any) instead of the project's dataset. If there is no dataset o...
virtual bool get_update_snapshot(out int nIteration, out double dfAccuracy)
Returns true when the training is ready for a snap-shot, false otherwise.
double GetProperty(string strProp)
Return a property value from the trainer.
bool IsTestingSupported
Returns whether or not Testing is supported.
virtual void initialize(InitializeArgs e)
Override called by the Initialize method of the trainer.
virtual bool convertOutput(ConvertOutputArgs e)
Override called by the OnConvertOutput event fired by the Trainer to convert the network output into ...
virtual void dispose()
Override to dispose of resources used.
double GlobalLoss
Return the global loss.
double? GlobalRewards
Returns the global rewards based on the reward type specified by the 'RewardType' property.
PropertySet m_properties
Specifies the properties parsed from the key-value pair passed to the Initialize method.
string Name
Returns the name of the custom trainer. This method calls the 'name' override.
void Train(Component mycaffe, int nIterationOverride, ITERATOR_TYPE type=ITERATOR_TYPE.ITERATION, TRAIN_STEP step=TRAIN_STEP.NONE)
Create a new trainer and use it to run a training cycle using the current 'stage' = RNN or RL.
void OnConvertOutput(ConvertOutputArgs e)
The OnConvertOutput callback fires from within the Run method and is used to convert the network outp...
virtual void shutdown()
Override called from within the CleanUp method.
void CleanUp()
Releases any resources used by the component.
void OnShutdown()
The OnShutdown callback fires when shutting down the trainer.
virtual IxTrainer create_trainerD(Component caffe, Stage stage)
Optionally overridden to return a new type of trainer.
void OnGetData(GetDataArgs e)
The OnGetData callback fires from within the Train method and is used to get a new observation data.
The TestAccuracyUpdateArgs are passed to the OnTestAccuracyUpdate event.
Definition: EventArgs.cs:553
The WaitArgs is passed to the OnWait event.
Definition: EventArgs.cs:65
int Wait
Returns the amount of time to wait in milliseconds.
Definition: EventArgs.cs:81
The Component class is a standard Microsoft.NET class that implements the IComponent interface and is...
Definition: Component.cs:18
The IXMyCaffeCustomTrainerCallback interface is used to call back to the parent running the custom tr...
Definition: Interfaces.cs:199
void Update(TRAINING_CATEGORY cat, Dictionary< string, double > rgValues)
The Update method updates the parent with the global iteration, reward and loss.
The IXMyCaffeCustomTrainerCallbackRNN interface is used to call back to the parent running the custom...
Definition: Interfaces.cs:212
PropertySet GetRunProperties()
The GetRunProperties method is used to qeury the properties used when Running, if any.
The IXMyCaffeCustomTrainer interface is used by the MyCaffeCustomTraininer components that provide va...
Definition: Interfaces.cs:135
ResultCollection RunOne(Component mycaffe, int nDelay)
Run the network using the run technique implemented by this trainer.
The IXMyCaffeCustomTrainer interface is used by the MyCaffeCustomTraininer components that provide va...
Definition: Interfaces.cs:158
float[] Run(Component mycaffe, int nN)
Run the network using the run technique implemented by this trainer.
The IxTrainerCallbackRNN provides functions used by each trainer to 'call-back' to the parent for inf...
Definition: Interfaces.cs:348
The IxTrainer interface is implemented by each Trainer.
Definition: Interfaces.cs:224
bool Train(int nN, ITERATOR_TYPE type, TRAIN_STEP step)
Train the network.
bool Test(int nN, ITERATOR_TYPE type)
Test the newtork.
bool Shutdown(int nWait)
Shutdown the trainer.
The IxTrainerRL interface is implemented by each RL Trainer.
Definition: Interfaces.cs:257
ResultCollection RunOne(int nDelay=1000)
Run a single cycle on the trainer.
byte[] Run(int nN, PropertySet runProp, out string type)
Run a number of 'nN' samples on the trainer.
The IxTrainerRL interface is implemented by each RL Trainer.
Definition: Interfaces.cs:279
float[] Run(int nN, PropertySet runProp)
Run a number of 'nN' samples on the trainer.
The descriptors namespace contains all descriptor used to describe various items stored within the da...
The MyCaffe.basecode contains all generic types used throughout MyCaffe.
Definition: Annotation.cs:12
TRAINING_CATEGORY
Defines the category of training.
Definition: Interfaces.cs:34
Stage
Specifies the stage underwhich to run a custom trainer.
Definition: Interfaces.cs:88
The MyCaffe.common namespace contains common MyCaffe classes.
Definition: BatchInput.cs:8
TRAIN_STEP
Defines the training stepping method (if any).
Definition: Interfaces.cs:131
The MyCaffe.gym namespace contains all classes related to the Gym's supported by MyCaffe.
The MyCaffe.param namespace contains parameters used to create models.
The MyCaffe.trainers namespace contains all reinforcement and recurrent learning trainers.
ITERATOR_TYPE
Specifies the iterator type to use.
Definition: Interfaces.cs:22
The MyCaffe namespace contains the main body of MyCaffe code that closesly tracks the C++ Caffe open-...
Definition: Annotation.cs:12