MyCaffe  1.12.2.41
Deep learning software for Windows C# programmers.
SGDSolver.cs
1using System;
2using System.Collections.Generic;
3using System.Linq;
4using System.Text;
5using System.Threading;
6using System.IO;
7using MyCaffe.basecode;
8using MyCaffe.db.image;
9using MyCaffe.common;
10using MyCaffe.param;
11
12namespace MyCaffe.solvers
13{
21 public class SGDSolver<T> : Solver<T>
22 {
27
31 // protected BlobCollection<T> m_colUpdate = new BlobCollection<T>(); // not used in GPU version
32
38
55 public SGDSolver(CudaDnn<T> cuda, Log log, SolverParameter p, CancelEvent evtCancel, AutoResetEvent evtForceSnapshot, AutoResetEvent evtForceTest, IXDatabaseBase db, IXPersist<T> persist, int nSolverCount = 1, int nSolverRank = 0, Net<T> shareNet = null, onGetWorkspace getws = null, onSetWorkspace setws = null)
56 : base(cuda, log, p, evtCancel, evtForceSnapshot, evtForceTest, db, persist, nSolverCount, nSolverRank, shareNet, getws, setws)
57 {
58 PreSolve();
59 }
60
64 protected override void dispose()
65 {
66 if (m_colHistory != null)
67 {
68 m_colHistory.Dispose();
69 m_colHistory = null;
70 }
71
72 if (m_colTemp != null)
73 {
74 m_colTemp.Dispose();
75 m_colTemp = null;
76 }
77
78 base.dispose();
79 }
80
85 {
86 get { return m_colHistory; }
87 }
88
92 public void PreSolve()
93 {
94 BlobCollection<T> colNetParams = m_net.learnable_parameters;
95 m_colHistory.Clear(true);
96// m_colUpdate.Clear(true);
97 m_colTemp.Clear(true);
98
99 for (int i = 0; i < colNetParams.Count; i++)
100 {
101 List<int> rgShape = colNetParams[i].shape();
102
103 m_colHistory.Add(new Blob<T>(m_cuda, m_log, rgShape, false)); // diff never used
104// m_colUpdate.Add(new Blob<T>(m_cuda, m_log, rgShape, false));
105 m_colTemp.Add(new Blob<T>(m_cuda, m_log, rgShape, false)); // diff never used
106 }
107 }
108
129 public double GetLearningRate(int nIterationOverride = -1)
130 {
131 double dfRate = 0;
132
133 if (nIterationOverride == -1)
134 nIterationOverride = m_nIter;
135
136 switch (m_param.lr_policy)
137 {
138 case "fixed":
139 dfRate = m_param.base_lr;
140 break;
141
142 case "step":
143 m_log.CHECK_GT(m_param.stepsize, 0, "The stepsize must be greater than 0.");
144 m_nCurrentStep = nIterationOverride / m_param.stepsize;
145 m_log.CHECK_GE(m_param.gamma, 0, "The gamma must be greater than or equal to 0.");
146 dfRate = m_param.base_lr * Math.Pow(m_param.gamma, m_nCurrentStep);
147 break;
148
149 case "exp":
150 m_log.CHECK_GE(m_param.gamma, 0, "The gamma must be greater than or equal to 0.");
151 dfRate = m_param.base_lr * Math.Pow(m_param.gamma, nIterationOverride);
152 break;
153
154 case "inv":
155 m_log.CHECK_GE(m_param.gamma, 0, "The gamma must be greater than or equal to 0.");
156 dfRate = m_param.base_lr * Math.Pow(1.0 + m_param.gamma * nIterationOverride, -1.0 * m_param.power);
157 break;
158
159 case "multistep":
160 if (m_nCurrentStep < m_param.stepvalue.Count && nIterationOverride >= m_param.stepvalue[m_nCurrentStep])
161 {
163 m_log.WriteLine("MultiStep Status: Iteration " + nIterationOverride.ToString() + ", step = " + m_nCurrentStep.ToString());
164 }
165 m_log.CHECK_GE(m_param.gamma, 0, "The gamma must be greater than or equal to 0.");
166 dfRate = m_param.base_lr * Math.Pow(m_param.gamma, m_nCurrentStep);
167 break;
168
169 case "poly":
170 dfRate = m_param.base_lr * Math.Pow(1.0 - ((double)nIterationOverride / (double)m_param.max_iter), m_param.power);
171 break;
172
173 case "sigmoid":
174 m_log.CHECK_GE(m_param.gamma, 0, "The gamma must be greater than or equal to 0.");
175 m_log.CHECK_GT(m_param.stepsize, 0, "The stepsize must be greater than 0.");
176 dfRate = m_param.base_lr * (1.0 / (1.0 + Math.Exp(-1.0 * m_param.gamma * nIterationOverride - m_param.stepsize)));
177 break;
178
179 default:
180 m_log.FAIL("Unknown learning rate policy: " + m_param.lr_policy);
181 break;
182 }
183
184 return dfRate;
185 }
186
192 public override double ApplyUpdate(int nIterationOverride = -1)
193 {
194 double dfRate = GetLearningRate(nIterationOverride);
195
196 if (LearningRateOverride > 0)
197 dfRate = LearningRateOverride;
198
199 if (m_param.display > 0 && (m_nIter % m_param.display) == 0)
200 {
201 string strOut = "Iteration " + m_nIter.ToString() + ", lr = " + dfRate.ToString() + ", Loss = " + m_dfSmoothedLoss.ToString();
202 if (m_dfIterAccuracy.HasValue)
203 strOut += ", Iter Accuracy = " + m_dfIterAccuracy.Value.ToString() + " (" + m_dfIterAccuracy.Value.ToString("P3") + ")";
204
205 m_log.WriteLine(strOut);
206 }
207
209
210 for (int i = 0; i < m_net.learnable_parameters.Count; i++)
211 {
212 Normalize(i);
213 Regularize(i);
214 ComputeUpdateValue(i, dfRate, nIterationOverride);
215 }
216
217 m_net.Update();
218
219 // Increment the internal iter_ counter -- its value should always indicate
220 // the number of times the weights have been updated.
221 m_nIter++;
222
223 return dfRate;
224 }
225
230 protected override void RestoreSolverState(byte[] rgState)
231 {
232 SolverState state = m_persist.LoadSolverState(rgState);
233
234 m_nIter = state.iter;
236
237 m_log.CHECK_EQ(state.history.Count, m_colHistory.Count, "Incorrect length of state history blobs.");
238 m_log.WriteLine("SGDSolver: restoring state history.");
239
240 for (int i = 0; i < m_colHistory.Count; i++)
241 {
242 m_colHistory[i].FromProto(state.history[i]);
243 }
244 }
245
250 protected override byte[] SnapshotSolverState()
251 {
252 SolverState state = new SolverState();
253 state.iter = m_nIter;
255
256 foreach (Blob<T> blob in m_colHistory)
257 {
258 state.history.Add(blob.ToProto());
259 }
260
261 return m_persist.SaveSolverState(state);
262 }
263
268 public virtual void Normalize(int param_id)
269 {
270 if (m_param.iter_size == 1)
271 return;
272
273 // Scale gradient to counterbalance accumulation.
274 BlobCollection<T> colNetParams = m_net.learnable_parameters;
275
276 if (!colNetParams[param_id].DiffExists)
277 return;
278
279 double dfAccumNormalization = 1.0 / m_param.iter_size;
280 m_cuda.scal(colNetParams[param_id].count(), dfAccumNormalization, colNetParams[param_id].mutable_gpu_diff);
281 }
282
287 public virtual void Regularize(int param_id)
288 {
289 BlobCollection<T> colNetParams = m_net.learnable_parameters;
290
291 if (!colNetParams[param_id].DiffExists)
292 return;
293
294 List<double?> rgNetParamWeightDecay = m_net.params_weight_decay;
295 double dfWeightDecay = m_param.weight_decay;
296 double dfLocalDecay = dfWeightDecay * rgNetParamWeightDecay[param_id].GetValueOrDefault(0);
297
298 if (dfLocalDecay > 0)
299 {
301 {
302 case "L2":
303 // add weight decay
304 m_cuda.axpy(colNetParams[param_id].count(), dfLocalDecay, colNetParams[param_id].gpu_data, colNetParams[param_id].mutable_gpu_diff);
305 break;
306
307 case "L1":
308 m_cuda.sign(colNetParams[param_id].count(), colNetParams[param_id].gpu_data, m_colTemp[param_id].mutable_gpu_data);
309 m_cuda.axpy(colNetParams[param_id].count(), dfLocalDecay, m_colTemp[param_id].gpu_data, colNetParams[param_id].mutable_gpu_diff);
310 break;
311 }
312 }
313 }
314
321 public virtual void ComputeUpdateValue(int param_id, double dfRate, int nIterationOverride = -1)
322 {
323 BlobCollection<T> colNetParams = m_net.learnable_parameters;
324
325 if (!colNetParams[param_id].DiffExists)
326 return;
327
328 List<double?> net_params_lr = m_net.params_lr;
329 T fMomentum = Utility.ConvertVal<T>(m_param.momentum);
330 T fLocalRate = Utility.ConvertVal<T>(dfRate * net_params_lr[param_id].GetValueOrDefault(0));
331
332 // Compute the update to history, then copy it to the parameter diff.
333 if (m_colHistory != null)
334 m_cuda.sgd_update(colNetParams[param_id].count(), colNetParams[param_id].mutable_gpu_diff, m_colHistory[param_id].mutable_gpu_data, fMomentum, fLocalRate);
335 }
336
340 public virtual void ClipGradients()
341 {
342 double dfClipGradients = m_param.clip_gradients;
343
344 if (dfClipGradients < 0)
345 return;
346
347 BlobCollection<T> colNetParams = m_net.learnable_parameters;
348 double dfSumsqDiff = 0;
349
350 for (int i = 0; i < colNetParams.Count; i++)
351 {
352 if (colNetParams[i].DiffExists)
353 dfSumsqDiff += Utility.ConvertVal<T>(colNetParams[i].sumsq_diff());
354 }
355
356 double dfL2NormDiff = Math.Sqrt(dfSumsqDiff);
357
358 if (dfL2NormDiff > dfClipGradients)
359 {
360 double dfScaleFactor = dfClipGradients / dfL2NormDiff;
361
363 m_log.WriteLine("Gradient clipping: scaling down gradients (L2 norm " + dfL2NormDiff.ToString() + " > " + dfClipGradients.ToString() + ") by scale factor " + dfScaleFactor.ToString());
364
365 for (int i = 0; i < colNetParams.Count; i++)
366 {
367 if (colNetParams[i].DiffExists)
368 colNetParams[i].scale_diff(Utility.ConvertVal<T>(dfScaleFactor));
369 }
370 }
371 }
372 }
373}
The CancelEvent provides an extension to the manual cancel event that allows for overriding the manua...
Definition: CancelEvent.cs:17
The Log class provides general output in text form.
Definition: Log.cs:13
void WriteLine(string str, bool bOverrideEnabled=false, bool bHeader=false, bool bError=false, bool bDisable=false)
Write a line of output.
Definition: Log.cs:80
void FAIL(string str)
Causes a failure which throws an exception with the desciptive text.
Definition: Log.cs:394
void CHECK_EQ(double df1, double df2, string str)
Test whether one number is equal to another.
Definition: Log.cs:239
void CHECK_GT(double df1, double df2, string str)
Test whether one number is greater than another.
Definition: Log.cs:299
void CHECK_GE(double df1, double df2, string str)
Test whether one number is greater than or equal to another.
Definition: Log.cs:287
The Utility class provides general utility funtions.
Definition: Utility.cs:35
The BlobCollection contains a list of Blobs.
int Count
Returns the number of items in the collection.
The Blob is the main holder of data that moves through the Layers of the Net.
Definition: Blob.cs:25
BlobProto ToProto(bool bWriteDiff=false)
Writes the Blob to a new BlobProto.
Definition: Blob.cs:1663
The CudaDnn object is the main interface to the Low-Level Cuda C++ DLL.
Definition: CudaDnn.cs:969
Connects Layer's together into a direct acrylic graph (DAG) specified by a NetParameter
Definition: Net.cs:23
The SolverParameter is a parameter for the solver, specifying the train and test networks.
int stepsize
The stepsize for learning rate policy 'step'.
int max_iter
The maximum number of iterations.
string regularization_type
Specifies the regularization type (default = 'L2').
string lr_policy
The learning rate decay policy.
double power
The 'power' parameter to compute the learning rate.
bool enable_clip_gradient_status
Optionally, enable status output when gradients are clipped (default = true)
int iter_size
Accumulate gradients over 'iter_size' x 'batch_size' instances.
double gamma
Specifies the 'gamma' parameter to compute the 'step', 'exp', 'inv', and 'sigmoid' learning policy (d...
int display
The number of iterations between displaying info. If display = 0, no info will be displayed.
double weight_decay
Specifies the weight decay (default = 0.0005).
List< int > stepvalue
The step values for learning rate policy 'multistep'.
double momentum
Specifies the momentum value - used by all solvers EXCEPT the 'AdaGrad' and 'RMSProp' solvers....
double base_lr
The base learning rate (default = 0.01).
double clip_gradients
Set clip_gradients to >= 0 to clip parameter gradients to that L2 norm, whenever their actual L2 norm...
The SolverState specifies the state of a given solver.
Definition: SolverState.cs:17
int iter
The current iteration.
Definition: SolverState.cs:40
List< BlobProto > history
The history for SGD solvers.
Definition: SolverState.cs:67
int current_step
The current step for learning rate.
Definition: SolverState.cs:76
Stochastic Gradient Descent solver with momentum updates weights by a linear combination of the negat...
Definition: SGDSolver.cs:22
virtual void ComputeUpdateValue(int param_id, double dfRate, int nIterationOverride=-1)
Compute the SGD update value that will be applied to a learnable blobs in the training Net.
Definition: SGDSolver.cs:321
BlobCollection< T > m_colHistory
History maintains the historical momentum data.
Definition: SGDSolver.cs:26
BlobCollection< T > history
Returns the history BlobCollection containing historical momentum data.
Definition: SGDSolver.cs:85
override void dispose()
Releases all resources (GPU and Host) used by the Solver.
Definition: SGDSolver.cs:64
override double ApplyUpdate(int nIterationOverride=-1)
Compute the update values and apply them to the training Net.
Definition: SGDSolver.cs:192
void PreSolve()
Runs the pre-solve which prepares the Solver to start Solving.
Definition: SGDSolver.cs:92
override void RestoreSolverState(byte[] rgState)
Restore the state of the Solver.
Definition: SGDSolver.cs:230
virtual void Normalize(int param_id)
Normalize a learnable Blob of the training Net.
Definition: SGDSolver.cs:268
SGDSolver(CudaDnn< T > cuda, Log log, SolverParameter p, CancelEvent evtCancel, AutoResetEvent evtForceSnapshot, AutoResetEvent evtForceTest, IXDatabaseBase db, IXPersist< T > persist, int nSolverCount=1, int nSolverRank=0, Net< T > shareNet=null, onGetWorkspace getws=null, onSetWorkspace setws=null)
The SGDSolver constructor.
Definition: SGDSolver.cs:55
virtual void ClipGradients()
Clip the gradients of all learnable blobs in the training Net.
Definition: SGDSolver.cs:340
override byte[] SnapshotSolverState()
Take a snapshot of the Solver state.
Definition: SGDSolver.cs:250
BlobCollection< T > m_colTemp
Update maintains update related data and is not needed in snapshots.
Definition: SGDSolver.cs:37
double GetLearningRate(int nIterationOverride=-1)
Return the current learning rate.
Definition: SGDSolver.cs:129
virtual void Regularize(int param_id)
Regularize a learnable Blob of the training net.
Definition: SGDSolver.cs:287
An interface for classes that perform optimization on Nets - this class serves as the base class for ...
Definition: Solver.cs:28
double m_dfSmoothedLoss
Specifies the smoothed loss protected for derived classes to use.
Definition: Solver.cs:70
SolverParameter m_param
Specifies the SolverParameter that defines how the Solver operates.
Definition: Solver.cs:40
CudaDnn< T > m_cuda
Specifies the instance of CudaDnn used by the Solver that provides a connection to Cuda.
Definition: Solver.cs:32
double? m_dfIterAccuracy
Specifies the iteration accuracy calculated when a blob exists with the name 'accuracy'.
Definition: Solver.cs:74
double LearningRateOverride
Get/set the learning rate override. When 0, this setting is ignored.
Definition: Solver.cs:227
int m_nIter
Specifies the current iteration.
Definition: Solver.cs:52
IXPersist< T > m_persist
Specifies the persistance object used to save weight and solver states.
Definition: Solver.cs:90
Net< T > m_net
Specifies the training Net.
Definition: Solver.cs:44
int m_nCurrentStep
Specifies the current step.
Definition: Solver.cs:56
Log m_log
Specifies the Log for output.
Definition: Solver.cs:36
The IXDatabaseBase interface defines the general interface to the in-memory database.
Definition: Interfaces.cs:444
The IXPersist interface is used by the CaffeControl to load and save weights.
Definition: Interfaces.cs:187
The MyCaffe.basecode contains all generic types used throughout MyCaffe.
Definition: Annotation.cs:12
The MyCaffe.common namespace contains common MyCaffe classes.
Definition: BatchInput.cs:8
The MyCaffe.db.image namespace contains all image database related classes.
Definition: Database.cs:18
The MyCaffe.param namespace contains parameters used to create models.
The MyCaffe.solvers namespace contains all solver classes, including the base Solver.
The MyCaffe namespace contains the main body of MyCaffe code that closesly tracks the C++ Caffe open-...
Definition: Annotation.cs:12