Deep learning software for Windows C# programmers.
1using System;
2using System.Collections.Generic;
3using System.Linq;
4using System.Text;
5using System.Threading;
6using System.IO;
7using System.Diagnostics;
8using System.Collections;
9using MyCaffe.basecode;
10using MyCaffe.db.image;
11using MyCaffe.common;
12using MyCaffe.param;
17namespace MyCaffe.solvers
27 public abstract class Solver<T> : IDisposable
28 {
32 protected CudaDnn<T> m_cuda;
36 protected Log m_log;
44 protected Net<T> m_net;
48 protected List<Net<T>> m_rgTestNets = new List<Net<T>>();
52 protected int m_nIter;
56 protected int m_nCurrentStep;
60 protected List<double> m_rgLosses = new List<double>();
61 AutoResetEvent m_evtCompleted = new AutoResetEvent(false);
62 bool m_bEnableTest = true;
63 bool m_bEnableBlobDebugging = false;
64 bool m_bEnableBreakOnNan = false;
65 bool m_bEnableDetailedNanDetection = false;
66 bool m_bEnableSingleStep = false;
70 protected double m_dfSmoothedLoss = 0;
74 protected double? m_dfIterAccuracy = null;
75 Blob<T> m_blobAccuracy = null;
76 CancelEvent m_evtCancel;
77 AutoResetEvent m_evtForceSnapshot;
78 AutoResetEvent m_evtForceTest;
82 protected int m_nSolverCount = 1;
86 protected int m_nSolverRank = 0;
94 protected double m_dfLearningRateOverride = 0;
95 double m_dfLastAccuracy = 0;
96 double m_dfLastError = double.MaxValue;
97 double m_dfBestAccuracy = 0;
98 double m_dfBestError = double.MaxValue;
99 IXDatabaseBase m_db = null;
100 int m_nTrainingIterationOverride = -1;
101 int m_nTestingIterationOverride = -1;
102 object m_tag = null;
103 bool m_bWeightsUpdated = false;
104 static object m_syncGetRi = new object();
105 Blob<T> m_blobBatchInputData = null;
106 double m_dfAverageTestTime = 0;
108 int m_nTrainingTimeLimitInMinutes = 0;
109 long m_hWorkspaceData = 0; // shared among the layers and nets, only grows in size.
110 ulong m_lWorkspaceSizeInBytes = 0;
111 bool m_bFirstNanError = true;
112 List<double> m_rgAverageAccuracyWindow = null;
113 bool m_bForceTest = false;
118 public event EventHandler OnStart;
122 public event EventHandler OnAborted;
126 public event EventHandler<GradientsReadyArgs> OnGradientsReady;
130 public event EventHandler<SnapshotArgs> OnSnapshot;
134 public event EventHandler<TrainingIterationArgs<T>> OnTrainingIteration;
138 public event EventHandler<TestingIterationArgs<T>> OnTestingIteration;
142 public event EventHandler<TestResultArgs<T>> OnTestResults;
146 public event EventHandler<TestArgs> OnTest;
150 public event EventHandler OnTestStart;
155 public event EventHandler<CustomForwardBackArgs<T>> OnCustomForwardBack;
159 public event EventHandler<WorkspaceArgs> OnGetWorkspace;
163 public event EventHandler<WorkspaceArgs> OnSetWorkspace;
181 public Solver(CudaDnn<T> cuda, Log log, SolverParameter p, CancelEvent evtCancel, AutoResetEvent evtForceSnapshot, AutoResetEvent evtForceTest, IXDatabaseBase db, IXPersist<T> persist, int nSolverCount = 1, int nSolverRank = 0, Net<T> shareNet = null, onGetWorkspace getws = null, onSetWorkspace setws = null)
182 {
183 m_cuda = cuda;
184 m_log = log;
185 m_evtCancel = evtCancel;
186 m_evtForceSnapshot = evtForceSnapshot;
187 m_evtForceTest = evtForceTest;
189 if (m_log.IsEnabled)
192 m_db = db;
193 m_persist = persist;
194 m_nSolverCount = nSolverCount;
195 m_nSolverRank = nSolverRank;
197 if (getws != null)
198 OnGetWorkspace += new EventHandler<WorkspaceArgs>(getws);
200 if (setws != null)
201 OnSetWorkspace += new EventHandler<WorkspaceArgs>(setws);
203 if (p.accuracy_average_window > 0)
204 {
205 m_rgAverageAccuracyWindow = new List<double>();
206 for (int i = 0; i < p.accuracy_average_window; i++)
207 {
208 m_rgAverageAccuracyWindow.Add(0);
209 }
210 }
212 Init(p, shareNet);
213 }
218 public void Dispose()
219 {
220 dispose();
221 }
227 {
228 get { return m_dfLearningRateOverride; }
229 set { m_dfLearningRateOverride = value; }
230 }
237 {
238 int nTimingCount = 0;
239 double dfTotalTime = 0;
240 return fireOnTrainingIterationEvent(false, 0, 0, ref nTimingCount, ref dfTotalTime);
241 }
243 private bool fireOnTrainingIterationEvent(bool bFwdPassNanFree, double dfLoss, double dfLastLearningRate, ref int nTimingCount, ref double dfTotalTime)
244 {
245 if (is_root_solver && OnTrainingIteration != null)
246 {
247 string strFirstNanBlob = null;
248 DebugInformation<T> dbgInfo = null;
250 if (m_bEnableBlobDebugging)
251 {
252 dbgInfo = TrainingNet.GetDebugInformation(m_bEnableDetailedNanDetection);
254 if (m_bEnableBreakOnNan && dbgInfo != null)
255 {
256 string strType;
257 strFirstNanBlob = dbgInfo.DetectFirstNaN(out strType);
259 if (strFirstNanBlob != null)
260 {
261 string strPass = (!bFwdPassNanFree) ? "Forward" : "Backward";
262 m_log.WriteLine("First NaN detected in the '" + strType + "' of blob '" + strFirstNanBlob + "' after " + strPass + " pass.");
264 string strTypeLast;
265 string strLastNanBlob = dbgInfo.DetectLastNaN(out strTypeLast);
267 if (strLastNanBlob != strFirstNanBlob && strType != strTypeLast)
268 m_log.WriteLine("Last NaN detected in the '" + strTypeLast + "' of blob '" + strLastNanBlob + "' after " + strPass + " pass.");
269 }
270 }
271 }
273 double dfTime = (nTimingCount > 0) ? (dfTotalTime / nTimingCount) : 0;
274 OnTrainingIteration(this, new TrainingIterationArgs<T>(m_nIter, m_dfLastAccuracy, dfLoss, m_dfSmoothedLoss, m_dfBestError, m_bWeightsUpdated, m_net.ActiveLabelCounts, m_net.LabelQueryHitPercents, m_net.LabelQueryEpochs, m_net.BoostQueryHitPercents, dfLastLearningRate, dfTime, dbgInfo));
275 dfTotalTime = 0;
276 nTimingCount = 0;
278 if (strFirstNanBlob != null)
279 {
280 m_log.WriteLine("Training is now stopping at iteration " + m_nIter.ToString("N0") + " as the first NaN has been detected ('" + strFirstNanBlob + "').");
281 return false;
282 }
283 }
285 return true;
286 }
292 {
293 get { return m_nTrainingTimeLimitInMinutes; }
294 set { m_nTrainingTimeLimitInMinutes = value; }
295 }
301 {
302 get { return m_snapshotWeightUpdatemMethod; }
303 set { m_snapshotWeightUpdatemMethod = value; }
304 }
310 {
311 get { return m_db; }
312 }
317 protected virtual void dispose()
318 {
319 if (m_net != null)
320 {
321 m_net.Dispose();
322 m_net = null;
323 }
325 foreach (Net<T> net in m_rgTestNets)
326 {
327 net.Dispose();
328 }
330 m_rgTestNets.Clear();
332 if (m_blobBatchInputData != null)
333 {
334 m_blobBatchInputData.Dispose();
335 m_blobBatchInputData = null;
336 }
338 if (m_hWorkspaceData != 0)
339 {
340 m_cuda.DisableGhostMemory();
341 m_cuda.FreeMemory(m_hWorkspaceData);
342 m_cuda.ResetGhostMemory();
343 m_hWorkspaceData = 0;
344 m_lWorkspaceSizeInBytes = 0;
345 }
346 }
351 public bool EnableTesting
352 {
353 get { return m_bEnableTest; }
354 set { m_bEnableTest = value; }
355 }
361 {
362 get { return m_bEnableBlobDebugging; }
363 set { m_bEnableBlobDebugging = value; }
364 }
373 {
374 get { return TrainingNet.EnableLayerDebugging; }
375 set { TrainingNet.EnableLayerDebugging = value; }
376 }
382 {
383 get { return m_bEnableBreakOnNan; }
384 set { m_bEnableBreakOnNan = value; }
385 }
395 {
396 get { return m_bEnableDetailedNanDetection; }
397 set { m_bEnableDetailedNanDetection = value; }
398 }
404 {
405 get { return m_bEnableSingleStep; }
406 set { m_bEnableSingleStep = value; }
407 }
412 public bool WeightsUpdated
413 {
414 get { return m_bWeightsUpdated; }
415 set { m_bWeightsUpdated = value; }
416 }
421 public object Tag
422 {
423 get { return m_tag; }
424 set { m_tag = value; }
425 }
431 {
432 get
433 {
434 if (m_rgTestNets.Count == 0)
435 return null;
437 return m_rgTestNets[0];
438 }
439 }
445 {
446 get { return m_net; }
447 }
454 public void Init(SolverParameter p, Net<T> shareNet = null)
455 {
456 m_log.WriteLine("Initializing solver from parameters: " + p.DebugString());
457 m_param = p;
458 m_log.CHECK_GE(m_param.average_loss, 1, "Average loss should be non-negative and >= 1.0.");
460 if (m_param.random_seed >= 0)
463 // Scaffolding code.
464 InitTrainNet(shareNet);
465 InitTestNets();
467 if (is_root_solver)
468 m_log.WriteLine("Solver scaffolding done.");
470 Reset();
472 m_log.WriteLine("INFO: Solver created for " + m_param.eval_type.ToString() + " (NOTE: Detection is only for SSD models).", true);
473 }
478 public void Reset()
479 {
480 m_nIter = 0;
481 m_nCurrentStep = 0;
482 }
488 protected void InitTrainNet(Net<T> shareNet = null)
489 {
490 try
491 {
492 int num_train_nets = ((m_param.net_param != null) ? 1 : 0) + ((m_param.train_net_param != null) ? 1 : 0);
493 string field_names = "net_param, train_net_param";
494 m_log.CHECK_GE(num_train_nets, 1, "SolverParameter must specify a train net using one of these fields: " + field_names);
495 m_log.CHECK_LE(num_train_nets, 1, "SolverParameter must not contain more than one of these fields specifying a train_net: " + field_names);
496 NetParameter net_param = null;
498 if (m_param.train_net_param != null)
499 {
500 m_log.WriteLine("Creating training net specified in train_net_param.");
501 net_param = m_param.train_net_param.Clone(true);
502 }
504 if (m_param.net_param != null)
505 {
506 m_log.WriteLine("Creating training net specified in net_param.");
507 net_param = m_param.net_param.Clone(true);
508 }
510 // Set the correct NetState. We start with the solver defaults (lowest
511 // precedence); then, merge in any NetState specified by the net_param itself;
512 // finally, merge in any NetState specified by the train-state (highest
513 // precedence).
514 NetState net_state = new NetState();
515 net_state.phase = Phase.TRAIN;
516 net_state.MergeFrom(net_param.state);
517 net_state.MergeFrom(m_param.train_state);
518 net_param.state = net_state;
519 net_param.solver_count = m_nSolverCount;
520 net_param.solver_rank = m_nSolverRank;
521 m_net = new Net<T>(m_cuda, m_log, net_param, m_evtCancel, m_db, Phase.NONE, m_evtCompleted, shareNet, net_OnGetWorkspace, net_OnSetWorkspace);
522 m_net.OnGetIteration += net_OnGetIteration;
524 m_blobAccuracy = m_net.FindBlob("accuracy");
525 }
526 catch(Exception excpt)
527 {
528 throw new Exception("Initializing Training Net: " + excpt.Message);
529 }
530 }
532 private void net_OnSetWorkspace(object sender, WorkspaceArgs e)
533 {
534 if (e.WorkspaceSizeInBytes == 0)
535 return;
537 if (OnSetWorkspace != null)
538 {
539 OnSetWorkspace(sender, e);
540 return;
541 }
543 m_cuda.DisableGhostMemory();
545 if (e.WorkspaceSizeInBytes > m_lWorkspaceSizeInBytes)
546 {
547 m_lWorkspaceSizeInBytes = e.WorkspaceSizeInBytes;
549 if (m_hWorkspaceData != 0)
550 m_cuda.FreeMemory(m_hWorkspaceData);
552 ulong lCount = CudaDnn<T>.ConvertByteSizeToCount(m_lWorkspaceSizeInBytes);
553 m_hWorkspaceData = m_cuda.AllocMemory((long)lCount);
554 }
556 m_cuda.ResetGhostMemory();
557 }
559 private void net_OnGetWorkspace(object sender, WorkspaceArgs e)
560 {
561 if (OnGetWorkspace != null)
562 {
563 OnGetWorkspace(sender, e);
564 return;
565 }
567 e.WorkspaceData = m_hWorkspaceData;
568 e.WorkspaceSizeInBytes = m_lWorkspaceSizeInBytes;
569 }
571 private void net_OnGetIteration(object sender, GetIterationArgs e)
572 {
573 e.SetIteration(Phase.TRAIN, m_nIter);
574 }
579 protected void InitTestNets()
580 {
581 try
582 {
583 int num_generic_nets = ((m_param.net_param != null) ? 1 : 0);
584 int num_test_net_params = m_param.test_net_param.Count;
585 int num_test_nets = num_test_net_params;
587 if (num_generic_nets > 0)
588 m_log.CHECK_GE(m_param.test_iter.Count, num_test_nets, "test_iter must be specified fore each test network.");
589 else
590 m_log.CHECK_EQ(m_param.test_iter.Count, num_test_nets, "test_iter must be specified fore each test network.");
592 // If we have a generic net (specified by net or net_param, rather than
593 // test_net or test_net_param), we may have an unlimited number of actual
594 // test networks -- the actual number is given by the number of remaining
595 // test_iters after any test nets specified by test_net_param and/or test_net
596 // are evaluated.
597 int num_generic_net_instances = m_param.test_iter.Count - num_test_nets;
598 int num_test_net_instances = num_test_nets + num_generic_net_instances;
600 if (m_param.test_state.Count > 0)
601 m_log.CHECK_EQ(m_param.test_state.Count, num_test_net_instances, "test_state must be unspecified or specified once per test net.");
603 if (num_test_net_instances > 0)
604 m_log.CHECK_GT(m_param.test_interval, 0, "The test interval must be greater than zero.");
606 List<string> sources = new List<string>();
607 List<NetParameter> net_params = new List<NetParameter>();
609 for (int i = 0; i < num_test_net_params; i++)
610 {
611 sources.Add("test_net_param");
612 net_params.Add(m_param.test_net_param[i].Clone());
613 }
615 int remaining_test_nets = m_param.test_iter.Count - num_test_net_params;
617 if (m_param.net_param != null)
618 {
619 for (int i = 0; i < remaining_test_nets; i++)
620 {
621 sources.Add("net_param");
622 net_params.Add(m_param.net_param.Clone());
623 }
624 }
626 m_rgTestNets = new List<Net<T>>();
628 for (int i = 0; i < num_test_net_instances; i++)
629 {
630 // Set the correct NetState. We start with the solver defaults (lowest
631 // precedence); then, merge in any NetState specified by the net_param
632 // itself; finally, merge in any NetState specified by the test_state
633 // (highest precedence).
634 NetState net_state = new NetState();
635 net_state.phase = Phase.TEST;
636 net_state.MergeFrom(net_params[i].state);
638 if (m_param.test_state.Count > 0)
639 net_state.MergeFrom(m_param.test_state[i]);
641 net_params[i].state = net_state;
643 m_log.WriteLine("Creating test net (#" + i.ToString() + ") specified by " + sources[i], true);
644 Net<T> net = new Net<T>(m_cuda, m_log, net_params[i], m_evtCancel, m_db, Phase.NONE, null, TrainingNet, net_OnGetWorkspace, net_OnSetWorkspace);
646 m_rgTestNets.Add(net);
647 m_rgTestNets[i].set_debug_info(m_param.debug_info);
648 }
649 }
650 catch (Exception excpt)
651 {
652 throw new Exception("Initializing Testing Nets: " + excpt.Message);
653 }
654 }
660 {
661 get { return m_cuda; }
662 }
667 public string ActiveLabelCounts
668 {
669 get { return m_net.ActiveLabelCounts; }
670 }
676 {
677 get { return m_net.LabelQueryHitPercents; }
678 }
683 public string LabelQueryEpochs
684 {
685 get { return m_net.LabelQueryEpochs; }
686 }
692 {
693 get { return m_nIter; }
694 }
700 {
701 get { return m_param.max_iter; }
702 }
708 {
709 get
710 {
711 int nIters = m_param.max_iter - m_nIter;
713 if (m_nTrainingIterationOverride > 0)
714 nIters = m_nTrainingIterationOverride;
716 return nIters;
717 }
718 }
724 {
725 get
726 {
727 int nIters = (m_param.test_iter.Count == 0) ? 0 : m_param.test_iter[0];
729 if (m_nTestingIterationOverride > 0)
730 nIters = m_nTestingIterationOverride;
732 return nIters;
733 }
734 }
744 public virtual void Solve(int nIterationOverride = -1, byte[] rgWeights = null, byte[] rgState = null, TRAIN_STEP step = TRAIN_STEP.NONE)
745 {
746 m_log.CHECK(is_root_solver, "Solve is only supported by the root solver.");
747 m_log.WriteLine("Solving " +;
748 m_log.WriteLine("Learing Rate Policy: " + m_param.lr_policy);
750 if (rgWeights != null || rgState != null)
751 Restore(rgWeights, rgState);
753 // For a network that is trained by the solver, no bottom or top vecs
754 // should be given, and we will just provide dummy vecs.
755 int start_iter = m_nIter;
757 if (nIterationOverride <= 0)
758 nIterationOverride = TrainingIterations;
760 if (!Step(nIterationOverride, step))
761 return;
763 // If we haven't already, save a snapshot after optimization, unless
764 // overriden by setting snapshot_after_train = false.
765 if (step == TRAIN_STEP.NONE && (m_param.snapshot_after_train && (m_param.snapshot == 0 || (m_nIter % m_param.snapshot) != 0)))
766 Snapshot(false, true);
767 else if (m_net.learnable_parameters.SnapshotRequested(true))
768 Snapshot(true, false);
770 if (m_evtCancel.WaitOne(0))
771 {
772 m_log.WriteLine("Optimization stopped early.");
773 return;
774 }
776 // After the optimization is done, run an additional train and test pass to
777 // display the train and test loss/outputs if appropriate (based on the
778 // display and test_interval settings, respectively). Unlike in the rest of
779 // training, for the train net we only run a forward pass as we've already
780 // updated the parameters 'max_iter' times -- this final pass is only done to
781 // display the loss, which is computed in the forward pass.
782 if (m_param.display > 0 && (m_nIter % m_param.display) == 0)
783 {
784 double dfLoss;
785 m_net.Forward(out dfLoss);
787 UpdateSmoothedLoss(dfLoss, start_iter);
788 m_log.WriteLine("Iteration " + m_nIter + ", loss = " + m_dfSmoothedLoss.ToString());
789 }
792 {
793 if (m_bEnableTest)
794 TestAll();
795 }
797 m_log.WriteLine("Optimization done.");
799 if (m_blobBatchInputData != null)
800 {
801 m_blobBatchInputData.Dispose();
802 m_blobBatchInputData = null;
803 }
804 }
818 public bool Step(int nIters, TRAIN_STEP step = TRAIN_STEP.NONE, bool bZeroDiffs = true, bool bApplyUpdates = true, bool bDisableOutput = false, bool bDisableProgress = false, double? dfLossOverride = null, bool? bAllowSnapshot = null)
819 {
820 Exception err = null;
822 try
823 {
824 BlobCollection<T> colBottom = new BlobCollection<T>();
825 int start_iter = m_nIter;
826 int stop_iter = m_nIter + nIters;
828 m_rgLosses.Clear();
831 // Break on first NaN is a debugging tool
832 // that causes the network to stop training
833 // right after a NaN is discovered either
834 // just after the forward pass or just
835 // after the backward pass.
836 m_net.EnableBreakOnFirstNaN = m_bEnableBreakOnNan && m_bEnableBlobDebugging;
837 m_net.EnableDetailedNanDetection = m_bEnableDetailedNanDetection & m_bEnableBlobDebugging;
839 Stopwatch sw = new Stopwatch();
840 sw.Start();
842 Stopwatch swTimeout = new Stopwatch();
843 swTimeout.Start();
845 while (m_nIter < stop_iter && !m_evtCompleted.WaitOne(0))
846 {
847 // zero-init the params.
848 if (bZeroDiffs)
849 m_net.ClearParamDiffs();
851 if (OnStart != null)
852 OnStart(this, new EventArgs());
854 if (step == TRAIN_STEP.NONE && (forceTest ||
855 (m_param.test_interval > 0 &&
856 (m_nIter % m_param.test_interval) == 0 &&
858 {
859 if (m_bEnableTest && is_root_solver)
860 m_dfLastAccuracy = TestAll();
862 // Break out of the while loop because a stop was requested while testing.
863 if (m_evtCancel.WaitOne(0))
864 break;
865 }
867 // on_start currently not used, so no event added.
868 bool bDisplay1 = (is_root_solver && m_param.display > 0 && (m_nIter % m_param.display) == 0 && !bDisableOutput) ? true : false;
869 m_net.set_debug_info(bDisplay1 && m_param.debug_info);
871 // accumulate the loss and gradient
872 double dfLoss = 0;
873 double dfLossTotal = 0;
874 double? dfAccuracyTotal = null;
875 int nIterCount = 0;
877 Stopwatch swTiming = new Stopwatch();
878 double dfTotalTime = 0;
879 int nTimingCount = 0;
880 bool bFwdPassNanFree = true;
882 for (int i = 0; i < m_param.iter_size; i++)
883 {
884 double dfLocalLoss;
885 double? dfLocalAccuracy = null;
887 swTiming.Restart();
889 if (OnCustomForwardBack != null)
890 {
892 OnCustomForwardBack(this, args);
893 bFwdPassNanFree = args.FwdPassNanFree;
894 dfLocalLoss = args.LocalLoss;
895 }
896 else
897 {
898 bFwdPassNanFree = m_net.ForwardBackward(colBottom, out dfLocalLoss, step);
900 if (m_blobAccuracy != null)
901 dfLocalAccuracy = Utility.ConvertVal<T>(m_blobAccuracy.GetData(0));
902 }
904 if (double.IsNaN(dfLocalLoss) || double.IsInfinity(dfLocalLoss))
905 {
906 if (m_bFirstNanError)
907 {
908 m_log.WriteError(new Exception("The local loss at iteration " + m_nIter.ToString() + " is invalid (NAN or INFINITY)!"));
909 m_bFirstNanError = false;
910 }
911 }
913 if (dfLocalAccuracy.HasValue)
914 {
915 if (!dfAccuracyTotal.HasValue)
916 dfAccuracyTotal = 0;
918 dfAccuracyTotal = dfAccuracyTotal + dfLocalAccuracy.Value;
919 }
921 dfLossTotal += dfLocalLoss;
922 swTiming.Stop();
924 dfTotalTime += swTiming.Elapsed.TotalMilliseconds;
925 nTimingCount++;
926 nIterCount++;
928 if (!bFwdPassNanFree)
929 break;
930 }
932 dfLoss = dfLossTotal / nIterCount;
933 dfLoss = dfLossOverride.GetValueOrDefault(dfLoss);
935 if (dfAccuracyTotal.HasValue)
936 m_dfIterAccuracy = dfAccuracyTotal.Value / nIterCount;
938 // average the loss across iterations for smoothed reporting
939 UpdateSmoothedLoss(dfLoss, start_iter);
941 bool bDisplay = false;
942 if (!bDisplay1 && sw.ElapsedMilliseconds > 2000 && !bDisableOutput)
943 {
944 bDisplay = true;
945 m_bFirstNanError = true;
946 sw.Restart();
947 }
949 if (bDisplay && bDisplay1)
950 {
951 m_log.WriteLine("Iteration " + m_nIter.ToString() + ", loss = " + m_dfSmoothedLoss.ToString());
953 BlobCollection<T> colResult = m_net.output_blobs;
954 int score_index = 0;
956 if (is_root_solver)
957 {
958 for (int j = 0; j < colResult.Count; j++)
959 {
960 double[] result_vec = Utility.ConvertVec<T>(colResult[j].update_cpu_data());
961 int nIdx = m_net.output_blob_indices[j];
962 string output_name = m_net.blob_names[nIdx];
963 double loss_weight = m_net.blob_loss_weights[nIdx];
964 double dfTotalLossWeight = 0;
965 int nResultCount = colResult[j].count();
967 for (int k = 0; k < nResultCount; k++)
968 {
970 {
971 string strOut = "";
973 if (loss_weight != 0)
974 strOut += " (* " + loss_weight.ToString() + " = " + (loss_weight * result_vec[k]).ToString() + " loss)";
976 m_log.WriteLine(" Train net output #" + score_index.ToString() + ": " + output_name + " = " + result_vec[k].ToString() + strOut);
977 score_index++;
978 }
979 else
980 {
981 dfTotalLossWeight += loss_weight * result_vec[k];
982 }
983 }
986 {
987 double dfAverage = dfTotalLossWeight / nResultCount;
988 m_log.WriteLine(" Average weighted score = " + dfAverage.ToString() + " for '" + output_name + "' - averaged over " + nResultCount.ToString("N0") + " results.");
989 }
990 }
991 }
992 }
994 if (OnGradientsReady != null && bFwdPassNanFree)
997 double dfLastLearningRate = 0;
999 if (step != TRAIN_STEP.FORWARD && bApplyUpdates)
1000 dfLastLearningRate = ApplyUpdate(m_nIter);
1002 if (m_evtCancel.WaitOne(0))
1003 break;
1005 if (!bDisableProgress)
1006 m_log.Progress = (double)m_nIter / (double)stop_iter;
1008 bool bSnapshotTaken = false;
1009 bool bForceSnapshot = forceSnapshot;
1011 if ((step == TRAIN_STEP.NONE || bAllowSnapshot.GetValueOrDefault(false)) && (is_root_solver && bFwdPassNanFree &&
1012 (bForceSnapshot ||
1013 (m_param.snapshot > 0 && (m_nIter % m_param.snapshot) == 0) ||
1014 (m_dfLastAccuracy > m_dfBestAccuracy))))
1015 {
1016 bSnapshotTaken = true;
1017 Snapshot(bForceSnapshot, ((m_param.snapshot > 0 && (m_nIter % m_param.snapshot) == 0)) ? true : false);
1019 if (m_dfLastAccuracy > m_dfBestAccuracy)
1020 m_dfBestAccuracy = m_dfLastAccuracy;
1021 }
1023 //-------------------------------------
1024 // Call the training iteration event
1025 // on the root solver.
1026 //-------------------------------------
1027 fireOnTrainingIterationEvent(bFwdPassNanFree, dfLoss, dfLastLearningRate, ref nTimingCount, ref dfTotalTime);
1029 //-------------------------------------
1030 // If single stepping, stop the solver.
1031 //-------------------------------------
1032 if (step != TRAIN_STEP.NONE || m_bEnableSingleStep)
1033 {
1034 if (step == TRAIN_STEP.BOTH)
1035 {
1036 if (!bDisableOutput)
1037 m_log.WriteLine("Single step (both) triggered - solving stopped after a single forward/backward pass.");
1038 }
1039 else if (step == TRAIN_STEP.FORWARD)
1040 {
1041 if (!bDisableOutput)
1042 m_log.WriteLine("Single step (forward) triggered - solving stopped after a single forward pass.");
1043 }
1044 else if (step == TRAIN_STEP.BACKWARD)
1045 {
1046 if (!bDisableOutput)
1047 m_log.WriteLine("Single step (backward) triggered - solving stopped after a single backward pass.");
1048 }
1049 else
1050 {
1051 // When single stepping, force the snapshot so as to allow
1052 // debugging the net visually.
1053 if (!bSnapshotTaken)
1054 Snapshot(true, false);
1055 }
1056 break;
1057 }
1059 //-------------------------------------
1060 // If a time-limit has been imposed
1061 // and we have exceeded it, stop
1062 // training.
1063 //-------------------------------------
1064 if (m_nTrainingTimeLimitInMinutes > 0 && swTimeout.Elapsed.TotalMinutes > m_nTrainingTimeLimitInMinutes)
1065 {
1066 m_log.WriteLine("A training time-limit of " + m_nTrainingTimeLimitInMinutes.ToString("N0") + " minutes has been exceeded - training will now stop.");
1067 return true;
1068 }
1070 if (!bApplyUpdates)
1071 break;
1072 }
1074 return true;
1075 }
1076 catch (Exception excpt)
1077 {
1078 err = excpt;
1079 throw excpt;
1080 }
1081 finally
1082 {
1083 if (err != null || m_evtCancel.WaitOne(0))
1084 {
1085 if (OnAborted != null)
1086 OnAborted(this, new EventArgs());
1087 }
1088 }
1089 }
1097 public void Restore(byte[] rgWeights, byte[] rgState, string strSkipBlobTypes = null)
1098 {
1099 m_net.LoadWeights(rgWeights, m_persist, null, null, strSkipBlobTypes);
1101 if (rgState != null)
1102 {
1103 m_log.WriteLine("Restoring previous solver state from restore state...");
1104 RestoreSolverState(rgState);
1105 }
1106 }
1115 public void Snapshot(bool bForced, bool bScheduled, bool bUpdateDatabase = true)
1116 {
1117 m_log.WriteLine("Starting snap shot...");
1118 m_log.CHECK(is_root_solver, "Snapshot only supported on the root solver.");
1120 if (OnSnapshot == null)
1121 return;
1123 if (m_snapshotWeightUpdatemMethod == SNAPSHOT_WEIGHT_UPDATE_METHOD.DISABLED && !bForced)
1124 {
1125 m_log.WriteLine("WARNING: Snapshot UPDATE_METHOD = DISABLED.");
1126 return;
1127 }
1129 SnapshotArgs args = GetSnapshotArgs(null, null, m_dfLastAccuracy, m_dfLastError, m_nIter, m_snapshotWeightUpdatemMethod);
1130 args.Forced = bForced;
1131 args.Scheduled = bScheduled;
1132 args.UpdateDatabase = bUpdateDatabase;
1134 OnSnapshot(this, args);
1135 m_log.WriteLine("Snapshot completed.");
1136 }
1138 private void args_OnGetWeights(object sender, GetBytesArgs e)
1139 {
1140 if (m_net != null)
1141 e.Data = m_net.SaveWeights(m_persist, m_param.snapshot_diff);
1142 }
1144 private void args_OnGetState(object sender, GetBytesArgs e)
1145 {
1147 }
1159 public SnapshotArgs GetSnapshotArgs(byte[] rgState, byte[] rgWeights, double dfAccuracy, double dfError, int nIteration, SNAPSHOT_WEIGHT_UPDATE_METHOD wtUpdt)
1160 {
1161 if (dfAccuracy == 0)
1162 dfAccuracy = 0.0001;
1164 SnapshotArgs args = new SnapshotArgs(rgState, rgWeights, dfAccuracy, dfError, nIteration, wtUpdt);
1168 args.SingleStep = m_bEnableSingleStep;
1169 args.OnGetState += args_OnGetState;
1170 args.OnGetWeights += args_OnGetWeights;
1172 return args;
1173 }
1179 {
1180 get { return m_nTrainingIterationOverride; }
1181 set { m_nTrainingIterationOverride = value; }
1182 }
1188 {
1189 get { return m_nTestingIterationOverride; }
1190 set { m_nTestingIterationOverride = value; }
1191 }
1196 public AutoResetEvent CompletedEvent
1197 {
1198 get { return m_evtCompleted; }
1199 }
1205 {
1206 get { return m_evtCancel; }
1207 }
1212 public double smoothed_loss
1213 {
1214 get { return m_dfSmoothedLoss; }
1215 }
1221 {
1222 get { return m_param; }
1223 }
1229 {
1230 get { return m_net; }
1231 }
1236 public List<Net<T>> test_nets
1237 {
1238 get { return m_rgTestNets; }
1239 }
1244 public int iter
1245 {
1246 get { return m_nIter; }
1247 }
1253 {
1254 get { return m_param.type; }
1255 }
1260 protected bool forceSnapshot
1261 {
1262 get
1263 {
1264 if (m_evtForceSnapshot == null)
1265 return false;
1267 return m_evtForceSnapshot.WaitOne(0);
1268 }
1269 }
1274 public bool forceTest
1275 {
1276 get
1277 {
1278 if (m_evtForceTest == null)
1279 return false;
1281 m_bForceTest = m_evtForceTest.WaitOne(0);
1282 return m_bForceTest;
1283 }
1284 }
1289 public int solver_count
1290 {
1291 get { return m_nSolverCount; }
1292 }
1297 public int solver_rank
1298 {
1299 get { return m_nSolverRank; }
1300 }
1308 public bool is_root_solver
1309 {
1310 get { return (m_nSolverRank == 0) ? true : false; }
1311 }
1322 public double TestAll(int nIterationOverride = -1)
1323 {
1324 double dfTotalAccuracy = 0;
1325 double dfTotalTime = 0;
1326 int nTotalCount = 0;
1328 for (int test_net_id = 0; test_net_id < m_rgTestNets.Count; test_net_id++)
1329 {
1330 if (m_evtCancel.WaitOne(0))
1331 return 0;
1333 if (OnTest != null)
1334 {
1335 TestArgs args = new TestArgs(nIterationOverride, test_net_id);
1336 OnTest(this, args);
1337 dfTotalAccuracy += args.Accuracy;
1338 }
1339 else
1340 dfTotalAccuracy += testOne(nIterationOverride, test_net_id);
1342 dfTotalTime += m_dfAverageTestTime;
1343 nTotalCount++;
1344 }
1346 if (m_rgTestNets.Count == 0)
1347 {
1348 if (OnTest != null)
1349 {
1350 TestArgs args = new TestArgs(nIterationOverride, 0);
1351 OnTest(this, args);
1352 dfTotalAccuracy += args.Accuracy;
1353 }
1354 else
1355 dfTotalAccuracy += testOne(nIterationOverride, 0);
1356 }
1358 double dfAccuracy = (m_rgTestNets.Count > 0) ? dfTotalAccuracy / m_rgTestNets.Count : 0;
1360 if (m_rgAverageAccuracyWindow != null)
1361 {
1362 m_rgAverageAccuracyWindow.Add(dfAccuracy);
1363 m_rgAverageAccuracyWindow.RemoveAt(0);
1364 dfAccuracy = m_rgAverageAccuracyWindow.Average();
1365 }
1367 if (OnTestingIteration != null)
1368 {
1369 double dfTime = (nTotalCount > 0) ? dfTotalTime / nTotalCount : 0;
1370 OnTestingIteration(this, new TestingIterationArgs<T>(m_nIter, dfAccuracy, dfTime));
1371 }
1373 return dfAccuracy;
1374 }
1376 private double testOne(int nIterationOverride = -1, int nTestNetId = 0)
1377 {
1378 switch (m_param.eval_type)
1379 {
1380 // Test SSD Detection
1381 case SolverParameter.EvaluationType.DETECTION:
1382 return TestDetection(nIterationOverride, nTestNetId);
1384 // Perform regular classification Test.
1385 default:
1386 return TestClassification(nIterationOverride, nTestNetId);
1387 }
1388 }
1396 public double TestDetection(int nIterationOverride = -1, int nTestNetId = 0)
1397 {
1398 Stopwatch sw = new Stopwatch();
1399 BBoxUtility<T> bboxUtil = new BBoxUtility<T>(m_cuda, m_log);
1401 try
1402 {
1403 if (is_root_solver)
1404 m_log.WriteLine("Iteration " + m_nIter.ToString() + ", Testing net (#" + nTestNetId.ToString() + ")");
1406 Net<T> test_net = m_net;
1408 if (m_rgTestNets.Count > nTestNetId)
1409 {
1410 m_log.CHECK(m_rgTestNets[nTestNetId] != null, "The test net at " + nTestNetId.ToString() + " is null!");
1411 m_rgTestNets[nTestNetId].ShareTrainedLayersWith(m_net);
1412 test_net = m_rgTestNets[nTestNetId];
1413 }
1415 Dictionary<int, Dictionary<int, List<Tuple<float, int>>>> rgAllTruePos = new Dictionary<int, Dictionary<int, List<Tuple<float, int>>>>();
1416 Dictionary<int, Dictionary<int, List<Tuple<float, int>>>> rgAllFalsePos = new Dictionary<int, Dictionary<int, List<Tuple<float, int>>>>();
1417 Dictionary<int, Dictionary<int, int>> rgAllNumPos = new Dictionary<int, Dictionary<int, int>>();
1419 double dfLoss = 0;
1421 if (nIterationOverride <= 0)
1422 nIterationOverride = TestingIterations;
1424 int nIter = nIterationOverride;
1425 sw.Start();
1427 for (int i = 0; i < nIter; i++)
1428 {
1429 // Check to see if stoppage of testing/training has been requested.
1430 if (m_evtCancel.WaitOne(0))
1431 break;
1433 if (OnTestStart != null)
1434 OnTestStart(this, new EventArgs());
1436 double iter_loss;
1437 BlobCollection<T> colResult = test_net.Forward(out iter_loss);
1440 dfLoss += iter_loss;
1442 for (int j = 0; j < colResult.Count; j++)
1443 {
1444 m_log.CHECK_EQ(colResult[j].width, 5, "The width must be = 5 for SSD.");
1445 double[] result_vec = Utility.ConvertVec<T>(colResult[j].update_cpu_data());
1446 int num_det = colResult[j].height;
1448 for (int k = 0; k < num_det; k++)
1449 {
1450 int item_id = (int)result_vec[k * 5];
1451 int nLabel = (int)result_vec[k * 5 + 1];
1453 // Special row for storing number of positives for a label.
1454 if (item_id == -1)
1455 {
1456 if (!rgAllNumPos.ContainsKey(j))
1457 rgAllNumPos.Add(j, new Dictionary<int, int>());
1459 if (!rgAllNumPos[j].ContainsKey(nLabel))
1460 rgAllNumPos[j].Add(nLabel, (int)result_vec[k * 5 + 2]);
1461 else
1462 rgAllNumPos[j][nLabel] += (int)result_vec[k * 5 + 2];
1463 }
1464 // Normal row storing detection status.
1465 else
1466 {
1467 float fScore = (float)result_vec[k * 5 + 2];
1468 int tp = (int)result_vec[k * 5 + 3];
1469 int fp = (int)result_vec[k * 5 + 4];
1471 // Ignore such case, which happens when a detection bbox is matched to
1472 // a difficult gt bbox and we don't evaluate on difficult gt bbox.
1473 if (tp == 0 && fp == 0)
1474 continue;
1476 if (!rgAllTruePos.ContainsKey(j))
1477 rgAllTruePos.Add(j, new Dictionary<int, List<Tuple<float, int>>>());
1479 if (!rgAllTruePos[j].ContainsKey(nLabel))
1480 rgAllTruePos[j].Add(nLabel, new List<Tuple<float, int>>());
1482 if (!rgAllFalsePos.ContainsKey(j))
1483 rgAllFalsePos.Add(j, new Dictionary<int, List<Tuple<float, int>>>());
1485 if (!rgAllFalsePos[j].ContainsKey(nLabel))
1486 rgAllFalsePos[j].Add(nLabel, new List<Tuple<float, int>>());
1488 rgAllTruePos[j][nLabel].Add(new Tuple<float, int>(fScore, tp));
1489 rgAllFalsePos[j][nLabel].Add(new Tuple<float, int>(fScore, fp));
1490 }
1491 }
1492 }
1494 if (sw.Elapsed.TotalMilliseconds > 1000)
1495 {
1496 m_log.Progress = (double)i / (double)nIter;
1497 m_log.WriteLine("Testing at " + m_log.Progress.ToString("P") + " " + i.ToString() + " of " + nIter.ToString() + "...");
1498 sw.Restart();
1499 }
1500 }
1502 if (m_evtCancel.WaitOne(0))
1503 {
1504 m_log.WriteLine("Test interrupted.");
1505 return 0;
1506 }
1509 {
1510 dfLoss /= m_param.test_iter[nTestNetId];
1511 m_log.WriteLine("Test loss: " + dfLoss.ToString());
1512 }
1514 float fTotalmAP = 0;
1515 for (int i = 0; i < rgAllTruePos.Count; i++)
1516 {
1517 if (!rgAllTruePos.ContainsKey(i))
1518 m_log.FAIL("Missing output_blob true_pos: " + i.ToString());
1520 Dictionary<int, List<Tuple<float, int>>> rgTruePos = rgAllTruePos[i];
1522 if (!rgAllFalsePos.ContainsKey(i))
1523 m_log.FAIL("Missing output_blob false_pos: " + i.ToString());
1525 Dictionary<int, List<Tuple<float, int>>> rgFalsePos = rgAllFalsePos[i];
1527 if (!rgAllNumPos.ContainsKey(i))
1528 m_log.FAIL("Missing output_blob num_pos: " + i.ToString());
1530 Dictionary<int, int> rgNumPos = rgAllNumPos[i];
1532 Dictionary<int, float> rgAPs = new Dictionary<int, float>();
1533 float fmAP = 0.0f;
1535 // Sort true_pos and false_pos with descending scores.
1536 foreach (KeyValuePair<int, int> kv in rgNumPos)
1537 {
1538 int nLabel = kv.Key;
1539 int nLabelNumPos = kv.Value;
1541 if (!rgTruePos.ContainsKey(nLabel))
1542 {
1543 m_log.WriteLine("WARNING: Missing true_pos for label: " + nLabel.ToString() + "!");
1544 continue;
1545 }
1546 List<Tuple<float, int>> rgLabelTruePos = rgTruePos[nLabel];
1548 if (!rgFalsePos.ContainsKey(nLabel))
1549 {
1550 m_log.WriteLine("WARNING: Missing false_pos for label: " + nLabel.ToString() + "!");
1551 continue;
1552 }
1553 List<Tuple<float, int>> rgLabelFalsePos = rgFalsePos[nLabel];
1555 List<float> rgPrec;
1556 List<float> rgRec;
1557 float fAp = bboxUtil.ComputeAP(rgLabelTruePos, nLabelNumPos, rgLabelFalsePos, m_param.ap_version, out rgPrec, out rgRec);
1559 if (!rgAPs.ContainsKey(nLabel))
1560 rgAPs.Add(nLabel, fAp);
1561 else
1562 rgAPs[nLabel] = fAp;
1564 fmAP += fAp;
1567 m_log.WriteLine("class " + nLabel.ToString() + ": " + fAp.ToString());
1568 }
1570 fmAP /= rgNumPos.Count;
1572 int nOutputBlobIdx = test_net.output_blob_indices[i];
1573 string strOutputName = test_net.blob_names[nOutputBlobIdx];
1575 m_log.WriteLine(" Test net output #" + i.ToString() + ": " + strOutputName + " = " + fmAP.ToString());
1576 fTotalmAP += fmAP;
1577 }
1579 return fTotalmAP / rgAllTruePos.Count;
1580 }
1581 catch (Exception excpt)
1582 {
1583 throw excpt;
1584 }
1585 finally
1586 {
1587 bboxUtil.Dispose();
1588 }
1589 }
1597 public double TestClassification(int nIterationOverride = -1, int nTestNetId = 0)
1598 {
1599 bool bDisplay = (is_root_solver && m_param.display > 0 && (m_nIter % m_param.display) == 0) ? true : false;
1601 if (m_bForceTest)
1602 {
1603 m_bForceTest = false;
1604 bDisplay = true;
1605 }
1607 if (bDisplay)
1608 m_log.WriteLine("Iteration " + m_nIter.ToString() + ", Testing net (#" + nTestNetId.ToString() + ")");
1610 Net<T> test_net = m_net;
1612 if (m_rgTestNets.Count > nTestNetId)
1613 {
1614 m_log.CHECK(m_rgTestNets[nTestNetId] != null, "The test net at " + nTestNetId.ToString() + " is null!");
1615 m_rgTestNets[nTestNetId].ShareTrainedLayersWith(m_net);
1616 test_net = m_rgTestNets[nTestNetId];
1617 }
1619 List<double> test_score = new List<double>();
1620 List<int> test_score_output_id = new List<int>();
1621 double dfLoss = 0;
1623 if (nIterationOverride <= 0)
1624 nIterationOverride = TestingIterations;
1626 int nIter = nIterationOverride;
1628 Stopwatch sw = new Stopwatch();
1629 sw.Start();
1631 double dfTotalTiming = 0;
1632 int nTestCount = 0;
1633 int nAccuracyIdx = 0;
1634 int nMinRank = int.MaxValue;
1635 bool bAccuracyValid = false;
1636 Stopwatch swTiming = new Stopwatch();
1638 for (int i = 0; i < nIter; i++)
1639 {
1640 // Check to see if stoppage of testing/training has been requested.
1641 if (m_evtCancel.WaitOne(0))
1642 break;
1644 if (OnTestStart != null)
1645 OnTestStart(this, new EventArgs());
1647 swTiming.Restart();
1649 double iter_loss;
1650 BlobCollection<T> colResult = test_net.Forward(out iter_loss);
1653 dfLoss += iter_loss;
1655 TestResultArgs<T> args = new TestResultArgs<T>(colResult);
1656 if (OnTestResults != null)
1657 {
1658 OnTestResults(this, args);
1659 if (args.AccuracyValid)
1660 {
1661 test_score.Add(args.Accuracy);
1662 test_score_output_id.Add(1);
1663 bAccuracyValid = true;
1664 }
1665 }
1667 if (!args.AccuracyValid)
1668 {
1669 if (i == 0)
1670 {
1671 for (int j = 0; j < colResult.Count; j++)
1672 {
1673 double[] result_vec = Utility.ConvertVec<T>(colResult[j].update_cpu_data());
1675 for (int k = 0; k < colResult[j].count(); k++)
1676 {
1677 test_score.Add(result_vec[k]);
1678 test_score_output_id.Add(j);
1679 }
1681 if (colResult[j].type == BLOB_TYPE.ACCURACY)
1682 {
1683 int nRank = (int)getNumber(colResult[j].Tag, 0);
1684 if (nRank < nMinRank)
1685 {
1686 nMinRank = nRank;
1687 nAccuracyIdx = j;
1688 }
1689 }
1690 }
1691 }
1692 else
1693 {
1694 int idx = 0;
1696 for (int j = 0; j < colResult.Count; j++)
1697 {
1698 double[] result_vec = Utility.ConvertVec<T>(colResult[j].update_cpu_data());
1700 for (int k = 0; k < colResult[j].count(); k++)
1701 {
1702 test_score[idx] += result_vec[k];
1703 idx++;
1704 }
1705 }
1706 }
1707 }
1709 swTiming.Stop();
1710 dfTotalTiming += swTiming.Elapsed.TotalMilliseconds;
1711 nTestCount++;
1713 if (sw.ElapsedMilliseconds > 2000)
1714 {
1715 double dfPct = (double)i / (double)nIter;
1717 if (bDisplay)
1718 {
1719 m_log.Progress = dfPct;
1720 m_log.WriteLine("Testing '" + + "' at " + dfPct.ToString("P"));
1721 }
1723 sw.Restart();
1724 }
1725 }
1727 m_dfAverageTestTime = (nTestCount > 0) ? dfTotalTiming / nTestCount : 0;
1729 if (m_evtCancel.WaitOne(0))
1730 {
1731 m_log.WriteLine("Test interrupted.");
1732 return 0;
1733 }
1736 {
1737 dfLoss /= m_param.test_iter[nTestNetId];
1738 m_log.WriteLine("Test loss: " + dfLoss.ToString());
1739 }
1741 double dfFinalScore = 0;
1743 if (bAccuracyValid)
1744 {
1745 dfFinalScore = test_score.Sum();
1746 int nTotal = test_score_output_id.Sum();
1747 dfFinalScore /= nTotal;
1748 }
1749 else
1750 {
1751 for (int i = 0; i < test_score.Count; i++)
1752 {
1753 int nIdxTestScore = test_score_output_id[i];
1754 int output_blob_index = test_net.output_blob_indices[nIdxTestScore];
1755 string output_name = test_net.blob_names[output_blob_index];
1756 double loss_weight = test_net.blob_loss_weights[output_blob_index];
1757 double dfMeanScore = test_score[i] / nIter;
1758 string strOut = "";
1760 if (bDisplay)
1761 {
1762 if (loss_weight != 0)
1763 strOut += " (* " + loss_weight.ToString() + " = " + (loss_weight * dfMeanScore).ToString() + " loss)";
1765 m_log.WriteLine(" Test net output #" + i.ToString() + ": " + output_name + " = " + dfMeanScore.ToString() + strOut);
1766 }
1768 if (i == nAccuracyIdx)
1769 dfFinalScore = dfMeanScore;
1770 }
1771 }
1773 if (test_score.Count == 0)
1774 return 0;
1776 return dfFinalScore;
1777 }
1779 private double getNumber(object value, double dfDefault)
1780 {
1781 if (value == null)
1782 return dfDefault;
1784 if (value is sbyte)
1785 return (double)(sbyte)value;
1787 if (value is byte)
1788 return (double)(byte)value;
1790 if (value is short)
1791 return (double)(short)value;
1793 if (value is ushort)
1794 return (double)(ushort)value;
1796 if (value is int)
1797 return (double)(int)value;
1799 if (value is uint)
1800 return (double)(uint)value;
1802 if (value is long)
1803 return (double)(long)value;
1805 if (value is ulong)
1806 return (double)(ulong)value;
1808 if (value is float)
1809 return (double)(float)value;
1811 if (value is double)
1812 return (double)value;
1814 if (value is decimal)
1815 return (double)(decimal)value;
1817 return dfDefault;
1818 }
1826 public void UpdateSmoothedLoss(double dfLoss, int nStartIter, int nAverageLoss = 0)
1827 {
1828 if (nAverageLoss == 0)
1829 nAverageLoss = m_param.average_loss;
1831 if (m_rgLosses.Count < nAverageLoss)
1832 {
1833 m_rgLosses.Add(dfLoss);
1834 int nCount = m_rgLosses.Count;
1835 m_dfSmoothedLoss = (m_dfSmoothedLoss * (nCount - 1) + dfLoss) / nCount;
1836 }
1837 else
1838 {
1839 int nIdx = (m_nIter - nStartIter) % nAverageLoss;
1840 m_dfSmoothedLoss += (dfLoss - m_rgLosses[nIdx]) / nAverageLoss;
1841 m_rgLosses[nIdx] = dfLoss;
1842 }
1844 if (m_bWeightsUpdated)
1845 {
1846 m_dfSmoothedLoss = dfLoss;
1847 m_bWeightsUpdated = false;
1848 }
1850 m_dfLastError = m_dfSmoothedLoss;
1852 if (m_dfLastError < m_dfBestError)
1853 m_dfBestError = m_dfLastError;
1854 }
1860 public abstract double ApplyUpdate(int nIterationOverride = -1);
1865 protected abstract byte[] SnapshotSolverState();
1870 protected abstract void RestoreSolverState(byte[] rgState);
1889 public static SGDSolver<T> Create(CudaDnn<T> cuda, Log log, ProjectEx p, CancelEvent evtCancel, AutoResetEvent evtForceSnapshot, AutoResetEvent evtForceTest, IXDatabaseBase db, IXPersist<T> persist, int nSolverCount = 1, int nSolverRank = 0, Net<T> shareNet = null, onGetWorkspace getws = null, onSetWorkspace setws = null)
1890 {
1891 SolverParameter solverParam = null;
1893 if (p.SolverDescription != null)
1894 {
1895 RawProto protoSolver = RawProto.Parse(p.SolverDescription);
1896 solverParam = SolverParameter.FromProto(protoSolver);
1897 }
1898 else
1899 {
1900 solverParam = new param.SolverParameter();
1901 }
1903 if (solverParam.net_param == null)
1904 {
1905 RawProto protoModel = RawProto.Parse(p.ModelDescription);
1906 solverParam.net_param = NetParameter.FromProto(protoModel);
1907 solverParam.net_param.ProjectID = p.ID;
1908 }
1910 return Create(cuda, log, solverParam, evtCancel, evtForceSnapshot, evtForceTest, db, persist, nSolverCount, nSolverRank, shareNet, getws, setws);
1911 }
1930 public static SGDSolver<T> Create(CudaDnn<T> cuda, Log log, SolverParameter solverParam, CancelEvent evtCancel, AutoResetEvent evtForceSnapshot, AutoResetEvent evtForceTest, IXDatabaseBase db, IXPersist<T> persist, int nSolverCount = 1, int nSolverRank = 0, Net<T> shareNet = null, onGetWorkspace getws = null, onSetWorkspace setws = null)
1931 {
1932 SGDSolver<T> solver = null;
1934 switch (solverParam.type)
1935 {
1936 case SolverParameter.SolverType.SGD:
1937 solver = new SGDSolver<T>(cuda, log, solverParam, evtCancel, evtForceSnapshot, evtForceTest, db, persist, nSolverCount, nSolverRank, shareNet, getws, setws);
1938 break;
1940 case SolverParameter.SolverType.NESTEROV:
1941 solver = new NesterovSolver<T>(cuda, log, solverParam, evtCancel, evtForceSnapshot, evtForceTest, db, persist, nSolverCount, nSolverRank, shareNet, getws, setws);
1942 break;
1944 case SolverParameter.SolverType.ADAGRAD:
1945 solver = new AdaGradSolver<T>(cuda, log, solverParam, evtCancel, evtForceSnapshot, evtForceTest, db, persist, nSolverCount, nSolverRank, shareNet, getws, setws);
1946 break;
1948 case SolverParameter.SolverType.ADADELTA:
1949 solver = new AdaDeltaSolver<T>(cuda, log, solverParam, evtCancel, evtForceSnapshot, evtForceTest, db, persist, nSolverCount, nSolverRank, shareNet, getws, setws);
1950 break;
1952 case SolverParameter.SolverType.ADAM:
1953 solver = new AdamSolver<T>(cuda, log, solverParam, evtCancel, evtForceSnapshot, evtForceTest, db, persist, nSolverCount, nSolverRank, shareNet, getws, setws);
1954 break;
1956 case SolverParameter.SolverType.ADAMW:
1957 solver = new AdamWSolver<T>(cuda, log, solverParam, evtCancel, evtForceSnapshot, evtForceTest, db, persist, nSolverCount, nSolverRank, shareNet, getws, setws);
1958 break;
1960 case SolverParameter.SolverType.RMSPROP:
1961 solver = new RmsPropSolver<T>(cuda, log, solverParam, evtCancel, evtForceSnapshot, evtForceTest, db, persist, nSolverCount, nSolverRank, shareNet, getws, setws);
1962 break;
1964 default:
1965 throw new NotImplementedException("The solver " + solverParam.type.ToString() + " is not implemented yet!");
1966 }
1968 return solver;
1969 }
1970 }
1972#pragma warning disable 1591
1974 public class OutputCollection
1975 {
1976 OutputDataCollection m_rgError = new OutputDataCollection();
1977 OutputDataCollection m_rgAccuracy = new OutputDataCollection();
1979 public OutputCollection()
1980 {
1981 }
1983 public OutputDataCollection Errors
1984 {
1985 get { return m_rgError; }
1986 }
1988 public OutputDataCollection Accuracies
1989 {
1990 get { return m_rgAccuracy; }
1991 }
1992 }
1994 public class OutputDataCollection : IEnumerable<OutputData>
1995 {
1996 List<OutputData> m_rgData = new List<OutputData>();
1998 public OutputDataCollection()
1999 {
2000 }
2002 public List<OutputData> Data
2003 {
2004 get { return m_rgData; }
2005 }
2007 public int Count
2008 {
2009 get { return m_rgData.Count; }
2010 }
2012 public OutputData this[int nIdx]
2013 {
2014 get { return m_rgData[nIdx]; }
2015 set { m_rgData[nIdx] = value; }
2016 }
2018 public void Add(int nTotal, string strName, int nIdx, double dfVal)
2019 {
2020 OutputData data = Find(strName);
2022 if (data == null)
2023 {
2024 data = new OutputData(strName, nIdx);
2025 m_rgData.Add(data);
2026 }
2028 data.Add(nTotal, dfVal);
2029 }
2031 public OutputData Find(string strName)
2032 {
2033 foreach (OutputData data in m_rgData)
2034 {
2035 if (data.Name == strName)
2036 return data;
2037 }
2039 return null;
2040 }
2042 public IEnumerator<OutputData> GetEnumerator()
2043 {
2044 return m_rgData.GetEnumerator();
2045 }
2047 IEnumerator IEnumerable.GetEnumerator()
2048 {
2049 return m_rgData.GetEnumerator();
2050 }
2051 }
2053 public class OutputData
2054 {
2055 string m_strName;
2056 double m_dfValue = 0;
2057 int m_nIdx;
2059 public OutputData(string strName, int nIdx)
2060 {
2061 m_strName = strName;
2062 m_nIdx = nIdx;
2063 }
2065 public int Index
2066 {
2067 get { return m_nIdx; }
2068 }
2070 public string Name
2071 {
2072 get { return m_strName; }
2073 }
2075 public double Value
2076 {
2077 get { return m_dfValue; }
2078 set { m_dfValue = value; }
2079 }
2081 public void Add(int nTotal, double dfVal)
2082 {
2083 double dfRatio = 1.0 / (double)nTotal;
2084 m_dfValue = (m_dfValue * (1.0 - dfRatio)) + (dfRatio * dfVal);
2085 }
2086 }
2088#pragma warning restore 1591
