MyCaffe  1.12.2.41
Deep learning software for Windows C# programmers.
TrainerRNNSimple.cs
1using System;
2using System.Collections.Generic;
3using System.Diagnostics;
4using System.Drawing;
5using System.Linq;
6using System.Runtime.InteropServices;
7using System.Text;
8using System.Threading;
9using System.Threading.Tasks;
10using MyCaffe.basecode;
11using MyCaffe.common;
12using MyCaffe.fillers;
13using MyCaffe.gym;
14using MyCaffe.layers;
15using MyCaffe.param;
16using MyCaffe.solvers;
17
19{
28 public class TrainerRNNSimple<T> : IxTrainerRNN, IDisposable
29 {
30 IxTrainerCallback m_icallback;
31 MyCaffeControl<T> m_mycaffe;
32 PropertySet m_properties;
33 CryptoRandom m_random;
34 BucketCollection m_rgVocabulary = null;
35 bool m_bUsePreloadData = true;
36 GetDataArgs m_getDataTrainArgs = null;
37
38
51 public TrainerRNNSimple(MyCaffeControl<T> mycaffe, PropertySet properties, CryptoRandom random, IxTrainerCallback icallback, BucketCollection rgVocabulary)
52 {
53 m_icallback = icallback;
54 m_mycaffe = mycaffe;
55 m_properties = properties;
56 m_random = random;
57 m_rgVocabulary = rgVocabulary;
58 m_bUsePreloadData = properties.GetPropertyAsBool("UsePreLoadData", true); ;
59 }
60
64 public void Dispose()
65 {
66 }
67
72 public bool Initialize()
73 {
74 m_mycaffe.CancelEvent.Reset();
75 m_icallback.OnInitialize(new InitializeArgs(m_mycaffe));
76 return true;
77 }
78
79 private void wait(int nWait)
80 {
81 int nWaitInc = 250;
82 int nTotalWait = 0;
83
84 while (nTotalWait < nWait)
85 {
86 m_icallback.OnWait(new WaitArgs(nWaitInc));
87 nTotalWait += nWaitInc;
88 }
89 }
90
96 public bool Shutdown(int nWait)
97 {
98 if (m_mycaffe != null)
99 {
100 m_mycaffe.CancelEvent.Set();
101 wait(nWait);
102 }
103
104 m_icallback.OnShutdown();
105
106 return true;
107 }
108
109 private void updateStatus(int nIteration, int nMaxIteration, double dfAccuracy, double dfLoss, double dfLearningRate)
110 {
111 GetStatusArgs args = new GetStatusArgs(0, nIteration, nIteration, nMaxIteration, dfAccuracy, 0, 0, 0, dfLoss, dfLearningRate);
112 m_icallback.OnUpdateStatus(args);
113 }
114
115 private float computeAccuracy(List<Tuple<float, float>> rg, float fThreshold)
116 {
117 int nMatch = 0;
118
119 for (int i=0; i<rg.Count; i++)
120 {
121 float fDiff = Math.Abs(rg[i].Item1 - rg[i].Item2);
122
123 if (fDiff < fThreshold)
124 nMatch++;
125 }
126
127 return (float)nMatch / (float)rg.Count;
128 }
129
136 public float[] Run(int nN, PropertySet runProp)
137 {
138 m_mycaffe.CancelEvent.Reset();
139 return null;
140 }
141
149 public byte[] Run(int nN, PropertySet runProp, out string type)
150 {
151 m_mycaffe.CancelEvent.Reset();
152 type = "";
153 return null;
154 }
155
162 public bool Test(int nN, ITERATOR_TYPE type)
163 {
164 return run(nN, type, TRAIN_STEP.NONE, Phase.TEST);
165 }
166
174 public bool Train(int nN, ITERATOR_TYPE type, TRAIN_STEP step)
175 {
176 return run(nN, type, step, Phase.TRAIN);
177 }
178
179 private bool run(int nN, ITERATOR_TYPE type, TRAIN_STEP step, Phase phase)
180 {
181 PropertySet prop = new PropertySet();
182
183 prop.SetProperty("TrainingStart", "0");
184
185 m_mycaffe.CancelEvent.Reset();
186
187 if (m_getDataTrainArgs == null)
188 {
189 m_getDataTrainArgs = new GetDataArgs(Phase.TRAIN, 0, m_mycaffe, m_mycaffe.Log, m_mycaffe.CancelEvent, true, -1);
190 m_getDataTrainArgs.ExtraProperties = prop;
191 m_icallback.OnGetData(m_getDataTrainArgs);
192 }
193
194 m_getDataTrainArgs.Action = 0;
195 m_getDataTrainArgs.Reset = false;
196
197 Net<T> net = m_mycaffe.GetInternalNet(Phase.TRAIN);
198 Solver<T> solver = m_mycaffe.GetInternalSolver();
199
200 InputLayer<T> input = net.layers[0] as InputLayer<T>;
201 if (input == null)
202 throw new Exception("Missing expected input layer!");
203
204 int nBatchSize = input.layer_param.input_param.shape[0].dim[0];
205 if (nBatchSize != 1)
206 throw new Exception("Expected batch size of 1!");
207
208 int nInputDim = input.layer_param.input_param.shape[0].dim[1];
209 int nOutputDim = input.layer_param.input_param.shape[3].dim[1];
210
211 string strVal = m_properties.GetProperty("BlobNames");
212 string[] rgstrVal = strVal.Split('|');
213 Dictionary<string, string> rgstrValMap = new Dictionary<string, string>();
214
215 foreach (string strVal1 in rgstrVal)
216 {
217 string[] rgstrVal2 = strVal1.Split('~');
218 if (rgstrVal2.Length != 2)
219 throw new Exception("Invalid BlobNames property, expected 'name=blobname'!");
220
221 rgstrValMap.Add(rgstrVal2[0], rgstrVal2[1]);
222 }
223
224 Blob<T> blobX = null;
225 if (rgstrValMap.ContainsKey("x"))
226 blobX = net.blob_by_name(rgstrValMap["x"]);
227
228 Blob<T> blobTt = null;
229 if (rgstrValMap.ContainsKey("tt"))
230 blobTt = net.blob_by_name(rgstrValMap["tt"]);
231
232 Blob<T> blobMask = null;
233 if (rgstrValMap.ContainsKey("mask"))
234 blobMask = net.blob_by_name(rgstrValMap["mask"]);
235
236 Blob<T> blobTarget = null;
237 if (rgstrValMap.ContainsKey("target"))
238 blobTarget = net.blob_by_name(rgstrValMap["target"]);
239
240 Blob<T> blobXhat = null;
241 if (rgstrValMap.ContainsKey("xhat"))
242 blobXhat = net.blob_by_name(rgstrValMap["xhat"]);
243
244 if (blobX == null)
245 throw new Exception("The 'x' blob was not found in the 'BlobNames' property!");
246 if (blobTt == null)
247 throw new Exception("The 'tt' blob was not found in the 'BlobNames' property!");
248 if (blobMask == null)
249 throw new Exception("The 'mask' blob was not found in the 'BlobNames' property!");
250 if (blobTarget == null)
251 throw new Exception("The 'target' blob was not found in the 'BlobNames' property!");
252 if (blobXhat == null)
253 throw new Exception("The 'xhat' blob was not found in the 'BlobNames' property!");
254
255 if (blobX.count() != nInputDim)
256 throw new Exception("The 'x' blob must have a count of '" + nInputDim.ToString() + "'!");
257 if (blobTt.count() != nInputDim)
258 throw new Exception("The 'tt' blob must have a count of '" + nInputDim.ToString() + "'!");
259 if (blobMask.count() != nInputDim)
260 throw new Exception("The 'mask' blob must have a count of '" + nInputDim.ToString() + "'!");
261 if (blobTarget.count() != nOutputDim)
262 throw new Exception("The 'target' blob must have a count of '" + nOutputDim.ToString() + "'!");
263 if (blobXhat.count() != nOutputDim)
264 throw new Exception("The 'xhat' blob must have a count of '" + nOutputDim.ToString() + "'!");
265
266 float[] rgInput = new float[nInputDim];
267 float[] rgTimeSteps = new float[nInputDim];
268 float[] rgMask = new float[nInputDim];
269 float[] rgTarget = new float[nOutputDim];
270
271 List<Tuple<float, float>> rgAccHistory = new List<Tuple<float, float>>();
272
273 for (int i = 0; i < nN; i++)
274 {
275 double dfLoss = 0;
276 double dfAccuracy = 0;
277 float fPredictedY = 0;
278 float fTargetY = 0;
279
280 m_icallback.OnGetData(m_getDataTrainArgs);
281
282 if (m_getDataTrainArgs.CancelEvent.WaitOne(0))
283 break;
284
285 if (m_mycaffe.CancelEvent.WaitOne(0))
286 break;
287
288 List<DataPoint> rgHistory = m_getDataTrainArgs.State.History;
289 DataPoint dpLast = (rgHistory.Count > 0) ? rgHistory.Last() : null;
290
291 if (dpLast != null)
292 fTargetY = dpLast.Target;
293 else
294 fTargetY = -1;
295
296 if (rgHistory.Count >= nInputDim)
297 {
298 for (int j = 0; j < nInputDim; j++)
299 {
300 int nIdx = rgHistory.Count - nInputDim + j;
301 rgInput[j] = rgHistory[nIdx].Inputs[0];
302 rgTimeSteps[j] = rgHistory[nIdx].Time;
303 rgMask[j] = rgHistory[nIdx].Mask[0];
304 rgTarget[0] = rgHistory[nIdx].Target;
305 }
306
307 blobX.mutable_cpu_data = Utility.ConvertVec<T>(rgInput);
308 blobTt.mutable_cpu_data = Utility.ConvertVec<T>(rgTimeSteps);
309 blobMask.mutable_cpu_data = Utility.ConvertVec<T>(rgMask);
310 blobTarget.mutable_cpu_data = Utility.ConvertVec<T>(rgTarget);
311
312 net.Forward(out dfLoss);
313
314 if (phase == Phase.TRAIN)
315 {
316 net.Backward();
317 solver.Step(1);
318 }
319
320 float[] rgOutput = Utility.ConvertVecF<T>(blobXhat.mutable_cpu_data);
321 fPredictedY = rgOutput[0];
322
323 prop.SetProperty("override_prediction", fPredictedY.ToString());
324 }
325 else
326 {
327 Thread.Sleep(50);
328 }
329
330 rgAccHistory.Add(new Tuple<float, float>(fTargetY, fPredictedY));
331 if (rgAccHistory.Count > 100)
332 rgAccHistory.RemoveAt(0);
333
334 dfAccuracy = computeAccuracy(rgAccHistory, 0.005f);
335
336 updateStatus(i, nN, dfAccuracy, dfLoss, solver.parameter.base_lr);
337 }
338
339 return false;
340 }
341 }
342}
The MyCaffeControl is the main object used to manage all training, testing and running of the MyCaffe...
CancelEvent CancelEvent
Returns the CancelEvent used.
Net< T > GetInternalNet(Phase phase=Phase.RUN)
Returns the internal net based on the Phase specified: TRAIN, TEST or RUN.
Solver< T > GetInternalSolver()
Get the internal solver.
Log Log
Returns the Log (for output) used.
The BucketCollection contains a set of Buckets.
void Reset()
Resets the event clearing any signaled state.
Definition: CancelEvent.cs:279
bool WaitOne(int nMs=int.MaxValue)
Waits for the signal state to occur.
Definition: CancelEvent.cs:290
void Set()
Sets the event to the signaled state.
Definition: CancelEvent.cs:270
The CryptoRandom is a random number generator that can use either the standard .Net Random objec or t...
Definition: CryptoRandom.cs:14
Specifies a key-value pair of properties.
Definition: PropertySet.cs:16
string GetProperty(string strName, bool bThrowExceptions=true)
Returns a property as a string value.
Definition: PropertySet.cs:146
bool GetPropertyAsBool(string strName, bool bDefault=false)
Returns a property as a boolean value.
Definition: PropertySet.cs:267
void SetProperty(string strName, string strVal)
Sets a property in the property set to a value if it exists, otherwise it adds the new property.
Definition: PropertySet.cs:211
The Utility class provides general utility funtions.
Definition: Utility.cs:35
static double[] ConvertVec(float[] rgf)
Convert an array of float to an array of generics.
Definition: Utility.cs:550
The Blob is the main holder of data that moves through the Layers of the Net.
Definition: Blob.cs:25
T[] mutable_cpu_data
Get data from the GPU and bring it over to the host, or Set data from the Host and send it over to th...
Definition: Blob.cs:1461
int count()
Returns the total number of items in the Blob.
Definition: Blob.cs:739
Connects Layer's together into a direct acrylic graph (DAG) specified by a NetParameter
Definition: Net.cs:23
List< Layer< T > > layers
Returns the layers.
Definition: Net.cs:2003
BlobCollection< T > Forward()
Run forward with the input Blob's already fed separately.
Definition: Net.cs:1445
void Backward(int nStart=int.MaxValue, int nEnd=0)
The network backward should take no input and output, since it solely computes the gradient w....
Definition: Net.cs:1499
Blob< T > blob_by_name(string strName, bool bThrowExceptionOnError=true)
Returns a blob given its name.
Definition: Net.cs:2245
The DataPoint contains the data used when training.
Definition: Interfaces.cs:236
float Target
Returns the target value.
Definition: Interfaces.cs:287
The InputLayer provides data to the Net by assigning top Blobs directly. This layer is initialized wi...
Definition: InputLayer.cs:22
LayerParameter layer_param
Returns the LayerParameter for this Layer.
Definition: Layer.cs:899
List< BlobShape > shape
Define N shapes to set a shape for each top. Define 1 shape to set the same shape for every top....
InputParameter input_param
Returns the parameter set when initialized with LayerType.INPUT
double base_lr
The base learning rate (default = 0.01).
An interface for classes that perform optimization on Nets - this class serves as the base class for ...
Definition: Solver.cs:28
SolverParameter parameter
Returns the SolverParameter used.
Definition: Solver.cs:1221
bool Step(int nIters, TRAIN_STEP step=TRAIN_STEP.NONE, bool bZeroDiffs=true, bool bApplyUpdates=true, bool bDisableOutput=false, bool bDisableProgress=false, double? dfLossOverride=null, bool? bAllowSnapshot=null)
Steps a set of iterations through a training cycle.
Definition: Solver.cs:818
The GetDataArgs is passed to the OnGetData event to retrieve data.
Definition: EventArgs.cs:402
CancelEvent CancelEvent
Returns the cancel event.
Definition: EventArgs.cs:509
PropertySet ExtraProperties
Get/set extra properties.
Definition: EventArgs.cs:458
bool Reset
Returns whether or not to reset the observation environment or not.
Definition: EventArgs.cs:543
int Action
Returns the action to run. If less than zero, this parameter is ignored.
Definition: EventArgs.cs:526
StateBase State
Specifies the state data of the observations.
Definition: EventArgs.cs:517
The GetStatusArgs is passed to the OnGetStatus event.
Definition: EventArgs.cs:166
The InitializeArgs is passed to the OnInitialize event.
Definition: EventArgs.cs:90
List< DataPoint > History
Get/set the data history (if any exists).
Definition: StateBase.cs:81
The WaitArgs is passed to the OnWait event.
Definition: EventArgs.cs:65
The TrainerRNNSimple implements a very simple RNN trainer inspired by adepierre's GitHub site referen...
bool Train(int nN, ITERATOR_TYPE type, TRAIN_STEP step)
Train the network using a modified PG training algorithm optimized for GPU use.
TrainerRNNSimple(MyCaffeControl< T > mycaffe, PropertySet properties, CryptoRandom random, IxTrainerCallback icallback, BucketCollection rgVocabulary)
The constructor.
byte[] Run(int nN, PropertySet runProp, out string type)
Run a single cycle on the environment after the delay.
void Dispose()
Releases all resources used.
bool Initialize()
Initialize the trainer.
bool Test(int nN, ITERATOR_TYPE type)
Run the test cycle - currently this is not implemented.
float[] Run(int nN, PropertySet runProp)
Run a single cycle on the environment after the delay.
bool Shutdown(int nWait)
Shutdown the trainer.
The IxTrainerCallback provides functions used by each trainer to 'call-back' to the parent for inform...
Definition: Interfaces.cs:303
The IxTrainerRL interface is implemented by each RL Trainer.
Definition: Interfaces.cs:279
The MyCaffe.basecode contains all generic types used throughout MyCaffe.
Definition: Annotation.cs:12
Phase
Defines the Phase under which to run a Net.
Definition: Interfaces.cs:61
The MyCaffe.common namespace contains common MyCaffe classes.
Definition: BatchInput.cs:8
TRAIN_STEP
Defines the training stepping method (if any).
Definition: Interfaces.cs:131
The MyCaffe.fillers namespace contains all fillers including the Filler class.
The MyCaffe.gym namespace contains all classes related to the Gym's supported by MyCaffe.
The MyCaffe.layers namespace contains all layers that have a solidified code base,...
Definition: LayerFactory.cs:15
The MyCaffe.param namespace contains parameters used to create models.
The MyCaffe.solvers namespace contains all solver classes, including the base Solver.
ITERATOR_TYPE
Specifies the iterator type to use.
Definition: Interfaces.cs:22
The MyCaffe namespace contains the main body of MyCaffe code that closesly tracks the C++ Caffe open-...
Definition: Annotation.cs:12