MyCaffe  1.12.2.41
Deep learning software for Windows C# programmers.
Solver.cs
1using System;
2using System.Collections.Generic;
3using System.Linq;
4using System.Text;
5using System.Threading;
6using System.IO;
7using System.Diagnostics;
8using System.Collections;
9using MyCaffe.basecode;
10using MyCaffe.db.image;
11using MyCaffe.common;
12using MyCaffe.param;
13
17namespace MyCaffe.solvers
18{
27 public abstract class Solver<T> : IDisposable
28 {
32 protected CudaDnn<T> m_cuda;
36 protected Log m_log;
44 protected Net<T> m_net;
48 protected List<Net<T>> m_rgTestNets = new List<Net<T>>();
52 protected int m_nIter;
56 protected int m_nCurrentStep;
60 protected List<double> m_rgLosses = new List<double>();
61 AutoResetEvent m_evtCompleted = new AutoResetEvent(false);
62 bool m_bEnableTest = true;
63 bool m_bEnableBlobDebugging = false;
64 bool m_bEnableBreakOnNan = false;
65 bool m_bEnableDetailedNanDetection = false;
66 bool m_bEnableSingleStep = false;
70 protected double m_dfSmoothedLoss = 0;
74 protected double? m_dfIterAccuracy = null;
75 Blob<T> m_blobAccuracy = null;
76 CancelEvent m_evtCancel;
77 AutoResetEvent m_evtForceSnapshot;
78 AutoResetEvent m_evtForceTest;
82 protected int m_nSolverCount = 1;
86 protected int m_nSolverRank = 0;
94 protected double m_dfLearningRateOverride = 0;
95 double m_dfLastAccuracy = 0;
96 double m_dfLastError = double.MaxValue;
97 double m_dfBestAccuracy = 0;
98 double m_dfBestError = double.MaxValue;
99 IXDatabaseBase m_db = null;
100 int m_nTrainingIterationOverride = -1;
101 int m_nTestingIterationOverride = -1;
102 object m_tag = null;
103 bool m_bWeightsUpdated = false;
104 static object m_syncGetRi = new object();
105 Blob<T> m_blobBatchInputData = null;
106 double m_dfAverageTestTime = 0;
107 SNAPSHOT_WEIGHT_UPDATE_METHOD m_snapshotWeightUpdatemMethod = SNAPSHOT_WEIGHT_UPDATE_METHOD.FAVOR_ACCURACY;
108 int m_nTrainingTimeLimitInMinutes = 0;
109 long m_hWorkspaceData = 0; // shared among the layers and nets, only grows in size.
110 ulong m_lWorkspaceSizeInBytes = 0;
111 bool m_bFirstNanError = true;
112 List<double> m_rgAverageAccuracyWindow = null;
113 bool m_bForceTest = false;
114
118 public event EventHandler OnStart;
122 public event EventHandler OnAborted;
126 public event EventHandler<GradientsReadyArgs> OnGradientsReady;
130 public event EventHandler<SnapshotArgs> OnSnapshot;
134 public event EventHandler<TrainingIterationArgs<T>> OnTrainingIteration;
138 public event EventHandler<TestingIterationArgs<T>> OnTestingIteration;
142 public event EventHandler<TestResultArgs<T>> OnTestResults;
146 public event EventHandler<TestArgs> OnTest;
150 public event EventHandler OnTestStart;
155 public event EventHandler<CustomForwardBackArgs<T>> OnCustomForwardBack;
159 public event EventHandler<WorkspaceArgs> OnGetWorkspace;
163 public event EventHandler<WorkspaceArgs> OnSetWorkspace;
164
181 public Solver(CudaDnn<T> cuda, Log log, SolverParameter p, CancelEvent evtCancel, AutoResetEvent evtForceSnapshot, AutoResetEvent evtForceTest, IXDatabaseBase db, IXPersist<T> persist, int nSolverCount = 1, int nSolverRank = 0, Net<T> shareNet = null, onGetWorkspace getws = null, onSetWorkspace setws = null)
182 {
183 m_cuda = cuda;
184 m_log = log;
185 m_evtCancel = evtCancel;
186 m_evtForceSnapshot = evtForceSnapshot;
187 m_evtForceTest = evtForceTest;
188
189 if (m_log.IsEnabled)
191
192 m_db = db;
193 m_persist = persist;
194 m_nSolverCount = nSolverCount;
195 m_nSolverRank = nSolverRank;
196
197 if (getws != null)
198 OnGetWorkspace += new EventHandler<WorkspaceArgs>(getws);
199
200 if (setws != null)
201 OnSetWorkspace += new EventHandler<WorkspaceArgs>(setws);
202
203 if (p.accuracy_average_window > 0)
204 {
205 m_rgAverageAccuracyWindow = new List<double>();
206 for (int i = 0; i < p.accuracy_average_window; i++)
207 {
208 m_rgAverageAccuracyWindow.Add(0);
209 }
210 }
211
212 Init(p, shareNet);
213 }
214
218 public void Dispose()
219 {
220 dispose();
221 }
222
227 {
228 get { return m_dfLearningRateOverride; }
229 set { m_dfLearningRateOverride = value; }
230 }
231
237 {
238 int nTimingCount = 0;
239 double dfTotalTime = 0;
240 return fireOnTrainingIterationEvent(false, 0, 0, ref nTimingCount, ref dfTotalTime);
241 }
242
243 private bool fireOnTrainingIterationEvent(bool bFwdPassNanFree, double dfLoss, double dfLastLearningRate, ref int nTimingCount, ref double dfTotalTime)
244 {
245 if (is_root_solver && OnTrainingIteration != null)
246 {
247 string strFirstNanBlob = null;
248 DebugInformation<T> dbgInfo = null;
249
250 if (m_bEnableBlobDebugging)
251 {
252 dbgInfo = TrainingNet.GetDebugInformation(m_bEnableDetailedNanDetection);
253
254 if (m_bEnableBreakOnNan && dbgInfo != null)
255 {
256 string strType;
257 strFirstNanBlob = dbgInfo.DetectFirstNaN(out strType);
258
259 if (strFirstNanBlob != null)
260 {
261 string strPass = (!bFwdPassNanFree) ? "Forward" : "Backward";
262 m_log.WriteLine("First NaN detected in the '" + strType + "' of blob '" + strFirstNanBlob + "' after " + strPass + " pass.");
263
264 string strTypeLast;
265 string strLastNanBlob = dbgInfo.DetectLastNaN(out strTypeLast);
266
267 if (strLastNanBlob != strFirstNanBlob && strType != strTypeLast)
268 m_log.WriteLine("Last NaN detected in the '" + strTypeLast + "' of blob '" + strLastNanBlob + "' after " + strPass + " pass.");
269 }
270 }
271 }
272
273 double dfTime = (nTimingCount > 0) ? (dfTotalTime / nTimingCount) : 0;
274 OnTrainingIteration(this, new TrainingIterationArgs<T>(m_nIter, m_dfLastAccuracy, dfLoss, m_dfSmoothedLoss, m_dfBestError, m_bWeightsUpdated, m_net.ActiveLabelCounts, m_net.LabelQueryHitPercents, m_net.LabelQueryEpochs, m_net.BoostQueryHitPercents, dfLastLearningRate, dfTime, dbgInfo));
275 dfTotalTime = 0;
276 nTimingCount = 0;
277
278 if (strFirstNanBlob != null)
279 {
280 m_log.WriteLine("Training is now stopping at iteration " + m_nIter.ToString("N0") + " as the first NaN has been detected ('" + strFirstNanBlob + "').");
281 return false;
282 }
283 }
284
285 return true;
286 }
287
292 {
293 get { return m_nTrainingTimeLimitInMinutes; }
294 set { m_nTrainingTimeLimitInMinutes = value; }
295 }
296
301 {
302 get { return m_snapshotWeightUpdatemMethod; }
303 set { m_snapshotWeightUpdatemMethod = value; }
304 }
305
310 {
311 get { return m_db; }
312 }
313
317 protected virtual void dispose()
318 {
319 if (m_net != null)
320 {
321 m_net.Dispose();
322 m_net = null;
323 }
324
325 foreach (Net<T> net in m_rgTestNets)
326 {
327 net.Dispose();
328 }
329
330 m_rgTestNets.Clear();
331
332 if (m_blobBatchInputData != null)
333 {
334 m_blobBatchInputData.Dispose();
335 m_blobBatchInputData = null;
336 }
337
338 if (m_hWorkspaceData != 0)
339 {
340 m_cuda.DisableGhostMemory();
341 m_cuda.FreeMemory(m_hWorkspaceData);
342 m_cuda.ResetGhostMemory();
343 m_hWorkspaceData = 0;
344 m_lWorkspaceSizeInBytes = 0;
345 }
346 }
347
351 public bool EnableTesting
352 {
353 get { return m_bEnableTest; }
354 set { m_bEnableTest = value; }
355 }
356
361 {
362 get { return m_bEnableBlobDebugging; }
363 set { m_bEnableBlobDebugging = value; }
364 }
365
373 {
374 get { return TrainingNet.EnableLayerDebugging; }
375 set { TrainingNet.EnableLayerDebugging = value; }
376 }
377
382 {
383 get { return m_bEnableBreakOnNan; }
384 set { m_bEnableBreakOnNan = value; }
385 }
386
395 {
396 get { return m_bEnableDetailedNanDetection; }
397 set { m_bEnableDetailedNanDetection = value; }
398 }
399
404 {
405 get { return m_bEnableSingleStep; }
406 set { m_bEnableSingleStep = value; }
407 }
408
412 public bool WeightsUpdated
413 {
414 get { return m_bWeightsUpdated; }
415 set { m_bWeightsUpdated = value; }
416 }
417
421 public object Tag
422 {
423 get { return m_tag; }
424 set { m_tag = value; }
425 }
426
431 {
432 get
433 {
434 if (m_rgTestNets.Count == 0)
435 return null;
436
437 return m_rgTestNets[0];
438 }
439 }
440
445 {
446 get { return m_net; }
447 }
448
454 public void Init(SolverParameter p, Net<T> shareNet = null)
455 {
456 m_log.WriteLine("Initializing solver from parameters: " + p.DebugString());
457 m_param = p;
458 m_log.CHECK_GE(m_param.average_loss, 1, "Average loss should be non-negative and >= 1.0.");
459
460 if (m_param.random_seed >= 0)
462
463 // Scaffolding code.
464 InitTrainNet(shareNet);
465 InitTestNets();
466
467 if (is_root_solver)
468 m_log.WriteLine("Solver scaffolding done.");
469
470 Reset();
471
472 m_log.WriteLine("INFO: Solver created for " + m_param.eval_type.ToString() + " (NOTE: Detection is only for SSD models).", true);
473 }
474
478 public void Reset()
479 {
480 m_nIter = 0;
481 m_nCurrentStep = 0;
482 }
483
488 protected void InitTrainNet(Net<T> shareNet = null)
489 {
490 try
491 {
492 int num_train_nets = ((m_param.net_param != null) ? 1 : 0) + ((m_param.train_net_param != null) ? 1 : 0);
493 string field_names = "net_param, train_net_param";
494 m_log.CHECK_GE(num_train_nets, 1, "SolverParameter must specify a train net using one of these fields: " + field_names);
495 m_log.CHECK_LE(num_train_nets, 1, "SolverParameter must not contain more than one of these fields specifying a train_net: " + field_names);
496 NetParameter net_param = null;
497
498 if (m_param.train_net_param != null)
499 {
500 m_log.WriteLine("Creating training net specified in train_net_param.");
501 net_param = m_param.train_net_param.Clone(true);
502 }
503
504 if (m_param.net_param != null)
505 {
506 m_log.WriteLine("Creating training net specified in net_param.");
507 net_param = m_param.net_param.Clone(true);
508 }
509
510 // Set the correct NetState. We start with the solver defaults (lowest
511 // precedence); then, merge in any NetState specified by the net_param itself;
512 // finally, merge in any NetState specified by the train-state (highest
513 // precedence).
514 NetState net_state = new NetState();
515 net_state.phase = Phase.TRAIN;
516 net_state.MergeFrom(net_param.state);
517 net_state.MergeFrom(m_param.train_state);
518 net_param.state = net_state;
519 net_param.solver_count = m_nSolverCount;
520 net_param.solver_rank = m_nSolverRank;
521 m_net = new Net<T>(m_cuda, m_log, net_param, m_evtCancel, m_db, Phase.NONE, m_evtCompleted, shareNet, net_OnGetWorkspace, net_OnSetWorkspace);
522 m_net.OnGetIteration += net_OnGetIteration;
523
524 m_blobAccuracy = m_net.FindBlob("accuracy");
525 }
526 catch(Exception excpt)
527 {
528 throw new Exception("Initializing Training Net: " + excpt.Message);
529 }
530 }
531
532 private void net_OnSetWorkspace(object sender, WorkspaceArgs e)
533 {
534 if (e.WorkspaceSizeInBytes == 0)
535 return;
536
537 if (OnSetWorkspace != null)
538 {
539 OnSetWorkspace(sender, e);
540 return;
541 }
542
543 m_cuda.DisableGhostMemory();
544
545 if (e.WorkspaceSizeInBytes > m_lWorkspaceSizeInBytes)
546 {
547 m_lWorkspaceSizeInBytes = e.WorkspaceSizeInBytes;
548
549 if (m_hWorkspaceData != 0)
550 m_cuda.FreeMemory(m_hWorkspaceData);
551
552 ulong lCount = CudaDnn<T>.ConvertByteSizeToCount(m_lWorkspaceSizeInBytes);
553 m_hWorkspaceData = m_cuda.AllocMemory((long)lCount);
554 }
555
556 m_cuda.ResetGhostMemory();
557 }
558
559 private void net_OnGetWorkspace(object sender, WorkspaceArgs e)
560 {
561 if (OnGetWorkspace != null)
562 {
563 OnGetWorkspace(sender, e);
564 return;
565 }
566
567 e.WorkspaceData = m_hWorkspaceData;
568 e.WorkspaceSizeInBytes = m_lWorkspaceSizeInBytes;
569 }
570
571 private void net_OnGetIteration(object sender, GetIterationArgs e)
572 {
573 e.SetIteration(Phase.TRAIN, m_nIter);
574 }
575
579 protected void InitTestNets()
580 {
581 try
582 {
583 int num_generic_nets = ((m_param.net_param != null) ? 1 : 0);
584 int num_test_net_params = m_param.test_net_param.Count;
585 int num_test_nets = num_test_net_params;
586
587 if (num_generic_nets > 0)
588 m_log.CHECK_GE(m_param.test_iter.Count, num_test_nets, "test_iter must be specified fore each test network.");
589 else
590 m_log.CHECK_EQ(m_param.test_iter.Count, num_test_nets, "test_iter must be specified fore each test network.");
591
592 // If we have a generic net (specified by net or net_param, rather than
593 // test_net or test_net_param), we may have an unlimited number of actual
594 // test networks -- the actual number is given by the number of remaining
595 // test_iters after any test nets specified by test_net_param and/or test_net
596 // are evaluated.
597 int num_generic_net_instances = m_param.test_iter.Count - num_test_nets;
598 int num_test_net_instances = num_test_nets + num_generic_net_instances;
599
600 if (m_param.test_state.Count > 0)
601 m_log.CHECK_EQ(m_param.test_state.Count, num_test_net_instances, "test_state must be unspecified or specified once per test net.");
602
603 if (num_test_net_instances > 0)
604 m_log.CHECK_GT(m_param.test_interval, 0, "The test interval must be greater than zero.");
605
606 List<string> sources = new List<string>();
607 List<NetParameter> net_params = new List<NetParameter>();
608
609 for (int i = 0; i < num_test_net_params; i++)
610 {
611 sources.Add("test_net_param");
612 net_params.Add(m_param.test_net_param[i].Clone());
613 }
614
615 int remaining_test_nets = m_param.test_iter.Count - num_test_net_params;
616
617 if (m_param.net_param != null)
618 {
619 for (int i = 0; i < remaining_test_nets; i++)
620 {
621 sources.Add("net_param");
622 net_params.Add(m_param.net_param.Clone());
623 }
624 }
625
626 m_rgTestNets = new List<Net<T>>();
627
628 for (int i = 0; i < num_test_net_instances; i++)
629 {
630 // Set the correct NetState. We start with the solver defaults (lowest
631 // precedence); then, merge in any NetState specified by the net_param
632 // itself; finally, merge in any NetState specified by the test_state
633 // (highest precedence).
634 NetState net_state = new NetState();
635 net_state.phase = Phase.TEST;
636 net_state.MergeFrom(net_params[i].state);
637
638 if (m_param.test_state.Count > 0)
639 net_state.MergeFrom(m_param.test_state[i]);
640
641 net_params[i].state = net_state;
642
643 m_log.WriteLine("Creating test net (#" + i.ToString() + ") specified by " + sources[i], true);
644 Net<T> net = new Net<T>(m_cuda, m_log, net_params[i], m_evtCancel, m_db, Phase.NONE, null, TrainingNet, net_OnGetWorkspace, net_OnSetWorkspace);
645
646 m_rgTestNets.Add(net);
647 m_rgTestNets[i].set_debug_info(m_param.debug_info);
648 }
649 }
650 catch (Exception excpt)
651 {
652 throw new Exception("Initializing Testing Nets: " + excpt.Message);
653 }
654 }
655
660 {
661 get { return m_cuda; }
662 }
663
667 public string ActiveLabelCounts
668 {
669 get { return m_net.ActiveLabelCounts; }
670 }
671
676 {
677 get { return m_net.LabelQueryHitPercents; }
678 }
679
683 public string LabelQueryEpochs
684 {
685 get { return m_net.LabelQueryEpochs; }
686 }
687
692 {
693 get { return m_nIter; }
694 }
695
700 {
701 get { return m_param.max_iter; }
702 }
703
708 {
709 get
710 {
711 int nIters = m_param.max_iter - m_nIter;
712
713 if (m_nTrainingIterationOverride > 0)
714 nIters = m_nTrainingIterationOverride;
715
716 return nIters;
717 }
718 }
719
724 {
725 get
726 {
727 int nIters = (m_param.test_iter.Count == 0) ? 0 : m_param.test_iter[0];
728
729 if (m_nTestingIterationOverride > 0)
730 nIters = m_nTestingIterationOverride;
731
732 return nIters;
733 }
734 }
735
744 public virtual void Solve(int nIterationOverride = -1, byte[] rgWeights = null, byte[] rgState = null, TRAIN_STEP step = TRAIN_STEP.NONE)
745 {
746 m_log.CHECK(is_root_solver, "Solve is only supported by the root solver.");
747 m_log.WriteLine("Solving " + m_net.name);
748 m_log.WriteLine("Learing Rate Policy: " + m_param.lr_policy);
749
750 if (rgWeights != null || rgState != null)
751 Restore(rgWeights, rgState);
752
753 // For a network that is trained by the solver, no bottom or top vecs
754 // should be given, and we will just provide dummy vecs.
755 int start_iter = m_nIter;
756
757 if (nIterationOverride <= 0)
758 nIterationOverride = TrainingIterations;
759
760 if (!Step(nIterationOverride, step))
761 return;
762
763 // If we haven't already, save a snapshot after optimization, unless
764 // overriden by setting snapshot_after_train = false.
765 if (step == TRAIN_STEP.NONE && (m_param.snapshot_after_train && (m_param.snapshot == 0 || (m_nIter % m_param.snapshot) != 0)))
766 Snapshot(false, true);
767 else if (m_net.learnable_parameters.SnapshotRequested(true))
768 Snapshot(true, false);
769
770 if (m_evtCancel.WaitOne(0))
771 {
772 m_log.WriteLine("Optimization stopped early.");
773 return;
774 }
775
776 // After the optimization is done, run an additional train and test pass to
777 // display the train and test loss/outputs if appropriate (based on the
778 // display and test_interval settings, respectively). Unlike in the rest of
779 // training, for the train net we only run a forward pass as we've already
780 // updated the parameters 'max_iter' times -- this final pass is only done to
781 // display the loss, which is computed in the forward pass.
782 if (m_param.display > 0 && (m_nIter % m_param.display) == 0)
783 {
784 double dfLoss;
785 m_net.Forward(out dfLoss);
786
787 UpdateSmoothedLoss(dfLoss, start_iter);
788 m_log.WriteLine("Iteration " + m_nIter + ", loss = " + m_dfSmoothedLoss.ToString());
789 }
790
792 {
793 if (m_bEnableTest)
794 TestAll();
795 }
796
797 m_log.WriteLine("Optimization done.");
798
799 if (m_blobBatchInputData != null)
800 {
801 m_blobBatchInputData.Dispose();
802 m_blobBatchInputData = null;
803 }
804 }
805
818 public bool Step(int nIters, TRAIN_STEP step = TRAIN_STEP.NONE, bool bZeroDiffs = true, bool bApplyUpdates = true, bool bDisableOutput = false, bool bDisableProgress = false, double? dfLossOverride = null, bool? bAllowSnapshot = null)
819 {
820 Exception err = null;
821
822 try
823 {
824 BlobCollection<T> colBottom = new BlobCollection<T>();
825 int start_iter = m_nIter;
826 int stop_iter = m_nIter + nIters;
827
828 m_rgLosses.Clear();
830
831 // Break on first NaN is a debugging tool
832 // that causes the network to stop training
833 // right after a NaN is discovered either
834 // just after the forward pass or just
835 // after the backward pass.
836 m_net.EnableBreakOnFirstNaN = m_bEnableBreakOnNan && m_bEnableBlobDebugging;
837 m_net.EnableDetailedNanDetection = m_bEnableDetailedNanDetection & m_bEnableBlobDebugging;
838
839 Stopwatch sw = new Stopwatch();
840 sw.Start();
841
842 Stopwatch swTimeout = new Stopwatch();
843 swTimeout.Start();
844
845 while (m_nIter < stop_iter && !m_evtCompleted.WaitOne(0))
846 {
847 // zero-init the params.
848 if (bZeroDiffs)
849 m_net.ClearParamDiffs();
850
851 if (OnStart != null)
852 OnStart(this, new EventArgs());
853
854 if (step == TRAIN_STEP.NONE && (forceTest ||
855 (m_param.test_interval > 0 &&
856 (m_nIter % m_param.test_interval) == 0 &&
858 {
859 if (m_bEnableTest && is_root_solver)
860 m_dfLastAccuracy = TestAll();
861
862 // Break out of the while loop because a stop was requested while testing.
863 if (m_evtCancel.WaitOne(0))
864 break;
865 }
866
867 // on_start currently not used, so no event added.
868 bool bDisplay1 = (is_root_solver && m_param.display > 0 && (m_nIter % m_param.display) == 0 && !bDisableOutput) ? true : false;
869 m_net.set_debug_info(bDisplay1 && m_param.debug_info);
870
871 // accumulate the loss and gradient
872 double dfLoss = 0;
873 double dfLossTotal = 0;
874 double? dfAccuracyTotal = null;
875 int nIterCount = 0;
876
877 Stopwatch swTiming = new Stopwatch();
878 double dfTotalTime = 0;
879 int nTimingCount = 0;
880 bool bFwdPassNanFree = true;
881
882 for (int i = 0; i < m_param.iter_size; i++)
883 {
884 double dfLocalLoss;
885 double? dfLocalAccuracy = null;
886
887 swTiming.Restart();
888
889 if (OnCustomForwardBack != null)
890 {
892 OnCustomForwardBack(this, args);
893 bFwdPassNanFree = args.FwdPassNanFree;
894 dfLocalLoss = args.LocalLoss;
895 }
896 else
897 {
898 bFwdPassNanFree = m_net.ForwardBackward(colBottom, out dfLocalLoss, step);
899
900 if (m_blobAccuracy != null)
901 dfLocalAccuracy = Utility.ConvertVal<T>(m_blobAccuracy.GetData(0));
902 }
903
904 if (double.IsNaN(dfLocalLoss) || double.IsInfinity(dfLocalLoss))
905 {
906 if (m_bFirstNanError)
907 {
908 m_log.WriteError(new Exception("The local loss at iteration " + m_nIter.ToString() + " is invalid (NAN or INFINITY)!"));
909 m_bFirstNanError = false;
910 }
911 }
912
913 if (dfLocalAccuracy.HasValue)
914 {
915 if (!dfAccuracyTotal.HasValue)
916 dfAccuracyTotal = 0;
917
918 dfAccuracyTotal = dfAccuracyTotal + dfLocalAccuracy.Value;
919 }
920
921 dfLossTotal += dfLocalLoss;
922 swTiming.Stop();
923
924 dfTotalTime += swTiming.Elapsed.TotalMilliseconds;
925 nTimingCount++;
926 nIterCount++;
927
928 if (!bFwdPassNanFree)
929 break;
930 }
931
932 dfLoss = dfLossTotal / nIterCount;
933 dfLoss = dfLossOverride.GetValueOrDefault(dfLoss);
934
935 if (dfAccuracyTotal.HasValue)
936 m_dfIterAccuracy = dfAccuracyTotal.Value / nIterCount;
937
938 // average the loss across iterations for smoothed reporting
939 UpdateSmoothedLoss(dfLoss, start_iter);
940
941 bool bDisplay = false;
942 if (!bDisplay1 && sw.ElapsedMilliseconds > 2000 && !bDisableOutput)
943 {
944 bDisplay = true;
945 m_bFirstNanError = true;
946 sw.Restart();
947 }
948
949 if (bDisplay && bDisplay1)
950 {
951 m_log.WriteLine("Iteration " + m_nIter.ToString() + ", loss = " + m_dfSmoothedLoss.ToString());
952
953 BlobCollection<T> colResult = m_net.output_blobs;
954 int score_index = 0;
955
956 if (is_root_solver)
957 {
958 for (int j = 0; j < colResult.Count; j++)
959 {
960 double[] result_vec = Utility.ConvertVec<T>(colResult[j].update_cpu_data());
961 int nIdx = m_net.output_blob_indices[j];
962 string output_name = m_net.blob_names[nIdx];
963 double loss_weight = m_net.blob_loss_weights[nIdx];
964 double dfTotalLossWeight = 0;
965 int nResultCount = colResult[j].count();
966
967 for (int k = 0; k < nResultCount; k++)
968 {
970 {
971 string strOut = "";
972
973 if (loss_weight != 0)
974 strOut += " (* " + loss_weight.ToString() + " = " + (loss_weight * result_vec[k]).ToString() + " loss)";
975
976 m_log.WriteLine(" Train net output #" + score_index.ToString() + ": " + output_name + " = " + result_vec[k].ToString() + strOut);
977 score_index++;
978 }
979 else
980 {
981 dfTotalLossWeight += loss_weight * result_vec[k];
982 }
983 }
984
986 {
987 double dfAverage = dfTotalLossWeight / nResultCount;
988 m_log.WriteLine(" Average weighted score = " + dfAverage.ToString() + " for '" + output_name + "' - averaged over " + nResultCount.ToString("N0") + " results.");
989 }
990 }
991 }
992 }
993
994 if (OnGradientsReady != null && bFwdPassNanFree)
996
997 double dfLastLearningRate = 0;
998
999 if (step != TRAIN_STEP.FORWARD && bApplyUpdates)
1000 dfLastLearningRate = ApplyUpdate(m_nIter);
1001
1002 if (m_evtCancel.WaitOne(0))
1003 break;
1004
1005 if (!bDisableProgress)
1006 m_log.Progress = (double)m_nIter / (double)stop_iter;
1007
1008 bool bSnapshotTaken = false;
1009 bool bForceSnapshot = forceSnapshot;
1010
1011 if ((step == TRAIN_STEP.NONE || bAllowSnapshot.GetValueOrDefault(false)) && (is_root_solver && bFwdPassNanFree &&
1012 (bForceSnapshot ||
1013 (m_param.snapshot > 0 && (m_nIter % m_param.snapshot) == 0) ||
1014 (m_dfLastAccuracy > m_dfBestAccuracy))))
1015 {
1016 bSnapshotTaken = true;
1017 Snapshot(bForceSnapshot, ((m_param.snapshot > 0 && (m_nIter % m_param.snapshot) == 0)) ? true : false);
1018
1019 if (m_dfLastAccuracy > m_dfBestAccuracy)
1020 m_dfBestAccuracy = m_dfLastAccuracy;
1021 }
1022
1023 //-------------------------------------
1024 // Call the training iteration event
1025 // on the root solver.
1026 //-------------------------------------
1027 fireOnTrainingIterationEvent(bFwdPassNanFree, dfLoss, dfLastLearningRate, ref nTimingCount, ref dfTotalTime);
1028
1029 //-------------------------------------
1030 // If single stepping, stop the solver.
1031 //-------------------------------------
1032 if (step != TRAIN_STEP.NONE || m_bEnableSingleStep)
1033 {
1034 if (step == TRAIN_STEP.BOTH)
1035 {
1036 if (!bDisableOutput)
1037 m_log.WriteLine("Single step (both) triggered - solving stopped after a single forward/backward pass.");
1038 }
1039 else if (step == TRAIN_STEP.FORWARD)
1040 {
1041 if (!bDisableOutput)
1042 m_log.WriteLine("Single step (forward) triggered - solving stopped after a single forward pass.");
1043 }
1044 else if (step == TRAIN_STEP.BACKWARD)
1045 {
1046 if (!bDisableOutput)
1047 m_log.WriteLine("Single step (backward) triggered - solving stopped after a single backward pass.");
1048 }
1049 else
1050 {
1051 // When single stepping, force the snapshot so as to allow
1052 // debugging the net visually.
1053 if (!bSnapshotTaken)
1054 Snapshot(true, false);
1055 }
1056 break;
1057 }
1058
1059 //-------------------------------------
1060 // If a time-limit has been imposed
1061 // and we have exceeded it, stop
1062 // training.
1063 //-------------------------------------
1064 if (m_nTrainingTimeLimitInMinutes > 0 && swTimeout.Elapsed.TotalMinutes > m_nTrainingTimeLimitInMinutes)
1065 {
1066 m_log.WriteLine("A training time-limit of " + m_nTrainingTimeLimitInMinutes.ToString("N0") + " minutes has been exceeded - training will now stop.");
1067 return true;
1068 }
1069
1070 if (!bApplyUpdates)
1071 break;
1072 }
1073
1074 return true;
1075 }
1076 catch (Exception excpt)
1077 {
1078 err = excpt;
1079 throw excpt;
1080 }
1081 finally
1082 {
1083 if (err != null || m_evtCancel.WaitOne(0))
1084 {
1085 if (OnAborted != null)
1086 OnAborted(this, new EventArgs());
1087 }
1088 }
1089 }
1090
1097 public void Restore(byte[] rgWeights, byte[] rgState, string strSkipBlobTypes = null)
1098 {
1099 m_net.LoadWeights(rgWeights, m_persist, null, null, strSkipBlobTypes);
1100
1101 if (rgState != null)
1102 {
1103 m_log.WriteLine("Restoring previous solver state from restore state...");
1104 RestoreSolverState(rgState);
1105 }
1106 }
1107
1115 public void Snapshot(bool bForced, bool bScheduled, bool bUpdateDatabase = true)
1116 {
1117 m_log.WriteLine("Starting snap shot...");
1118 m_log.CHECK(is_root_solver, "Snapshot only supported on the root solver.");
1119
1120 if (OnSnapshot == null)
1121 return;
1122
1123 if (m_snapshotWeightUpdatemMethod == SNAPSHOT_WEIGHT_UPDATE_METHOD.DISABLED && !bForced)
1124 {
1125 m_log.WriteLine("WARNING: Snapshot UPDATE_METHOD = DISABLED.");
1126 return;
1127 }
1128
1129 SnapshotArgs args = GetSnapshotArgs(null, null, m_dfLastAccuracy, m_dfLastError, m_nIter, m_snapshotWeightUpdatemMethod);
1130 args.Forced = bForced;
1131 args.Scheduled = bScheduled;
1132 args.UpdateDatabase = bUpdateDatabase;
1133
1134 OnSnapshot(this, args);
1135 m_log.WriteLine("Snapshot completed.");
1136 }
1137
1138 private void args_OnGetWeights(object sender, GetBytesArgs e)
1139 {
1140 if (m_net != null)
1141 e.Data = m_net.SaveWeights(m_persist, m_param.snapshot_diff);
1142 }
1143
1144 private void args_OnGetState(object sender, GetBytesArgs e)
1145 {
1147 }
1148
1159 public SnapshotArgs GetSnapshotArgs(byte[] rgState, byte[] rgWeights, double dfAccuracy, double dfError, int nIteration, SNAPSHOT_WEIGHT_UPDATE_METHOD wtUpdt)
1160 {
1161 if (dfAccuracy == 0)
1162 dfAccuracy = 0.0001;
1163
1164 SnapshotArgs args = new SnapshotArgs(rgState, rgWeights, dfAccuracy, dfError, nIteration, wtUpdt);
1165
1168 args.SingleStep = m_bEnableSingleStep;
1169 args.OnGetState += args_OnGetState;
1170 args.OnGetWeights += args_OnGetWeights;
1171
1172 return args;
1173 }
1174
1179 {
1180 get { return m_nTrainingIterationOverride; }
1181 set { m_nTrainingIterationOverride = value; }
1182 }
1183
1188 {
1189 get { return m_nTestingIterationOverride; }
1190 set { m_nTestingIterationOverride = value; }
1191 }
1192
1196 public AutoResetEvent CompletedEvent
1197 {
1198 get { return m_evtCompleted; }
1199 }
1200
1205 {
1206 get { return m_evtCancel; }
1207 }
1208
1212 public double smoothed_loss
1213 {
1214 get { return m_dfSmoothedLoss; }
1215 }
1216
1221 {
1222 get { return m_param; }
1223 }
1224
1229 {
1230 get { return m_net; }
1231 }
1232
1236 public List<Net<T>> test_nets
1237 {
1238 get { return m_rgTestNets; }
1239 }
1240
1244 public int iter
1245 {
1246 get { return m_nIter; }
1247 }
1248
1253 {
1254 get { return m_param.type; }
1255 }
1256
1260 protected bool forceSnapshot
1261 {
1262 get
1263 {
1264 if (m_evtForceSnapshot == null)
1265 return false;
1266
1267 return m_evtForceSnapshot.WaitOne(0);
1268 }
1269 }
1270
1274 public bool forceTest
1275 {
1276 get
1277 {
1278 if (m_evtForceTest == null)
1279 return false;
1280
1281 m_bForceTest = m_evtForceTest.WaitOne(0);
1282 return m_bForceTest;
1283 }
1284 }
1285
1289 public int solver_count
1290 {
1291 get { return m_nSolverCount; }
1292 }
1293
1297 public int solver_rank
1298 {
1299 get { return m_nSolverRank; }
1300 }
1301
1308 public bool is_root_solver
1309 {
1310 get { return (m_nSolverRank == 0) ? true : false; }
1311 }
1312
1322 public double TestAll(int nIterationOverride = -1)
1323 {
1324 double dfTotalAccuracy = 0;
1325 double dfTotalTime = 0;
1326 int nTotalCount = 0;
1327
1328 for (int test_net_id = 0; test_net_id < m_rgTestNets.Count; test_net_id++)
1329 {
1330 if (m_evtCancel.WaitOne(0))
1331 return 0;
1332
1333 if (OnTest != null)
1334 {
1335 TestArgs args = new TestArgs(nIterationOverride, test_net_id);
1336 OnTest(this, args);
1337 dfTotalAccuracy += args.Accuracy;
1338 }
1339 else
1340 dfTotalAccuracy += testOne(nIterationOverride, test_net_id);
1341
1342 dfTotalTime += m_dfAverageTestTime;
1343 nTotalCount++;
1344 }
1345
1346 if (m_rgTestNets.Count == 0)
1347 {
1348 if (OnTest != null)
1349 {
1350 TestArgs args = new TestArgs(nIterationOverride, 0);
1351 OnTest(this, args);
1352 dfTotalAccuracy += args.Accuracy;
1353 }
1354 else
1355 dfTotalAccuracy += testOne(nIterationOverride, 0);
1356 }
1357
1358 double dfAccuracy = (m_rgTestNets.Count > 0) ? dfTotalAccuracy / m_rgTestNets.Count : 0;
1359
1360 if (m_rgAverageAccuracyWindow != null)
1361 {
1362 m_rgAverageAccuracyWindow.Add(dfAccuracy);
1363 m_rgAverageAccuracyWindow.RemoveAt(0);
1364 dfAccuracy = m_rgAverageAccuracyWindow.Average();
1365 }
1366
1367 if (OnTestingIteration != null)
1368 {
1369 double dfTime = (nTotalCount > 0) ? dfTotalTime / nTotalCount : 0;
1370 OnTestingIteration(this, new TestingIterationArgs<T>(m_nIter, dfAccuracy, dfTime));
1371 }
1372
1373 return dfAccuracy;
1374 }
1375
1376 private double testOne(int nIterationOverride = -1, int nTestNetId = 0)
1377 {
1378 switch (m_param.eval_type)
1379 {
1380 // Test SSD Detection
1381 case SolverParameter.EvaluationType.DETECTION:
1382 return TestDetection(nIterationOverride, nTestNetId);
1383
1384 // Perform regular classification Test.
1385 default:
1386 return TestClassification(nIterationOverride, nTestNetId);
1387 }
1388 }
1389
1396 public double TestDetection(int nIterationOverride = -1, int nTestNetId = 0)
1397 {
1398 Stopwatch sw = new Stopwatch();
1399 BBoxUtility<T> bboxUtil = new BBoxUtility<T>(m_cuda, m_log);
1400
1401 try
1402 {
1403 if (is_root_solver)
1404 m_log.WriteLine("Iteration " + m_nIter.ToString() + ", Testing net (#" + nTestNetId.ToString() + ")");
1405
1406 Net<T> test_net = m_net;
1407
1408 if (m_rgTestNets.Count > nTestNetId)
1409 {
1410 m_log.CHECK(m_rgTestNets[nTestNetId] != null, "The test net at " + nTestNetId.ToString() + " is null!");
1411 m_rgTestNets[nTestNetId].ShareTrainedLayersWith(m_net);
1412 test_net = m_rgTestNets[nTestNetId];
1413 }
1414
1415 Dictionary<int, Dictionary<int, List<Tuple<float, int>>>> rgAllTruePos = new Dictionary<int, Dictionary<int, List<Tuple<float, int>>>>();
1416 Dictionary<int, Dictionary<int, List<Tuple<float, int>>>> rgAllFalsePos = new Dictionary<int, Dictionary<int, List<Tuple<float, int>>>>();
1417 Dictionary<int, Dictionary<int, int>> rgAllNumPos = new Dictionary<int, Dictionary<int, int>>();
1418
1419 double dfLoss = 0;
1420
1421 if (nIterationOverride <= 0)
1422 nIterationOverride = TestingIterations;
1423
1424 int nIter = nIterationOverride;
1425 sw.Start();
1426
1427 for (int i = 0; i < nIter; i++)
1428 {
1429 // Check to see if stoppage of testing/training has been requested.
1430 if (m_evtCancel.WaitOne(0))
1431 break;
1432
1433 if (OnTestStart != null)
1434 OnTestStart(this, new EventArgs());
1435
1436 double iter_loss;
1437 BlobCollection<T> colResult = test_net.Forward(out iter_loss);
1438
1440 dfLoss += iter_loss;
1441
1442 for (int j = 0; j < colResult.Count; j++)
1443 {
1444 m_log.CHECK_EQ(colResult[j].width, 5, "The width must be = 5 for SSD.");
1445 double[] result_vec = Utility.ConvertVec<T>(colResult[j].update_cpu_data());
1446 int num_det = colResult[j].height;
1447
1448 for (int k = 0; k < num_det; k++)
1449 {
1450 int item_id = (int)result_vec[k * 5];
1451 int nLabel = (int)result_vec[k * 5 + 1];
1452
1453 // Special row for storing number of positives for a label.
1454 if (item_id == -1)
1455 {
1456 if (!rgAllNumPos.ContainsKey(j))
1457 rgAllNumPos.Add(j, new Dictionary<int, int>());
1458
1459 if (!rgAllNumPos[j].ContainsKey(nLabel))
1460 rgAllNumPos[j].Add(nLabel, (int)result_vec[k * 5 + 2]);
1461 else
1462 rgAllNumPos[j][nLabel] += (int)result_vec[k * 5 + 2];
1463 }
1464 // Normal row storing detection status.
1465 else
1466 {
1467 float fScore = (float)result_vec[k * 5 + 2];
1468 int tp = (int)result_vec[k * 5 + 3];
1469 int fp = (int)result_vec[k * 5 + 4];
1470
1471 // Ignore such case, which happens when a detection bbox is matched to
1472 // a difficult gt bbox and we don't evaluate on difficult gt bbox.
1473 if (tp == 0 && fp == 0)
1474 continue;
1475
1476 if (!rgAllTruePos.ContainsKey(j))
1477 rgAllTruePos.Add(j, new Dictionary<int, List<Tuple<float, int>>>());
1478
1479 if (!rgAllTruePos[j].ContainsKey(nLabel))
1480 rgAllTruePos[j].Add(nLabel, new List<Tuple<float, int>>());
1481
1482 if (!rgAllFalsePos.ContainsKey(j))
1483 rgAllFalsePos.Add(j, new Dictionary<int, List<Tuple<float, int>>>());
1484
1485 if (!rgAllFalsePos[j].ContainsKey(nLabel))
1486 rgAllFalsePos[j].Add(nLabel, new List<Tuple<float, int>>());
1487
1488 rgAllTruePos[j][nLabel].Add(new Tuple<float, int>(fScore, tp));
1489 rgAllFalsePos[j][nLabel].Add(new Tuple<float, int>(fScore, fp));
1490 }
1491 }
1492 }
1493
1494 if (sw.Elapsed.TotalMilliseconds > 1000)
1495 {
1496 m_log.Progress = (double)i / (double)nIter;
1497 m_log.WriteLine("Testing at " + m_log.Progress.ToString("P") + " " + i.ToString() + " of " + nIter.ToString() + "...");
1498 sw.Restart();
1499 }
1500 }
1501
1502 if (m_evtCancel.WaitOne(0))
1503 {
1504 m_log.WriteLine("Test interrupted.");
1505 return 0;
1506 }
1507
1509 {
1510 dfLoss /= m_param.test_iter[nTestNetId];
1511 m_log.WriteLine("Test loss: " + dfLoss.ToString());
1512 }
1513
1514 float fTotalmAP = 0;
1515 for (int i = 0; i < rgAllTruePos.Count; i++)
1516 {
1517 if (!rgAllTruePos.ContainsKey(i))
1518 m_log.FAIL("Missing output_blob true_pos: " + i.ToString());
1519
1520 Dictionary<int, List<Tuple<float, int>>> rgTruePos = rgAllTruePos[i];
1521
1522 if (!rgAllFalsePos.ContainsKey(i))
1523 m_log.FAIL("Missing output_blob false_pos: " + i.ToString());
1524
1525 Dictionary<int, List<Tuple<float, int>>> rgFalsePos = rgAllFalsePos[i];
1526
1527 if (!rgAllNumPos.ContainsKey(i))
1528 m_log.FAIL("Missing output_blob num_pos: " + i.ToString());
1529
1530 Dictionary<int, int> rgNumPos = rgAllNumPos[i];
1531
1532 Dictionary<int, float> rgAPs = new Dictionary<int, float>();
1533 float fmAP = 0.0f;
1534
1535 // Sort true_pos and false_pos with descending scores.
1536 foreach (KeyValuePair<int, int> kv in rgNumPos)
1537 {
1538 int nLabel = kv.Key;
1539 int nLabelNumPos = kv.Value;
1540
1541 if (!rgTruePos.ContainsKey(nLabel))
1542 {
1543 m_log.WriteLine("WARNING: Missing true_pos for label: " + nLabel.ToString() + "!");
1544 continue;
1545 }
1546 List<Tuple<float, int>> rgLabelTruePos = rgTruePos[nLabel];
1547
1548 if (!rgFalsePos.ContainsKey(nLabel))
1549 {
1550 m_log.WriteLine("WARNING: Missing false_pos for label: " + nLabel.ToString() + "!");
1551 continue;
1552 }
1553 List<Tuple<float, int>> rgLabelFalsePos = rgFalsePos[nLabel];
1554
1555 List<float> rgPrec;
1556 List<float> rgRec;
1557 float fAp = bboxUtil.ComputeAP(rgLabelTruePos, nLabelNumPos, rgLabelFalsePos, m_param.ap_version, out rgPrec, out rgRec);
1558
1559 if (!rgAPs.ContainsKey(nLabel))
1560 rgAPs.Add(nLabel, fAp);
1561 else
1562 rgAPs[nLabel] = fAp;
1563
1564 fmAP += fAp;
1565
1567 m_log.WriteLine("class " + nLabel.ToString() + ": " + fAp.ToString());
1568 }
1569
1570 fmAP /= rgNumPos.Count;
1571
1572 int nOutputBlobIdx = test_net.output_blob_indices[i];
1573 string strOutputName = test_net.blob_names[nOutputBlobIdx];
1574
1575 m_log.WriteLine(" Test net output #" + i.ToString() + ": " + strOutputName + " = " + fmAP.ToString());
1576 fTotalmAP += fmAP;
1577 }
1578
1579 return fTotalmAP / rgAllTruePos.Count;
1580 }
1581 catch (Exception excpt)
1582 {
1583 throw excpt;
1584 }
1585 finally
1586 {
1587 bboxUtil.Dispose();
1588 }
1589 }
1590
1597 public double TestClassification(int nIterationOverride = -1, int nTestNetId = 0)
1598 {
1599 bool bDisplay = (is_root_solver && m_param.display > 0 && (m_nIter % m_param.display) == 0) ? true : false;
1600
1601 if (m_bForceTest)
1602 {
1603 m_bForceTest = false;
1604 bDisplay = true;
1605 }
1606
1607 if (bDisplay)
1608 m_log.WriteLine("Iteration " + m_nIter.ToString() + ", Testing net (#" + nTestNetId.ToString() + ")");
1609
1610 Net<T> test_net = m_net;
1611
1612 if (m_rgTestNets.Count > nTestNetId)
1613 {
1614 m_log.CHECK(m_rgTestNets[nTestNetId] != null, "The test net at " + nTestNetId.ToString() + " is null!");
1615 m_rgTestNets[nTestNetId].ShareTrainedLayersWith(m_net);
1616 test_net = m_rgTestNets[nTestNetId];
1617 }
1618
1619 List<double> test_score = new List<double>();
1620 List<int> test_score_output_id = new List<int>();
1621 double dfLoss = 0;
1622
1623 if (nIterationOverride <= 0)
1624 nIterationOverride = TestingIterations;
1625
1626 int nIter = nIterationOverride;
1627
1628 Stopwatch sw = new Stopwatch();
1629 sw.Start();
1630
1631 double dfTotalTiming = 0;
1632 int nTestCount = 0;
1633 int nAccuracyIdx = 0;
1634 int nMinRank = int.MaxValue;
1635 bool bAccuracyValid = false;
1636 Stopwatch swTiming = new Stopwatch();
1637
1638 for (int i = 0; i < nIter; i++)
1639 {
1640 // Check to see if stoppage of testing/training has been requested.
1641 if (m_evtCancel.WaitOne(0))
1642 break;
1643
1644 if (OnTestStart != null)
1645 OnTestStart(this, new EventArgs());
1646
1647 swTiming.Restart();
1648
1649 double iter_loss;
1650 BlobCollection<T> colResult = test_net.Forward(out iter_loss);
1651
1653 dfLoss += iter_loss;
1654
1655 TestResultArgs<T> args = new TestResultArgs<T>(colResult);
1656 if (OnTestResults != null)
1657 {
1658 OnTestResults(this, args);
1659 if (args.AccuracyValid)
1660 {
1661 test_score.Add(args.Accuracy);
1662 test_score_output_id.Add(1);
1663 bAccuracyValid = true;
1664 }
1665 }
1666
1667 if (!args.AccuracyValid)
1668 {
1669 if (i == 0)
1670 {
1671 for (int j = 0; j < colResult.Count; j++)
1672 {
1673 double[] result_vec = Utility.ConvertVec<T>(colResult[j].update_cpu_data());
1674
1675 for (int k = 0; k < colResult[j].count(); k++)
1676 {
1677 test_score.Add(result_vec[k]);
1678 test_score_output_id.Add(j);
1679 }
1680
1681 if (colResult[j].type == BLOB_TYPE.ACCURACY)
1682 {
1683 int nRank = (int)getNumber(colResult[j].Tag, 0);
1684 if (nRank < nMinRank)
1685 {
1686 nMinRank = nRank;
1687 nAccuracyIdx = j;
1688 }
1689 }
1690 }
1691 }
1692 else
1693 {
1694 int idx = 0;
1695
1696 for (int j = 0; j < colResult.Count; j++)
1697 {
1698 double[] result_vec = Utility.ConvertVec<T>(colResult[j].update_cpu_data());
1699
1700 for (int k = 0; k < colResult[j].count(); k++)
1701 {
1702 test_score[idx] += result_vec[k];
1703 idx++;
1704 }
1705 }
1706 }
1707 }
1708
1709 swTiming.Stop();
1710 dfTotalTiming += swTiming.Elapsed.TotalMilliseconds;
1711 nTestCount++;
1712
1713 if (sw.ElapsedMilliseconds > 2000)
1714 {
1715 double dfPct = (double)i / (double)nIter;
1716
1717 if (bDisplay)
1718 {
1719 m_log.Progress = dfPct;
1720 m_log.WriteLine("Testing '" + test_net.name + "' at " + dfPct.ToString("P"));
1721 }
1722
1723 sw.Restart();
1724 }
1725 }
1726
1727 m_dfAverageTestTime = (nTestCount > 0) ? dfTotalTiming / nTestCount : 0;
1728
1729 if (m_evtCancel.WaitOne(0))
1730 {
1731 m_log.WriteLine("Test interrupted.");
1732 return 0;
1733 }
1734
1736 {
1737 dfLoss /= m_param.test_iter[nTestNetId];
1738 m_log.WriteLine("Test loss: " + dfLoss.ToString());
1739 }
1740
1741 double dfFinalScore = 0;
1742
1743 if (bAccuracyValid)
1744 {
1745 dfFinalScore = test_score.Sum();
1746 int nTotal = test_score_output_id.Sum();
1747 dfFinalScore /= nTotal;
1748 }
1749 else
1750 {
1751 for (int i = 0; i < test_score.Count; i++)
1752 {
1753 int nIdxTestScore = test_score_output_id[i];
1754 int output_blob_index = test_net.output_blob_indices[nIdxTestScore];
1755 string output_name = test_net.blob_names[output_blob_index];
1756 double loss_weight = test_net.blob_loss_weights[output_blob_index];
1757 double dfMeanScore = test_score[i] / nIter;
1758 string strOut = "";
1759
1760 if (bDisplay)
1761 {
1762 if (loss_weight != 0)
1763 strOut += " (* " + loss_weight.ToString() + " = " + (loss_weight * dfMeanScore).ToString() + " loss)";
1764
1765 m_log.WriteLine(" Test net output #" + i.ToString() + ": " + output_name + " = " + dfMeanScore.ToString() + strOut);
1766 }
1767
1768 if (i == nAccuracyIdx)
1769 dfFinalScore = dfMeanScore;
1770 }
1771 }
1772
1773 if (test_score.Count == 0)
1774 return 0;
1775
1776 return dfFinalScore;
1777 }
1778
1779 private double getNumber(object value, double dfDefault)
1780 {
1781 if (value == null)
1782 return dfDefault;
1783
1784 if (value is sbyte)
1785 return (double)(sbyte)value;
1786
1787 if (value is byte)
1788 return (double)(byte)value;
1789
1790 if (value is short)
1791 return (double)(short)value;
1792
1793 if (value is ushort)
1794 return (double)(ushort)value;
1795
1796 if (value is int)
1797 return (double)(int)value;
1798
1799 if (value is uint)
1800 return (double)(uint)value;
1801
1802 if (value is long)
1803 return (double)(long)value;
1804
1805 if (value is ulong)
1806 return (double)(ulong)value;
1807
1808 if (value is float)
1809 return (double)(float)value;
1810
1811 if (value is double)
1812 return (double)value;
1813
1814 if (value is decimal)
1815 return (double)(decimal)value;
1816
1817 return dfDefault;
1818 }
1819
1826 public void UpdateSmoothedLoss(double dfLoss, int nStartIter, int nAverageLoss = 0)
1827 {
1828 if (nAverageLoss == 0)
1829 nAverageLoss = m_param.average_loss;
1830
1831 if (m_rgLosses.Count < nAverageLoss)
1832 {
1833 m_rgLosses.Add(dfLoss);
1834 int nCount = m_rgLosses.Count;
1835 m_dfSmoothedLoss = (m_dfSmoothedLoss * (nCount - 1) + dfLoss) / nCount;
1836 }
1837 else
1838 {
1839 int nIdx = (m_nIter - nStartIter) % nAverageLoss;
1840 m_dfSmoothedLoss += (dfLoss - m_rgLosses[nIdx]) / nAverageLoss;
1841 m_rgLosses[nIdx] = dfLoss;
1842 }
1843
1844 if (m_bWeightsUpdated)
1845 {
1846 m_dfSmoothedLoss = dfLoss;
1847 m_bWeightsUpdated = false;
1848 }
1849
1850 m_dfLastError = m_dfSmoothedLoss;
1851
1852 if (m_dfLastError < m_dfBestError)
1853 m_dfBestError = m_dfLastError;
1854 }
1855
1860 public abstract double ApplyUpdate(int nIterationOverride = -1);
1861
1865 protected abstract byte[] SnapshotSolverState();
1866
1870 protected abstract void RestoreSolverState(byte[] rgState);
1871
1889 public static SGDSolver<T> Create(CudaDnn<T> cuda, Log log, ProjectEx p, CancelEvent evtCancel, AutoResetEvent evtForceSnapshot, AutoResetEvent evtForceTest, IXDatabaseBase db, IXPersist<T> persist, int nSolverCount = 1, int nSolverRank = 0, Net<T> shareNet = null, onGetWorkspace getws = null, onSetWorkspace setws = null)
1890 {
1891 SolverParameter solverParam = null;
1892
1893 if (p.SolverDescription != null)
1894 {
1895 RawProto protoSolver = RawProto.Parse(p.SolverDescription);
1896 solverParam = SolverParameter.FromProto(protoSolver);
1897 }
1898 else
1899 {
1900 solverParam = new param.SolverParameter();
1901 }
1902
1903 if (solverParam.net_param == null)
1904 {
1905 RawProto protoModel = RawProto.Parse(p.ModelDescription);
1906 solverParam.net_param = NetParameter.FromProto(protoModel);
1907 solverParam.net_param.ProjectID = p.ID;
1908 }
1909
1910 return Create(cuda, log, solverParam, evtCancel, evtForceSnapshot, evtForceTest, db, persist, nSolverCount, nSolverRank, shareNet, getws, setws);
1911 }
1912
1930 public static SGDSolver<T> Create(CudaDnn<T> cuda, Log log, SolverParameter solverParam, CancelEvent evtCancel, AutoResetEvent evtForceSnapshot, AutoResetEvent evtForceTest, IXDatabaseBase db, IXPersist<T> persist, int nSolverCount = 1, int nSolverRank = 0, Net<T> shareNet = null, onGetWorkspace getws = null, onSetWorkspace setws = null)
1931 {
1932 SGDSolver<T> solver = null;
1933
1934 switch (solverParam.type)
1935 {
1936 case SolverParameter.SolverType.SGD:
1937 solver = new SGDSolver<T>(cuda, log, solverParam, evtCancel, evtForceSnapshot, evtForceTest, db, persist, nSolverCount, nSolverRank, shareNet, getws, setws);
1938 break;
1939
1940 case SolverParameter.SolverType.NESTEROV:
1941 solver = new NesterovSolver<T>(cuda, log, solverParam, evtCancel, evtForceSnapshot, evtForceTest, db, persist, nSolverCount, nSolverRank, shareNet, getws, setws);
1942 break;
1943
1944 case SolverParameter.SolverType.ADAGRAD:
1945 solver = new AdaGradSolver<T>(cuda, log, solverParam, evtCancel, evtForceSnapshot, evtForceTest, db, persist, nSolverCount, nSolverRank, shareNet, getws, setws);
1946 break;
1947
1948 case SolverParameter.SolverType.ADADELTA:
1949 solver = new AdaDeltaSolver<T>(cuda, log, solverParam, evtCancel, evtForceSnapshot, evtForceTest, db, persist, nSolverCount, nSolverRank, shareNet, getws, setws);
1950 break;
1951
1952 case SolverParameter.SolverType.ADAM:
1953 solver = new AdamSolver<T>(cuda, log, solverParam, evtCancel, evtForceSnapshot, evtForceTest, db, persist, nSolverCount, nSolverRank, shareNet, getws, setws);
1954 break;
1955
1956 case SolverParameter.SolverType.ADAMW:
1957 solver = new AdamWSolver<T>(cuda, log, solverParam, evtCancel, evtForceSnapshot, evtForceTest, db, persist, nSolverCount, nSolverRank, shareNet, getws, setws);
1958 break;
1959
1960 case SolverParameter.SolverType.RMSPROP:
1961 solver = new RmsPropSolver<T>(cuda, log, solverParam, evtCancel, evtForceSnapshot, evtForceTest, db, persist, nSolverCount, nSolverRank, shareNet, getws, setws);
1962 break;
1963
1964 default:
1965 throw new NotImplementedException("The solver " + solverParam.type.ToString() + " is not implemented yet!");
1966 }
1967
1968 return solver;
1969 }
1970 }
1971
1972#pragma warning disable 1591
1973
1974 public class OutputCollection
1975 {
1976 OutputDataCollection m_rgError = new OutputDataCollection();
1977 OutputDataCollection m_rgAccuracy = new OutputDataCollection();
1978
1979 public OutputCollection()
1980 {
1981 }
1982
1983 public OutputDataCollection Errors
1984 {
1985 get { return m_rgError; }
1986 }
1987
1988 public OutputDataCollection Accuracies
1989 {
1990 get { return m_rgAccuracy; }
1991 }
1992 }
1993
1994 public class OutputDataCollection : IEnumerable<OutputData>
1995 {
1996 List<OutputData> m_rgData = new List<OutputData>();
1997
1998 public OutputDataCollection()
1999 {
2000 }
2001
2002 public List<OutputData> Data
2003 {
2004 get { return m_rgData; }
2005 }
2006
2007 public int Count
2008 {
2009 get { return m_rgData.Count; }
2010 }
2011
2012 public OutputData this[int nIdx]
2013 {
2014 get { return m_rgData[nIdx]; }
2015 set { m_rgData[nIdx] = value; }
2016 }
2017
2018 public void Add(int nTotal, string strName, int nIdx, double dfVal)
2019 {
2020 OutputData data = Find(strName);
2021
2022 if (data == null)
2023 {
2024 data = new OutputData(strName, nIdx);
2025 m_rgData.Add(data);
2026 }
2027
2028 data.Add(nTotal, dfVal);
2029 }
2030
2031 public OutputData Find(string strName)
2032 {
2033 foreach (OutputData data in m_rgData)
2034 {
2035 if (data.Name == strName)
2036 return data;
2037 }
2038
2039 return null;
2040 }
2041
2042 public IEnumerator<OutputData> GetEnumerator()
2043 {
2044 return m_rgData.GetEnumerator();
2045 }
2046
2047 IEnumerator IEnumerable.GetEnumerator()
2048 {
2049 return m_rgData.GetEnumerator();
2050 }
2051 }
2052
2053 public class OutputData
2054 {
2055 string m_strName;
2056 double m_dfValue = 0;
2057 int m_nIdx;
2058
2059 public OutputData(string strName, int nIdx)
2060 {
2061 m_strName = strName;
2062 m_nIdx = nIdx;
2063 }
2064
2065 public int Index
2066 {
2067 get { return m_nIdx; }
2068 }
2069
2070 public string Name
2071 {
2072 get { return m_strName; }
2073 }
2074
2075 public double Value
2076 {
2077 get { return m_dfValue; }
2078 set { m_dfValue = value; }
2079 }
2080
2081 public void Add(int nTotal, double dfVal)
2082 {
2083 double dfRatio = 1.0 / (double)nTotal;
2084 m_dfValue = (m_dfValue * (1.0 - dfRatio)) + (dfRatio * dfVal);
2085 }
2086 }
2087
2088#pragma warning restore 1591
2089}
The CancelEvent provides an extension to the manual cancel event that allows for overriding the manua...
Definition: CancelEvent.cs:17
bool WaitOne(int nMs=int.MaxValue)
Waits for the signal state to occur.
Definition: CancelEvent.cs:290
The Log class provides general output in text form.
Definition: Log.cs:13
void CHECK(bool b, string str)
Test a flag for true.
Definition: Log.cs:227
bool IsEnabled
Returns whether or not the Log is enabled.
Definition: Log.cs:50
void WriteLine(string str, bool bOverrideEnabled=false, bool bHeader=false, bool bError=false, bool bDisable=false)
Write a line of output.
Definition: Log.cs:80
bool Enable
Enables/disables the Log. When disabled, the Log does not output any data.
Definition: Log.cs:42
void FAIL(string str)
Causes a failure which throws an exception with the desciptive text.
Definition: Log.cs:394
double Progress
Get/set the progress associated with the Log.
Definition: Log.cs:147
void CHECK_EQ(double df1, double df2, string str)
Test whether one number is equal to another.
Definition: Log.cs:239
void WriteError(Exception e)
Write an error as output.
Definition: Log.cs:130
void CHECK_GT(double df1, double df2, string str)
Test whether one number is greater than another.
Definition: Log.cs:299
void CHECK_LE(double df1, double df2, string str)
Test whether one number is less than or equal to another.
Definition: Log.cs:263
void CHECK_GE(double df1, double df2, string str)
Test whether one number is greater than or equal to another.
Definition: Log.cs:287
The ProjectEx class manages a project containing the solver description, model description,...
Definition: ProjectEx.cs:15
string? SolverDescription
Get/set the solver description script used by the Project.
Definition: ProjectEx.cs:726
int ID
Returns the ID of the Project in the database.
Definition: ProjectEx.cs:533
string? ModelDescription
Get/set the model description script used by the Project.
Definition: ProjectEx.cs:757
The RawProto class is used to parse and output Google prototxt file data.
Definition: RawProto.cs:17
static RawProto Parse(string str)
Parses a prototxt and places it in a new RawProto.
Definition: RawProto.cs:306
The Utility class provides general utility funtions.
Definition: Utility.cs:35
static double[] ConvertVec(float[] rgf)
Convert an array of float to an array of generics.
Definition: Utility.cs:550
The BBox class processes the NormalizedBBox data used with SSD.
Definition: BBoxUtility.cs:22
void Dispose()
Clean up all resources.
Definition: BBoxUtility.cs:43
float ComputeAP(List< Tuple< float, int > > rgTp, int nNumPos, List< Tuple< float, int > > rgFp, ApVersion apVersion, out List< float > rgPrec, out List< float > rgRec)
Compute the average precision given true positive and false positive vectors.
Definition: BBoxUtility.cs:69
The BlobCollection contains a list of Blobs.
void Add(Blob< T > b)
Add a new Blob to the collection.
int Count
Returns the number of items in the collection.
The Blob is the main holder of data that moves through the Layers of the Net.
Definition: Blob.cs:25
T GetData(int nIdx)
Returns the data at a given flat index within the Blob.
Definition: Blob.cs:1893
virtual void Dispose(bool bDisposing)
Releases all resources used by the Blob (including both GPU and Host).
Definition: Blob.cs:402
The CudaDnn object is the main interface to the Low-Level Cuda C++ DLL.
Definition: CudaDnn.cs:969
static ulong ConvertByteSizeToCount(ulong ulSizeInBytes)
Converts the byte size into the number of items in the base data type of float or double.
Definition: CudaDnn.cs:2438
The CustomForwardBackArgs provide the arguments to the OnCustomForwardBack event within the Solver St...
Definition: EventArgs.cs:609
double LocalLoss
Get/set the local loss of the pass.
Definition: EventArgs.cs:655
bool FwdPassNanFree
Get/set whether or a NAN was detected in the forward pass.
Definition: EventArgs.cs:646
The DebugInformation contains information used to help debug the Layers of a Net while it is training...
string DetectFirstNaN(out string strType)
Searches for the first NaN within any of the Layers.
string DetectLastNaN(out string strType)
Searches for the last NaN within any of the Layers.
The GetBytesArgs is passed along to the SnapshotArgs::OnGetWeights and SnapshotArgs::OnGetState event...
Definition: EventArgs.cs:392
byte[] Data
Get/set the data as an array of bytes.
Definition: EventArgs.cs:406
The GetIterationArgs is sent bubbled up to the solver when a layer needs to know the curret training ...
Definition: EventArgs.cs:748
void SetIteration(Phase p, int nIteration)
The SetIteration method is used to set the iteration and the phase.
Definition: EventArgs.cs:764
The GradientsReadyArgs is sent to the Solver::OnGradientsReady event which fires at the end of each S...
Definition: EventArgs.cs:734
Connects Layer's together into a direct acrylic graph (DAG) specified by a NetParameter
Definition: Net.cs:23
BlobCollection< T > Forward()
Run forward with the input Blob's already fed separately.
Definition: Net.cs:1445
List< string > blob_names
Returns the blob names.
Definition: Net.cs:1987
List< double > blob_loss_weights
Returns the collection of blob loss weights.
Definition: Net.cs:2069
string name
Returns the network name.
Definition: Net.cs:1971
List< int > output_blob_indices
Returns a list of the output Blob indexes.
Definition: Net.cs:2217
The SnapshotArgs is sent to the Solver::OnSnapshot event which fires each time the Solver::Snapshot m...
Definition: EventArgs.cs:416
bool Forced
Get/set whether or not the snapshot was forced or not.
Definition: EventArgs.cs:580
bool SingleStep
Get/set the Solver single step.
Definition: EventArgs.cs:571
bool IncludeWeights
Get/set whether or not to include the weights in the snapshot.
Definition: EventArgs.cs:553
bool Scheduled
Get/set whether or not the snapshot is a regular scheduled snapshot (e.g. not an improved accuracy or...
Definition: EventArgs.cs:589
bool IncludeState
Get/set whether or not to include the Solver state in the snapshot.
Definition: EventArgs.cs:562
EventHandler< GetBytesArgs > OnGetState
Specifies the OnGetState event which fires when the SnapshotArgs::UpdateState method is called.
Definition: EventArgs.cs:444
bool UpdateDatabase
Get/set whether or not to update the database (default = true).
Definition: EventArgs.cs:598
EventHandler< GetBytesArgs > OnGetWeights
Specifies the OnGetWeights event which fires when the SnapshotArgs::UpdateWeights method is called.
Definition: EventArgs.cs:437
The TestArgs are passed to the Solver::OnTest event.
Definition: EventArgs.cs:169
double Accuracy
Get/set the accuracy for the test run. When overriding the testing, the override should set the accur...
Definition: EventArgs.cs:205
The TestResultArgs are passed to the Solver::OnTestResults event.
Definition: EventArgs.cs:116
bool AccuracyValid
Get/set the accuracy valid flag. When not valid, the OnTestResults event is ignored.
Definition: EventArgs.cs:156
double Accuracy
Get/set the accuracy. The recipient of this event should set this value.
Definition: EventArgs.cs:143
Specifies the TestingIterationArgs sent to the Solver::OnTestingIteration, which is called at the end...
Definition: EventArgs.cs:216
The TrainingIterationArgs is sent to the Solver::OnTrainingIteration event that fires at the end of a...
Definition: EventArgs.cs:264
The WorkspaceArgs are passed to both the Layer::OnSetWorkspace and Layer::OnGetWorkspace events.
Definition: EventArgs.cs:17
long WorkspaceData
Get/set the handle to workspace data in GPU memory.
Definition: EventArgs.cs:36
ulong WorkspaceSizeInBytes
Get/set the workspace memory size in bytes.
Definition: EventArgs.cs:45
The Database class manages the actual connection to the physical database using Entity Framworks from...
Definition: Database.cs:23
Specifies the parameters use to create a Net
Definition: NetParameter.cs:18
static NetParameter FromProto(RawProto rp)
Parse a RawProto into a new instance of the parameter.
NetState state
The current 'state' of the network, including the phase, level and stage. Some layers may be included...
int ProjectID
Specifies the ID of the project that created this net param (if any).
Definition: NetParameter.cs:80
int solver_rank
Specifies the rank of the solver using this network.
int solver_count
Specifies the number of solvers used in a multi-gpu training session.
NetParameter Clone(bool bCloneLayers=true, int? nSolverCount=null, int? nSolverRank=null)
Creates a new copy of this instance of the parameter.
Specifies the NetState which includes the phase, level and stage for which a given Net is to run unde...
Definition: NetState.cs:19
Phase phase
Specifies the Phase of the NetState.
Definition: NetState.cs:63
void MergeFrom(NetState ns)
Merges another NetState with this instance.
Definition: NetState.cs:98
The SolverParameter is a parameter for the solver, specifying the train and test networks.
int max_iter
The maximum number of iterations.
List< int > test_iter
The number of iterations for each test.
NetParameter net_param
Inline train net param, possibly combined with one or more test nets.
bool debug_info
If true, print information about the state of the net that may help with debugging learning problems.
NetParameter train_net_param
Inline train net param, possibly combined with one or more test nets.
List< NetState > test_state
The states for the train/test nets. Must be unspecified or specified once per net.
SolverType
Defines the type of solver.
string lr_policy
The learning rate decay policy.
static SolverParameter FromProto(RawProto rp)
Parses a new SolverParameter from a RawProto.
ApVersion ap_version
Specifies the AP Version to use for average precision when using Single-Shot Detection (SSD) - (defau...
long random_seed
If non-negative, the seed with which the Solver will initialize the caffe random number generator – u...
int average_loss
Display the loss averaged over the last average_loss iterations.
int test_interval
The number of iterations between two testing phases.
bool output_average_results
Specifies to average loss results before they are output - this can be faster when there are a lot of...
int iter_size
Accumulate gradients over 'iter_size' x 'batch_size' instances.
string DebugString()
Returns a debug string for the SolverParameter.
EvaluationType
Defines the evaluation method used in the SSD algorithm.
bool snapshot_after_train
If false, don't save a snapshot after training finishes.
bool snapshot_include_weights
Specifies whether or not the snapshot includes the trained weights. The default = true.
bool test_compute_loss
Test the compute loss.
SolverParameter()
The SolverParameter constructor.
EvaluationType eval_type
Specifies the evaluation type to use when using Single-Shot Detection (SSD) - (default = NONE,...
bool test_initialization
If true, run an initial test pass before the first iteration, ensuring memory availability and printi...
List< NetParameter > test_net_param
Inline test net params.
int display
The number of iterations between displaying info. If display = 0, no info will be displayed.
bool snapshot_diff
Whether to snapshot diff in the results or not. Snapshotting diff will help debugging but the final p...
bool snapshot_include_state
Specifies whether or not the snapshot includes the solver state. The default = false....
bool show_per_class_result
Specifies whether or not to display results per class when using Single-Shot Detection (SSD) - (defau...
int accuracy_average_window
Specifies the window over which to average the accuracies (default = 0 which ignores averaging).
int snapshot
Specifies the snapshot interval.
SolverType type
Specifies the solver type.
NetState train_state
The states for the train/test nets. Must be unspecified or specified once per net.
Use AdaDelta Solver which has gradient based optimization like SGD.
Use AdaGrad Solver based optimization like SGD that tries to find rarely seen features.
Use Adam Solver which uses gradient based optimization like SGD that includes 'adaptive momentum esti...
Definition: AdamSolver.cs:22
Use AdamW Solver which uses gradient based optimization like Adam with a decoupled weight decay.
Definition: AdamWSolver.cs:23
Use Nesterov's accelerated gradient Solver, which is similar to SGD, but the error gradient is comput...
Use RmsProp Solver which uses gradient based optimization like SGD.
Stochastic Gradient Descent solver with momentum updates weights by a linear combination of the negat...
Definition: SGDSolver.cs:22
An interface for classes that perform optimization on Nets - this class serves as the base class for ...
Definition: Solver.cs:28
List< Net< T > > m_rgTestNets
Specifies the testing Nets.
Definition: Solver.cs:48
int TrainingIterations
Returns the current training iterations remaining.
Definition: Solver.cs:708
void InitTestNets()
Initializes the Net used by the Solver for testing.
Definition: Solver.cs:579
EventHandler< CustomForwardBackArgs< T > > OnCustomForwardBack
The OnCustomForwardBack allows for overriding the forward/backward operations within the solver.
Definition: Solver.cs:155
int m_nSolverCount
Specifies the Solver count in a multi-GPU training session.
Definition: Solver.cs:82
void Dispose()
Discards the resources (GPU and Host) used by this Solver.
Definition: Solver.cs:218
double m_dfSmoothedLoss
Specifies the smoothed loss protected for derived classes to use.
Definition: Solver.cs:70
SolverParameter m_param
Specifies the SolverParameter that defines how the Solver operates.
Definition: Solver.cs:40
EventHandler< TrainingIterationArgs< T > > OnTrainingIteration
The OnTrainingIteration event fires at the end of each training iteration.
Definition: Solver.cs:134
List< double > m_rgLosses
Specifies the Losses used to calculate the smoothed Loss.
Definition: Solver.cs:60
abstract byte[] SnapshotSolverState()
Save the current solver state.
static SGDSolver< T > Create(CudaDnn< T > cuda, Log log, ProjectEx p, CancelEvent evtCancel, AutoResetEvent evtForceSnapshot, AutoResetEvent evtForceTest, IXDatabaseBase db, IXPersist< T > persist, int nSolverCount=1, int nSolverRank=0, Net< T > shareNet=null, onGetWorkspace getws=null, onSetWorkspace setws=null)
Create a new Solver based on the project containing the SolverParameter.
Definition: Solver.cs:1889
double smoothed_loss
Returns the smoothed loss.
Definition: Solver.cs:1213
void Restore(byte[] rgWeights, byte[] rgState, string strSkipBlobTypes=null)
The restore method simply calls the RestoreSolverState method of the inherited class.
Definition: Solver.cs:1097
int iter
Returns the current training iteration.
Definition: Solver.cs:1245
CudaDnn< T > m_cuda
Specifies the instance of CudaDnn used by the Solver that provides a connection to Cuda.
Definition: Solver.cs:32
void Snapshot(bool bForced, bool bScheduled, bool bUpdateDatabase=true)
The snapshot function implements the basic snapshotting utility that stores the learned net....
Definition: Solver.cs:1115
int MaximumIteration
Returns the maximum training iterations.
Definition: Solver.cs:700
double? m_dfIterAccuracy
Specifies the iteration accuracy calculated when a blob exists with the name 'accuracy'.
Definition: Solver.cs:74
Solver(CudaDnn< T > cuda, Log log, SolverParameter p, CancelEvent evtCancel, AutoResetEvent evtForceSnapshot, AutoResetEvent evtForceTest, IXDatabaseBase db, IXPersist< T > persist, int nSolverCount=1, int nSolverRank=0, Net< T > shareNet=null, onGetWorkspace getws=null, onSetWorkspace setws=null)
The Solver constructor.
Definition: Solver.cs:181
EventHandler< SnapshotArgs > OnSnapshot
The OnSnapshot event fires when the Solver detects that a snapshot is needed.
Definition: Solver.cs:130
bool EnableBlobDebugging
When enabled, the OnTrainingIteration event is set extra debugging information describing the state o...
Definition: Solver.cs:361
SolverParameter.SolverType type
Returns the type of solver.
Definition: Solver.cs:1253
Net< T > net
Returns the main training Net.
Definition: Solver.cs:1229
bool ForceOnTrainingIterationEvent()
Force an OnTrainingIterationEvent to fire.
Definition: Solver.cs:236
object Tag
Returns a generic tag associated with the Solver.
Definition: Solver.cs:422
double TestDetection(int nIterationOverride=-1, int nTestNetId=0)
Run an SSD detection test on a given test Net by running it through its iterations.
Definition: Solver.cs:1396
bool? is_root_solver
Returns whether or not this is the root solver.
Definition: Solver.cs:1309
double LearningRateOverride
Get/set the learning rate override. When 0, this setting is ignored.
Definition: Solver.cs:227
bool EnableTesting
When enabled, the training cycle calls TestAll periodically based on the SolverParameter....
Definition: Solver.cs:352
int m_nIter
Specifies the current iteration.
Definition: Solver.cs:52
Net< T > TrainingNet
Returns the training Net used by the solver.
Definition: Solver.cs:445
double m_dfLearningRateOverride
Optionally, specifies a learning rate override (default = 0, which ignores this setting).
Definition: Solver.cs:94
EventHandler< TestingIterationArgs< T > > OnTestingIteration
The OnTestingIteration event fires at the end of each testing iteration.
Definition: Solver.cs:138
void InitTrainNet(Net< T > shareNet=null)
Initializes the Net used by the solver for training.
Definition: Solver.cs:488
abstract void RestoreSolverState(byte[] rgState)
Restore a solver state.
void UpdateSmoothedLoss(double dfLoss, int nStartIter, int nAverageLoss=0)
Update the avaraged loss value.
Definition: Solver.cs:1826
void Init(SolverParameter p, Net< T > shareNet=null)
Initializes the Solver.
Definition: Solver.cs:454
bool EnableBreakOnFirstNaN
When enabled (requires EnableBlobDebugging = true), the Solver immediately stop training upon detecti...
Definition: Solver.cs:382
int solver_count
Returns the solver count in a multi-GPU session.
Definition: Solver.cs:1290
SolverParameter parameter
Returns the SolverParameter used.
Definition: Solver.cs:1221
bool forceSnapshot
Returns whether or not a snapshot has been forced.
Definition: Solver.cs:1261
SNAPSHOT_WEIGHT_UPDATE_METHOD SnapshotWeightUpdateMethod
Get/set the snapshot weight update method.
Definition: Solver.cs:301
EventHandler OnAborted
The OnAborted event fires after aborting a training cycle.
Definition: Solver.cs:122
List< Net< T > > test_nets
Returns the testing Nets.
Definition: Solver.cs:1237
IXPersist< T > m_persist
Specifies the persistance object used to save weight and solver states.
Definition: Solver.cs:90
int TrainingTimeLimitInMinutes
Get/set the training time limit in minutes. When set to 0, no time limit is imposed on training.
Definition: Solver.cs:292
EventHandler< WorkspaceArgs > OnGetWorkspace
Specifies the OnGetWorkspace event that fires when the getWorkspace() function is called by a layer t...
Definition: Solver.cs:159
double TestClassification(int nIterationOverride=-1, int nTestNetId=0)
Run a test on a given test Net by running it through its iterations.
Definition: Solver.cs:1597
void Reset()
Reset the iterations of the net.
Definition: Solver.cs:478
bool Step(int nIters, TRAIN_STEP step=TRAIN_STEP.NONE, bool bZeroDiffs=true, bool bApplyUpdates=true, bool bDisableOutput=false, bool bDisableProgress=false, double? dfLossOverride=null, bool? bAllowSnapshot=null)
Steps a set of iterations through a training cycle.
Definition: Solver.cs:818
double TestAll(int nIterationOverride=-1)
Run a TestAll by running all test Nets.
Definition: Solver.cs:1322
string LabelQueryEpochs
Return the label query epochs for the active datasource.
Definition: Solver.cs:684
EventHandler< TestResultArgs< T > > OnTestResults
When specified, the OnTestResults event fires after each single test run. The recipient is responsibl...
Definition: Solver.cs:142
EventHandler< GradientsReadyArgs > OnGradientsReady
The OnGradientsReady event fires after the gradients of a Solver are ready for distribution to other ...
Definition: Solver.cs:126
EventHandler< WorkspaceArgs > OnSetWorkspace
Specifies the OnSetWorkspace event that fires when the setWorkspace() function is called by a layer t...
Definition: Solver.cs:163
int? TestingIterations
Returns the current testing iterations remaining.
Definition: Solver.cs:724
static SGDSolver< T > Create(CudaDnn< T > cuda, Log log, SolverParameter solverParam, CancelEvent evtCancel, AutoResetEvent evtForceSnapshot, AutoResetEvent evtForceTest, IXDatabaseBase db, IXPersist< T > persist, int nSolverCount=1, int nSolverRank=0, Net< T > shareNet=null, onGetWorkspace getws=null, onSetWorkspace setws=null)
Create a new Solver based on the project containing the SolverParameter.
Definition: Solver.cs:1930
bool forceTest
Returns whether or not a test has been forced.
Definition: Solver.cs:1275
int TrainingIterationOverride
Get/set the training iteration override.
Definition: Solver.cs:1179
EventHandler OnTestStart
The OnTestStart event fires at the start of each testing iteration.
Definition: Solver.cs:150
Net< T > m_net
Specifies the training Net.
Definition: Solver.cs:44
bool WeightsUpdated
Get/set when the weights have been updated.
Definition: Solver.cs:413
int m_nCurrentStep
Specifies the current step.
Definition: Solver.cs:56
int solver_rank
Returns this Solver's rank in a multi-GPU session.
Definition: Solver.cs:1298
bool EnableDetailedNanDetection
When enabled (requires EnableBlobDebugging = true), the detailed Nan (and Infinity) detection is pero...
Definition: Solver.cs:395
Log m_log
Specifies the Log for output.
Definition: Solver.cs:36
SnapshotArgs GetSnapshotArgs(byte[] rgState, byte[] rgWeights, double dfAccuracy, double dfError, int nIteration, SNAPSHOT_WEIGHT_UPDATE_METHOD wtUpdt)
The GetSnapshotArgs method fills out a snapshot args structure.
Definition: Solver.cs:1159
virtual void dispose()
Override that allows discarding of resources (GPU and Host) used by this Solver.
Definition: Solver.cs:317
EventHandler< TestArgs > OnTest
When specified, the OnTest event fires during a TestAll and overrides the call to Test.
Definition: Solver.cs:146
int TestingIterationOverride
Get/set the testing iteration override.
Definition: Solver.cs:1188
virtual void Solve(int nIterationOverride=-1, byte[] rgWeights=null, byte[] rgState=null, TRAIN_STEP step=TRAIN_STEP.NONE)
The main entry of the solver function. In default, iter will be zero. Pass in a non-zero iter number ...
Definition: Solver.cs:744
EventHandler OnStart
The OnStart event fires at the start of each training iteration.
Definition: Solver.cs:118
string ActiveLabelCounts
Returns a string describing the labels detected in the training along with the % that each label has ...
Definition: Solver.cs:668
AutoResetEvent CompletedEvent
Returns an auto reset event that is set upon training completion.
Definition: Solver.cs:1197
abstract double ApplyUpdate(int nIterationOverride=-1)
Make and apply the update value for the current iteration.
CudaDnn< T > Cuda
Returns the CudaDnn instance used by the Solver.
Definition: Solver.cs:660
bool EnableSingleStep
When enabled (requires EnableBlobDebugging = true), the Solver only runs one training cycle.
Definition: Solver.cs:404
int m_nSolverRank
Specifies the Solver rank of this solver, where rank == 0 is the root Solver.
Definition: Solver.cs:86
string LabelQueryHitPercents
Return the label query hit percentages for the active datasource.
Definition: Solver.cs:676
Net< T > TestingNet
Returns the testing Net used by the solver.
Definition: Solver.cs:431
bool EnableLayerDebugging
Enable/disable layer debugging which causes each layer to check for NAN/INF on each forward/backward ...
Definition: Solver.cs:373
int CurrentIteration
Returns the current training iteration.
Definition: Solver.cs:692
The IXDatabaseBase interface defines the general interface to the in-memory database.
Definition: Interfaces.cs:444
The IXPersist interface is used by the CaffeControl to load and save weights.
Definition: Interfaces.cs:187
The MyCaffe.basecode contains all generic types used throughout MyCaffe.
Definition: Annotation.cs:12
Phase
Defines the Phase under which to run a Net.
Definition: Interfaces.cs:61
SNAPSHOT_WEIGHT_UPDATE_METHOD
Defines the snapshot weight update method.
Definition: Interfaces.cs:181
The MyCaffe.common namespace contains common MyCaffe classes.
Definition: BatchInput.cs:8
BLOB_TYPE
Defines the tpe of data held by a given Blob.
Definition: Interfaces.cs:62
TRAIN_STEP
Defines the training stepping method (if any).
Definition: Interfaces.cs:131
The MyCaffe.db.image namespace contains all image database related classes.
Definition: Database.cs:18
The MyCaffe.param namespace contains parameters used to create models.
The MyCaffe.solvers namespace contains all solver classes, including the base Solver.
The MyCaffe namespace contains the main body of MyCaffe code that closesly tracks the C++ Caffe open-...
Definition: Annotation.cs:12