MyCaffe  1.12.2.41
Deep learning software for Windows C# programmers.
AdamWSolver.cs
1using System;
2using System.Collections.Generic;
3using System.Linq;
4using System.Text;
5using System.Threading;
6using System.IO;
7using MyCaffe.basecode;
8using MyCaffe.db.image;
9using MyCaffe.common;
10using MyCaffe.param;
11
12namespace MyCaffe.solvers
13{
22 public class AdamWSolver<T> : SGDSolver<T>
23 {
24 double m_dfDetachedWeightDecayRate = 0.0f;
25
42 public AdamWSolver(CudaDnn<T> cuda, Log log, SolverParameter p, CancelEvent evtCancel, AutoResetEvent evtForceSnapshot, AutoResetEvent evtForceTest, IXDatabaseBase db, IXPersist<T> persist, int nSolverCount = 1, int nSolverRank = 0, Net<T> shareNet = null, onGetWorkspace getws = null, onSetWorkspace setws = null)
43 : base(cuda, log, p, evtCancel, evtForceSnapshot, evtForceTest, db, persist, nSolverCount, nSolverRank, shareNet, getws, setws)
44 {
45 m_dfDetachedWeightDecayRate = p.adamw_decay;
47 }
48
52 public virtual void AdamPreSolve()
53 {
54 // Add the extra history entries for AdaDelta after those from
55 // SGDSolver::PreSolve
56 BlobCollection<T> colNetParams = m_net.learnable_parameters;
57
58 for (int i = 0; i < colNetParams.Count; i++)
59 {
60 List<int> rgShape = colNetParams[i].shape();
61 Blob<T> blob = new Blob<T>(m_cuda, m_log, rgShape);
62 m_colHistory.Add(blob);
63 }
64 }
65
72 public override void ComputeUpdateValue(int param_id, double dfRate, int nIterationOverride = -1)
73 {
74 BlobCollection<T> colNetParams = m_net.learnable_parameters;
75
76 if (!colNetParams[param_id].DiffExists)
77 return;
78
79 if (nIterationOverride == -1)
80 nIterationOverride = m_nIter;
81
82 List<double?> net_params_lr = m_net.params_lr;
83 double dfLocalRate = dfRate * net_params_lr[param_id].GetValueOrDefault(0);
84 List<double?> net_params_decay = net.params_weight_decay;
85 double dfLocalDecay = m_dfDetachedWeightDecayRate * net_params_decay[param_id].GetValueOrDefault(0);
86 double dfBeta1 = m_param.momentum;
87 T fBeta1 = Utility.ConvertVal<T>(dfBeta1);
88 double dfBeta2 = m_param.momentum2;
89 T fBeta2 = Utility.ConvertVal<T>(dfBeta2);
90
91 // we create aliases for convienience
92 int nUpdateHistoryOffset = colNetParams.Count;
93 Blob<T> val_m = m_colHistory[param_id];
94 Blob<T> val_v = m_colHistory[param_id + nUpdateHistoryOffset];
95
96 int nT = nIterationOverride + 1;
97 // Set the schedule multiplier
98 int nN = colNetParams[param_id].count();
99 double dfEpsHat = m_param.delta;
100
101 // Compute the update to history, then copy it to the parameter diff.
102 m_cuda.adamw_update(nN,
103 colNetParams[param_id].mutable_gpu_diff,
104 val_m.mutable_gpu_data,
105 val_v.mutable_gpu_data,
106 fBeta1,
107 fBeta2,
108 Utility.ConvertVal<T>(dfEpsHat),
109 Utility.ConvertVal<T>(dfLocalRate),
110 Utility.ConvertVal<T>(dfLocalDecay),
111 colNetParams[param_id].mutable_gpu_data,
112 nT);
113 }
114 }
115}
The CancelEvent provides an extension to the manual cancel event that allows for overriding the manua...
Definition: CancelEvent.cs:17
The Log class provides general output in text form.
Definition: Log.cs:13
The Utility class provides general utility funtions.
Definition: Utility.cs:35
The BlobCollection contains a list of Blobs.
int Count
Returns the number of items in the collection.
The Blob is the main holder of data that moves through the Layers of the Net.
Definition: Blob.cs:25
long mutable_gpu_data
Returns the data GPU handle used by the CudaDnn connection.
Definition: Blob.cs:1487
The CudaDnn object is the main interface to the Low-Level Cuda C++ DLL.
Definition: CudaDnn.cs:969
Connects Layer's together into a direct acrylic graph (DAG) specified by a NetParameter
Definition: Net.cs:23
The SolverParameter is a parameter for the solver, specifying the train and test networks.
double delta
Numerical stability for RMSProp, AdaGrad, AdaDelta, Adam and AdamW solvers (default = 1e-08).
double momentum2
An additional momentum property for the Adam and AdamW solvers (default = 0.999).
double adamw_decay
Specifies the 'AdamW' detached weight decay value used by the 'AdamW' solver (default = 0....
double momentum
Specifies the momentum value - used by all solvers EXCEPT the 'AdaGrad' and 'RMSProp' solvers....
Use AdamW Solver which uses gradient based optimization like Adam with a decoupled weight decay.
Definition: AdamWSolver.cs:23
AdamWSolver(CudaDnn< T > cuda, Log log, SolverParameter p, CancelEvent evtCancel, AutoResetEvent evtForceSnapshot, AutoResetEvent evtForceTest, IXDatabaseBase db, IXPersist< T > persist, int nSolverCount=1, int nSolverRank=0, Net< T > shareNet=null, onGetWorkspace getws=null, onSetWorkspace setws=null)
The AdamSolver constructor.
Definition: AdamWSolver.cs:42
override void ComputeUpdateValue(int param_id, double dfRate, int nIterationOverride=-1)
Compute the AdamWSolver update value that will be applied to a learnable blobs in the training Net.
Definition: AdamWSolver.cs:72
virtual void AdamPreSolve()
Runs the AdamSolver pre-solve which parpares the Solver to start Solving.
Definition: AdamWSolver.cs:52
Stochastic Gradient Descent solver with momentum updates weights by a linear combination of the negat...
Definition: SGDSolver.cs:22
BlobCollection< T > m_colHistory
History maintains the historical momentum data.
Definition: SGDSolver.cs:26
SolverParameter m_param
Specifies the SolverParameter that defines how the Solver operates.
Definition: Solver.cs:40
CudaDnn< T > m_cuda
Specifies the instance of CudaDnn used by the Solver that provides a connection to Cuda.
Definition: Solver.cs:32
Net< T > net
Returns the main training Net.
Definition: Solver.cs:1229
int m_nIter
Specifies the current iteration.
Definition: Solver.cs:52
Net< T > m_net
Specifies the training Net.
Definition: Solver.cs:44
Log m_log
Specifies the Log for output.
Definition: Solver.cs:36
The IXDatabaseBase interface defines the general interface to the in-memory database.
Definition: Interfaces.cs:444
The IXPersist interface is used by the CaffeControl to load and save weights.
Definition: Interfaces.cs:187
The MyCaffe.basecode contains all generic types used throughout MyCaffe.
Definition: Annotation.cs:12
The MyCaffe.common namespace contains common MyCaffe classes.
Definition: BatchInput.cs:8
The MyCaffe.db.image namespace contains all image database related classes.
Definition: Database.cs:18
The MyCaffe.param namespace contains parameters used to create models.
The MyCaffe.solvers namespace contains all solver classes, including the base Solver.
The MyCaffe namespace contains the main body of MyCaffe code that closesly tracks the C++ Caffe open-...
Definition: Annotation.cs:12