MyCaffe  1.12.2.41
Deep learning software for Windows C# programmers.
MyCaffeTrainerRNN.cs
1using System;
2using System.Collections.Generic;
4using System.Diagnostics;
5using System.Linq;
6using System.Text;
7using System.Threading;
8using System.Threading.Tasks;
9using MyCaffe.basecode;
11using MyCaffe.common;
12using MyCaffe.gym;
13using MyCaffe.param;
14
15namespace MyCaffe.trainers
16{
33 {
37 protected PropertySet m_properties = null;
41 protected int m_nProjectID = 0;
45 protected ConnectInfo m_dsCi = null;
46 IxTrainerRNN m_itrainer = null;
47 TRAINER_TYPE m_trainerType = TRAINER_TYPE.RNN_SIMPLE;
48 IXMyCaffeCustomTrainerCallback m_icallback = null;
49 CryptoRandom m_random = new CryptoRandom();
50 int m_nSnapshot = 0;
51 bool m_bSnapshot = false;
52 double m_dfLoss = 0;
53 double m_dfAccuracy = 0;
54 int m_nIteration = 0;
55 int m_nIterations = -1;
56 BucketCollection m_rgVocabulary = null;
57 object m_syncObj = new object();
58
59 enum TRAINER_TYPE
60 {
61 RNN_SIMPLE,
62 RNN_SUPER_SIMPLE
63 }
64
69 {
70 InitializeComponent();
71 }
72
77 public MyCaffeTrainerRNN(IContainer container)
78 {
79 container.Add(this);
80
81 InitializeComponent();
82 }
83
84 #region Overrides
85
89 protected virtual string name
90 {
91 get { return "MyCaffe RNN Trainer"; }
92 }
93
97 protected virtual TRAINING_CATEGORY category
98 {
99 get { return TRAINING_CATEGORY.RECURRENT; }
100 }
101
108 protected virtual DatasetDescriptor get_dataset_override(int nProjectID, ConnectInfo ci = null)
109 {
110 return null;
111 }
112
117 protected virtual string get_information()
118 {
119 return "";
120 }
121
130 protected virtual IxTrainerRNN create_trainerD(Component caffe)
131 {
134 m_dsCi = mycaffe.DatasetConnectInfo;
135
136 int.TryParse(mycaffe.CurrentProject.GetSolverSetting("max_iter"), out m_nIterations);
137 int.TryParse(mycaffe.CurrentProject.GetSolverSetting("snapshot"), out m_nSnapshot);
138
139 switch (m_trainerType)
140 {
141 case TRAINER_TYPE.RNN_SUPER_SIMPLE:
142 return new rnn.simple.TrainerRNNSimple<double>(mycaffe, m_properties, m_random, this, m_rgVocabulary);
143
144 case TRAINER_TYPE.RNN_SIMPLE:
145 return new rnn.simple.TrainerRNN<double>(mycaffe, m_properties, m_random, this, m_rgVocabulary);
146
147 default:
148 throw new Exception("Unknown trainer type '" + m_trainerType.ToString() + "'!");
149 }
150 }
151
160 protected virtual IxTrainerRNN create_trainerF(Component caffe)
161 {
164 m_dsCi = mycaffe.DatasetConnectInfo;
165
166 int.TryParse(mycaffe.CurrentProject.GetSolverSetting("max_iter"), out m_nIterations);
167 int.TryParse(mycaffe.CurrentProject.GetSolverSetting("snapshot"), out m_nSnapshot);
168
169 switch (m_trainerType)
170 {
171 case TRAINER_TYPE.RNN_SUPER_SIMPLE:
172 return new rnn.simple.TrainerRNNSimple<float>(mycaffe, m_properties, m_random, this, m_rgVocabulary);
173
174 case TRAINER_TYPE.RNN_SIMPLE:
175 return new rnn.simple.TrainerRNN<float>(mycaffe, m_properties, m_random, this, m_rgVocabulary);
176
177 default:
178 throw new Exception("Unknown trainer type '" + m_trainerType.ToString() + "'!");
179 }
180 }
181
185 protected virtual void dispose()
186 {
187 }
188
196 protected virtual void initialize(InitializeArgs e)
197 {
198 }
199
203 protected virtual void shutdown()
204 {
205 }
206
212 protected virtual bool getData(GetDataArgs e)
213 {
214 return false;
215 }
216
222 protected virtual bool convertOutput(ConvertOutputArgs e)
223 {
224 return false;
225 }
226
232 {
233 }
234
240 protected virtual bool get_update_snapshot(out int nIteration, out double dfAccuracy)
241 {
242 nIteration = (int)GetProperty("GlobalIteration");
243 dfAccuracy = GetProperty("GlobalAccuracy");
244
245 if (m_bSnapshot)
246 {
247 m_bSnapshot = false;
248 return true;
249 }
250
251 return false;
252 }
253
257 protected virtual void openUi()
258 {
259 }
260
270 protected virtual BucketCollection preloaddata(Log log, CancelEvent evtCancel, int nProjectID, PropertySet propertyOverride = null, ConnectInfo ci = null)
271 {
272 return null;
273 }
274
275 #endregion
276
277 #region IXMyCaffeCustomTrainer Interface
278
283 {
284 get { return Stage.RNN; }
285 }
286
290 public string Name
291 {
292 get { return name; }
293 }
294
299 {
300 get { return category; }
301 }
302
308 public bool GetUpdateSnapshot(out int nIteration, out double dfAccuracy)
309 {
310 return get_update_snapshot(out nIteration, out dfAccuracy);
311 }
312
319 public DatasetDescriptor GetDatasetOverride(int nProjectID, ConnectInfo ci = null)
320 {
321 return get_dataset_override(nProjectID, ci);
322 }
323
328 {
329 get { return true; }
330 }
331
336 {
337 get { return true; }
338 }
339
344 {
345 get { return true; }
346 }
347
351 public void CleanUp()
352 {
353 cleanup(3000, true);
354 }
355
356 private void cleanup(int nWait, bool bCallShutdown = false)
357 {
358 lock (m_syncObj)
359 {
360 if (m_itrainer != null)
361 {
362 m_itrainer.Shutdown(nWait);
363 m_itrainer = null;
364 }
365
366 if (bCallShutdown)
367 shutdown();
368 }
369 }
370
377 public void Initialize(string strProperties, IXMyCaffeCustomTrainerCallback icallback)
378 {
379 m_icallback = icallback;
380 m_properties = new PropertySet(strProperties);
381
382 string strTrainerType = m_properties.GetProperty("TrainerType");
383
384 switch (strTrainerType)
385 {
386 case "RNN.SIMPLE": // bare bones model
387 m_trainerType = TRAINER_TYPE.RNN_SIMPLE;
388 break;
389
390 default:
391 throw new Exception("Unknown trainer type '" + strTrainerType + "'!");
392 }
393 }
394
395 private IxTrainerRNN createTrainer(Component mycaffe)
396 {
397 IxTrainerRNN itrainer = null;
398
399 if (mycaffe is MyCaffeControl<double>)
400 itrainer = create_trainerD(mycaffe);
401 else
402 itrainer = create_trainerF(mycaffe);
403
404 itrainer.Initialize();
405
406 return itrainer;
407 }
408
415 public float[] Run(Component mycaffe, int nN)
416 {
417 if (m_itrainer == null)
418 m_itrainer = createTrainer(mycaffe);
419
420 PropertySet runProp = null;
422 if (icallback != null)
423 runProp = icallback.GetRunProperties();
424
425 float[] rgResults = m_itrainer.Run(nN, runProp);
426 cleanup(0);
427
428 return rgResults;
429 }
430
438 public byte[] Run(Component mycaffe, int nN, out string type)
439 {
440 if (m_itrainer == null)
441 m_itrainer = createTrainer(mycaffe);
442
443 PropertySet runProp = null;
445 if (icallback != null)
446 runProp = icallback.GetRunProperties();
447
448 byte[] rgResults = m_itrainer.Run(nN, runProp, out type);
449 cleanup(0);
450
451 return rgResults;
452 }
453
460 public void Test(Component mycaffe, int nIterationOverride, ITERATOR_TYPE type = ITERATOR_TYPE.ITERATION)
461 {
462 if (m_itrainer == null)
463 m_itrainer = createTrainer(mycaffe);
464
465 if (nIterationOverride == -1)
466 nIterationOverride = m_nIterations;
467
468 m_itrainer.Test(nIterationOverride, type);
469 cleanup(0);
470 }
471
479 public void Train(Component mycaffe, int nIterationOverride, ITERATOR_TYPE type = ITERATOR_TYPE.ITERATION, TRAIN_STEP step = TRAIN_STEP.NONE)
480 {
481 if (m_itrainer == null)
482 m_itrainer = createTrainer(mycaffe);
483
484 if (nIterationOverride == -1)
485 nIterationOverride = m_nIterations;
486
487 m_itrainer.Train(nIterationOverride, type, step);
488 cleanup(0);
489 }
490
491 #endregion
492
497 {
498 initialize(e);
499 }
500
504 public void OnShutdown()
505 {
506 shutdown();
507 }
508
512 public void OnGetData(GetDataArgs e)
513 {
514 getData(e);
515 }
516
522 {
523 convertOutput(e);
524 }
525
531 {
533 }
534
539 {
540 if (m_icallback != null)
541 {
542 m_dfLoss = e.Loss;
543 m_nIteration = e.Frames;
544 m_nIterations = e.MaxFrames;
545 m_dfAccuracy = e.TotalReward;
546
547 Dictionary<string, double> rgValues = new Dictionary<string, double>();
548 rgValues.Add("GlobalIteration", e.Frames);
549 rgValues.Add("GlobalLoss", e.Loss);
550 rgValues.Add("LearningRate", e.LearningRate);
551 rgValues.Add("GlobalAccuracy", e.TotalReward);
552 m_icallback.Update(TrainingCategory, rgValues);
553 }
554 }
555
559 public void OnWait(WaitArgs e)
560 {
561 Thread.Sleep(e.Wait);
562 }
563
573 public double GetProperty(string strProp)
574 {
575 switch (strProp)
576 {
577 case "GlobalLoss":
578 return m_dfLoss;
579
580 case "GlobalAccuracy":
581 return m_dfAccuracy;
582
583 case "GlobalIteration":
584 return m_nIteration;
585
586 case "GlobalMaxIterations":
587 return m_nIterations;
588
589 default:
590 throw new Exception("The property '" + strProp + "' is not supported by the MyCaffeTrainerRNN.");
591 }
592 }
593
597 public string Information
598 {
599 get { return get_information(); }
600 }
601
605 public void OpenUi()
606 {
607 openUi();
608 }
609
619 public BucketCollection PreloadData(Log log, CancelEvent evtCancel, int nProjectID, PropertySet propertyOverride = null, ConnectInfo ci = null)
620 {
621 return preloaddata(log, evtCancel, nProjectID, propertyOverride, ci);
622 }
623
632 public string ResizeModel(Log log, string strModel, BucketCollection rgVocabulary)
633 {
634 if (rgVocabulary == null || rgVocabulary.Count == 0)
635 return strModel;
636
637 int nVocabCount = rgVocabulary.Count;
639 string strEmbedName = "";
640 EmbedParameter embed = null;
641 string strIpName = "";
642 InnerProductParameter ip = null;
643
644 foreach (LayerParameter layer in p.layer)
645 {
646 if (layer.type == LayerParameter.LayerType.EMBED)
647 {
648 strEmbedName = layer.name;
649 embed = layer.embed_param;
650 }
651 else if (layer.type == LayerParameter.LayerType.INNERPRODUCT)
652 {
653 strIpName = layer.name;
654 ip = layer.inner_product_param;
655 }
656 }
657
658 if (embed != null)
659 {
660 if (embed.input_dim != (uint)nVocabCount)
661 {
662 log.WriteLine("WARNING: Embed layer '" + strEmbedName + "' input dim changed from " + embed.input_dim.ToString() + " to " + nVocabCount.ToString() + " to accomodate for the vocabulary count.");
663 embed.input_dim = (uint)nVocabCount;
664 }
665 }
666
667 if (ip.num_output != (uint)nVocabCount)
668 {
669 log.WriteLine("WARNING: InnerProduct layer '" + strIpName + "' num_output changed from " + ip.num_output.ToString() + " to " + nVocabCount.ToString() + " to accomodate for the vocabulary count.");
670 ip.num_output = (uint)nVocabCount;
671 }
672
673 m_rgVocabulary = rgVocabulary;
674
675 RawProto proto = p.ToProto("root");
676 return proto.ToString();
677 }
678 }
679}
The MyCaffeControl is the main object used to manage all training, testing and running of the MyCaffe...
ConnectInfo DatasetConnectInfo
Returns the dataset connection information, if used (default = null).
ProjectEx CurrentProject
Returns the name of the currently loaded project.
The BucketCollection contains a set of Buckets.
int Count
Returns the number of Buckets.
The CancelEvent provides an extension to the manual cancel event that allows for overriding the manua...
Definition: CancelEvent.cs:17
The ConnectInfo class specifies the server, database and username/password used to connect to a datab...
Definition: ConnectInfo.cs:14
The CryptoRandom is a random number generator that can use either the standard .Net Random objec or t...
Definition: CryptoRandom.cs:14
The Log class provides general output in text form.
Definition: Log.cs:13
string GetSolverSetting(string strParam)
Get a setting from the solver descriptor.
Definition: ProjectEx.cs:453
int OriginalID
Get/set the original project ID.
Definition: ProjectEx.cs:541
Specifies a key-value pair of properties.
Definition: PropertySet.cs:16
string GetProperty(string strName, bool bThrowExceptions=true)
Returns a property as a string value.
Definition: PropertySet.cs:146
The RawProto class is used to parse and output Google prototxt file data.
Definition: RawProto.cs:17
override string ToString()
Returns the RawProto as its full prototxt string.
Definition: RawProto.cs:681
static RawProto Parse(string str)
Parses a prototxt and places it in a new RawProto.
Definition: RawProto.cs:306
The DatasetDescriptor class describes a dataset which contains both a training data source and testin...
Specifies the parameters used by the EmbedLayer.
Specifies the parameters for the InnerProductLayer.
uint num_output
The number of outputs for the layer.
Specifies the base parameter for all layers.
string name
Specifies the name of this LayerParameter.
LayerType type
Specifies the type of this LayerParameter.
EmbedParameter embed_param
Returns the parameter set when initialized with LayerType.EMBED
InnerProductParameter inner_product_param
Returns the parameter set when initialized with LayerType.INNERPRODUCT
LayerType
Specifies the layer type.
Specifies the parameters use to create a Net
Definition: NetParameter.cs:18
static NetParameter FromProto(RawProto rp)
Parse a RawProto into a new instance of the parameter.
The ConvertOutputArgs is passed to the OnConvertOutput event.
Definition: EventArgs.cs:311
The GetDataArgs is passed to the OnGetData event to retrieve data.
Definition: EventArgs.cs:402
The GetStatusArgs is passed to the OnGetStatus event.
Definition: EventArgs.cs:166
double Loss
Returns the loss value.
Definition: EventArgs.cs:262
int MaxFrames
Returns the maximum frame count.
Definition: EventArgs.cs:246
int Frames
Returns the total frame count across all agents.
Definition: EventArgs.cs:238
double TotalReward
Returns the total rewards.
Definition: EventArgs.cs:278
double LearningRate
Returns the current learning rate.
Definition: EventArgs.cs:270
The InitializeArgs is passed to the OnInitialize event.
Definition: EventArgs.cs:90
(Depreciated - use MyCaffeTrainerDual instead.) The MyCaffeTrainerRNN is used to perform recurrent ne...
void Initialize(string strProperties, IXMyCaffeCustomTrainerCallback icallback)
Initializes a new custom trainer by loading the key-value pair of properties into the property set.
bool IsTestingSupported
Returns whether or not Testing is supported.
double GetProperty(string strProp)
Returns a specific property value.
void CleanUp()
Releases any resources used by the component.
void OnShutdown()
The OnShutdown callback fires when shutting down the trainer.
virtual void dispose()
Override to dispose of resources used.
void OnGetData(GetDataArgs e)
The OnGetData callback fires from within the Train method and is used to get a new observation data.
byte[] Run(Component mycaffe, int nN, out string type)
Run the network using the run technique implemented by this trainer.
MyCaffeTrainerRNN(IContainer container)
The constructor.
void OpenUi()
Open the user interface for the trainer, of one exists.
void OnConvertOutput(ConvertOutputArgs e)
The OnConvertOutput callback fires from within the Run method and is used to convert the network outp...
void OnInitialize(InitializeArgs e)
The OnIntialize callback fires when initializing the trainer.
string ResizeModel(Log log, string strModel, BucketCollection rgVocabulary)
The ResizeModel method gives the custom trainer the opportunity to resize the model if needed.
void OnWait(WaitArgs e)
The OnWait callback fires when waiting for a shutdown.
void Train(Component mycaffe, int nIterationOverride, ITERATOR_TYPE type=ITERATOR_TYPE.ITERATION, TRAIN_STEP step=TRAIN_STEP.NONE)
Create a new trainer and use it to run a training cycle.
TRAINING_CATEGORY TrainingCategory
Returns the training category of the custom trainer (default = REINFORCEMENT).
void OnUpdateStatus(GetStatusArgs e)
The OnGetStatus callback fires on each iteration within the Train method.
virtual void testAccuracyUpdate(TestAccuracyUpdateArgs e)
Override called by the OnTestAccuracyUpdate event fired from within the Run method and is used to giv...
float[] Run(Component mycaffe, int nN)
Create a new trainer and use it to run a single run cycle.
bool GetUpdateSnapshot(out int nIteration, out double dfAccuracy)
Returns true when the training is ready for a snap-shot, false otherwise.
virtual DatasetDescriptor get_dataset_override(int nProjectID, ConnectInfo ci=null)
Returns a dataset override to use (if any) instead of the project's dataset. If there is no dataset o...
virtual BucketCollection preloaddata(Log log, CancelEvent evtCancel, int nProjectID, PropertySet propertyOverride=null, ConnectInfo ci=null)
The preloaddata method gives the custom trainer an opportunity to pre-load any data.
virtual string get_information()
Returns information describing the specific trainer, such as the gym used, if any.
string Information
Returns information describing the trainer.
virtual TRAINING_CATEGORY category
Override when using a training method other than the RECURRENT method (the default).
void OnTestAccuracyUpdate(TestAccuracyUpdateArgs e)
The OnTestAccuracyUpdate callback fires from within the Run method and is used to give the recipient ...
PropertySet m_properties
Specifies the properties parsed from the key-value pair passed to the Initialize method.
virtual void openUi()
Called by OpenUi, override this when a UI (via WCF) should be displayed.
DatasetDescriptor GetDatasetOverride(int nProjectID, ConnectInfo ci=null)
Returns a dataset override to use (if any) instead of the project's dataset. If there is no dataset o...
virtual void initialize(InitializeArgs e)
Override called by the Initialize method of the trainer.
virtual bool getData(GetDataArgs e)
Override called by the OnGetData event fired by the Trainer to retrieve a new set of observation coll...
virtual bool get_update_snapshot(out int nIteration, out double dfAccuracy)
Returns true when the training is ready for a snap-shot, false otherwise.
virtual IxTrainerRNN create_trainerF(Component caffe)
Optionally overridden to return a new type of trainer.
virtual IxTrainerRNN create_trainerD(Component caffe)
Optionally overridden to return a new type of trainer.
virtual void shutdown()
Override called from within the CleanUp method.
void Test(Component mycaffe, int nIterationOverride, ITERATOR_TYPE type=ITERATOR_TYPE.ITERATION)
Create a new trainer and use it to run a test cycle.
virtual string name
Overriden to give the actual name of the custom trainer.
ConnectInfo m_dsCi
Optionally, specifies the dataset connection info, or null.
int m_nProjectID
Specifies the project ID of the project held by the instance of MyCaffe.
bool IsRunningSupported
Returns whether or not Running is supported.
string Name
Returns the name of the custom trainer. This method calls the 'name' override.
bool IsTrainingSupported
Returns whether or not Training is supported.
virtual bool convertOutput(ConvertOutputArgs e)
Override called by the OnConvertOutput event fired by the Trainer to convert the network output into ...
BucketCollection PreloadData(Log log, CancelEvent evtCancel, int nProjectID, PropertySet propertyOverride=null, ConnectInfo ci=null)
The PreloadData method gives the custom trainer an opportunity to pre-load any data.
The TestAccuracyUpdateArgs are passed to the OnTestAccuracyUpdate event.
Definition: EventArgs.cs:553
The WaitArgs is passed to the OnWait event.
Definition: EventArgs.cs:65
int Wait
Returns the amount of time to wait in milliseconds.
Definition: EventArgs.cs:81
The Component class is a standard Microsoft.NET class that implements the IComponent interface and is...
Definition: Component.cs:18
The IXMyCaffeCustomTrainerCallback interface is used to call back to the parent running the custom tr...
Definition: Interfaces.cs:199
void Update(TRAINING_CATEGORY cat, Dictionary< string, double > rgValues)
The Update method updates the parent with the global iteration, reward and loss.
The IXMyCaffeCustomTrainerCallbackRNN interface is used to call back to the parent running the custom...
Definition: Interfaces.cs:212
PropertySet GetRunProperties()
The GetRunProperties method is used to qeury the properties used when Running, if any.
The IXMyCaffeCustomTrainer interface is used by the MyCaffeCustomTraininer components that provide va...
Definition: Interfaces.cs:158
The IxTrainerCallbackRNN provides functions used by each trainer to 'call-back' to the parent for inf...
Definition: Interfaces.cs:348
bool Initialize()
Initialize the trainer.
bool Train(int nN, ITERATOR_TYPE type, TRAIN_STEP step)
Train the network.
bool Test(int nN, ITERATOR_TYPE type)
Test the newtork.
bool Shutdown(int nWait)
Shutdown the trainer.
The IxTrainerRL interface is implemented by each RL Trainer.
Definition: Interfaces.cs:279
float[] Run(int nN, PropertySet runProp)
Run a number of 'nN' samples on the trainer.
The descriptors namespace contains all descriptor used to describe various items stored within the da...
The MyCaffe.basecode contains all generic types used throughout MyCaffe.
Definition: Annotation.cs:12
TRAINING_CATEGORY
Defines the category of training.
Definition: Interfaces.cs:34
Stage
Specifies the stage underwhich to run a custom trainer.
Definition: Interfaces.cs:88
The MyCaffe.common namespace contains common MyCaffe classes.
Definition: BatchInput.cs:8
TRAIN_STEP
Defines the training stepping method (if any).
Definition: Interfaces.cs:131
The MyCaffe.gym namespace contains all classes related to the Gym's supported by MyCaffe.
The MyCaffe.param namespace contains parameters used to create models.
The MyCaffe.trainers namespace contains all reinforcement and recurrent learning trainers.
ITERATOR_TYPE
Specifies the iterator type to use.
Definition: Interfaces.cs:22
The MyCaffe namespace contains the main body of MyCaffe code that closesly tracks the C++ Caffe open-...
Definition: Annotation.cs:12