MyCaffe  1.12.2.41
Deep learning software for Windows C# programmers.
SolverParameter.cs
1using System;
2using System.Collections.Generic;
3using System.Linq;
4using System.Text;
6using MyCaffe.basecode;
7using MyCaffe.common;
8using MyCaffe.param.ui;
9
10namespace MyCaffe.param
11{
29 [Serializable]
30 [TypeConverter(typeof(ExpandableObjectConverter))]
32 {
33 NetParameter m_paramNet = null;
34 NetParameter m_paramTrainNet = null;
35 List<NetParameter> m_rgTestNets = new List<NetParameter>();
36 NetState m_stateTrain = null;
37 List<NetState> m_rgStateTest = new List<NetState>();
38 List<int> m_rgTestIter = new List<int>() { 300 };
39 int m_nTestInterval = 1000;
40 bool m_bTestComputeLoss = false;
41 bool m_bTestInitialization = true;
42 double m_dfBaseLR = 0.01;
43 int m_nDisplay = 0;
44 int m_nAverageLoss = 1;
45 int m_nMaxIter = 500000;
46 int m_nIterSize = 1;
47 string m_strLrPolicy = "step";
48 double m_dfGamma = 0.1;
49 double m_dfPower;
50 double m_dfMomentum = 0.0;
51 double m_dfWeightDecay = 0.0005;
52 string m_strRegularizationType = "L2";
53 int m_nStepSize = 100000;
54 List<int> m_rgnStepValue = new List<int>();
55 double m_dfClipGradients = -1;
56 int m_nSnapshot = 10000;
57 string m_strSnapshotPrefix = "";
58 bool m_bSnapshotDiff = false;
59 SnapshotFormat m_snapshotFormat = SnapshotFormat.BINARYPROTO;
60 int m_nDeviceID = 1;
61 long m_lRandomSeed = -1;
62 SolverType m_solverType = SolverType.SGD;
63 int m_lbfgs_corrections = 100;
64 double m_dfDelta = 1e-8;
65 double m_dfMomentum2 = 0.999;
66 double m_dfRmsDecay = 0.95;
67 double m_dfAdamWDecay = 0.1;
68 bool m_bDebugInfo = false;
69 bool m_bSnapshotAfterTrain = false;
70 string m_strCustomTrainer = null;
71 string m_strCustomTrainerProperties = null;
72 bool m_bOutputAverageResults = false;
73 bool m_bSnapshotIncludeWeights = true;
74 bool m_bSnapshotIncludeState = true;
75 int m_nAverageAccuracyWindow = 0;
76 bool m_bEnableClipGradientOutput = false;
77
78 // SSD Parameters
79 EvaluationType m_evalType = EvaluationType.CLASSIFICATION;
80 ApVersion m_apVersion = ApVersion.INTEGRAL;
81 bool m_bShowPerClassResult = false;
82
86 public enum EvaluationType
87 {
91 CLASSIFICATION,
95 DETECTION
96 }
97
101 public enum SnapshotFormat
102 {
106 BINARYPROTO = 1
107 }
108
112 public enum SolverType
113 {
120 SGD = 0,
128 NESTEROV = 1,
135 ADAGRAD = 2,
143 RMSPROP = 3,
150 ADADELTA = 4,
157 ADAM = 5,
164 LBFGS = 6,
172 ADAMW = 7,
173#pragma warning disable 1591
174 _MAX = 8
175#pragma warning restore 1591
176 }
177
182 {
186 FIXED,
190 STEP,
194 EXP,
198 INV,
202 MULTISTEP,
206 POLY,
210 SIGMOID
211 }
212
217 {
221 NONE,
225 L1,
229 L2
230 }
231
236 : base()
237 {
238 }
239
245 {
247
248 return p;
249 }
250
254 [Description("Specifies to average loss results before they are output - this can be faster when there are a lot of results in a cycle.")]
256 {
257 get { return m_bOutputAverageResults; }
258 set { m_bOutputAverageResults = value; }
259 }
260
267 [Description("Specifies the custom trainer (if any) used by an external process to provide customized training.")]
268 public string custom_trainer
269 {
270 get { return m_strCustomTrainer; }
271 set { m_strCustomTrainer = value; }
272 }
273
280 [Description("Specifies the custom trainer properties (if any) used by an external process to provide the properties for a customized training.")]
281 [Browsable(true)]
282 [EditorAttribute(typeof(DictionaryParamEditor), typeof(System.Drawing.Design.UITypeEditor))]
284 {
285 get { return m_strCustomTrainerProperties; }
286 set { m_strCustomTrainerProperties = Utility.Replace(value, ' ', "[sp]"); }
287 }
288
292 [Browsable(false)]
294 {
295 get { return m_paramNet; }
296 set { m_paramNet = value; }
297 }
298
302 [Browsable(false)]
304 {
305 get { return m_paramTrainNet; }
306 set { m_paramTrainNet = value; }
307 }
308
312 [Browsable(false)]
313 public List<NetParameter> test_net_param
314 {
315 get { return m_rgTestNets; }
316 set { m_rgTestNets = value; }
317 }
318
329 [Browsable(false)]
331 {
332 get { return m_stateTrain; }
333 set { m_stateTrain = value; }
334 }
335
346 [Browsable(false)]
347 public List<NetState> test_state
348 {
349 get { return m_rgStateTest; }
350 set { m_rgStateTest = value; }
351 }
352
356 [Category("Iterations")]
357 [Description("Specifies the number of iterations for each test.")]
358 public List<int> test_iter
359 {
360 get { return m_rgTestIter; }
361 set { m_rgTestIter = value; }
362 }
363
367 [Category("Iterations")]
368 [Description("Specifies the number of iterations between two testing phases.")]
369 public int test_interval
370 {
371 get { return m_nTestInterval; }
372 set { m_nTestInterval = value; }
373 }
374
378 [Description("Specifies whether or not to test the compute loss.")]
380 {
381 get { return m_bTestComputeLoss; }
382 set { m_bTestComputeLoss = value; }
383 }
384
389 [Category("Iterations")]
390 [Description("If true, run an initial test pass before the first iteration, ensuring memory availability and printing the starting value of the loss.")]
392 {
393 get { return m_bTestInitialization; }
394 set { m_bTestInitialization = value; }
395 }
396
400 [Description("Specifies the base learning rate (default = 0.01).")]
401 public double base_lr
402 {
403 get { return m_dfBaseLR; }
404 set { m_dfBaseLR = value; }
405 }
406
411 [Category("Iterations")]
412 [Description("Specifies the number of iterations between displaying information. If display == 0, no information will be displayed.")]
413 public int display
414 {
415 get { return m_nDisplay; }
416 set { m_nDisplay = value; }
417 }
418
422 [Description("Specifies the loss averaged over the last 'average_loss' iterations.")]
423 public int average_loss
424 {
425 get { return m_nAverageLoss; }
426 set { m_nAverageLoss = value; }
427 }
428
432 [Category("Iterations")]
433 [Description("Specifies the maximum number of iterations.")]
434 public int max_iter
435 {
436 get { return m_nMaxIter; }
437 set { m_nMaxIter = value; }
438 }
439
443 [Category("Iterations")]
444 [Description("Specifies to accumulate gradients over 'iter_size' x 'batch_size' instances.")]
445 public int iter_size
446 {
447 get { return m_nIterSize; }
448 set { m_nIterSize = value; }
449 }
450
469 [Category("Learning Policy")]
470 [DisplayName("lr_policy")]
471 [Description("Specifies the learning rate decay policy. \n 'fixed' - always return base_lr. \n 'step' - return base_lr * gamma ^ (floor(iter/step)). \n 'exp' - return base_lr * gamma ^ iter. \n 'inv' - return base_lr * (1 + gamma * iter) ^ (-power)." +
472 "\n 'multistep' - similar to 'step' but allows non-uniform steps defined by stepvalue. \n 'poly' - the effective learning rate follows a polynomial decay, to be zero by the max_iter, returns base_lr * (1 - iter/max_iter) ^ (power)." +
473 "\n 'sigmoid' - the effective learning rate follows a sigmoid decay, returns base_lr * (1/(1 + exp(-gamma * (iter - stepsize)))).")]
475 {
476 get
477 {
478 switch (m_strLrPolicy)
479 {
480 case "fixed":
481 return LearningRatePolicyType.FIXED;
482
483 case "step":
484 return LearningRatePolicyType.STEP;
485
486 case "exp":
487 return LearningRatePolicyType.EXP;
488
489 case "inv":
490 return LearningRatePolicyType.INV;
491
492 case "multistep":
493 return LearningRatePolicyType.MULTISTEP;
494
495 case "sigmoid":
496 return LearningRatePolicyType.SIGMOID;
497
498 case "poly":
499 return LearningRatePolicyType.POLY;
500
501 default:
502 throw new Exception("Unknown learning rate policy '" + m_strLrPolicy + "'");
503 }
504 }
505 set
506 {
507 switch (value)
508 {
509 case LearningRatePolicyType.FIXED:
510 m_strLrPolicy = "fixed";
511 break;
512
513 case LearningRatePolicyType.STEP:
514 m_strLrPolicy = "step";
515 break;
516
517 case LearningRatePolicyType.EXP:
518 m_strLrPolicy = "exp";
519 break;
520
521 case LearningRatePolicyType.INV:
522 m_strLrPolicy = "inv";
523 break;
524
525 case LearningRatePolicyType.MULTISTEP:
526 m_strLrPolicy = "multistep";
527 break;
528
529 case LearningRatePolicyType.SIGMOID:
530 m_strLrPolicy = "sigmoid";
531 break;
532
533 case LearningRatePolicyType.POLY:
534 m_strLrPolicy = "poly";
535 break;
536
537 default:
538 throw new Exception("Unknown learning rate policy '" + value.ToString() + "'.");
539 }
540 }
541 }
542
563 [Browsable(false)]
564 public string lr_policy
565 {
566 get { return m_strLrPolicy; }
567 set { m_strLrPolicy = value; }
568 }
569
573 [Category("Learning Policy")]
574 [Description("Specifies the 'gamma' parameter to compute the 'step', 'exp', 'inv', and 'sigmoid' learning policy (default = 0.1).")]
575 public double gamma
576 {
577 get { return m_dfGamma; }
578 set { m_dfGamma = value; }
579 }
580
584 [Category("Learning Policy")]
585 [Description("Specifies the 'power' parameter to compute the 'inv' and 'poly' learning policy.")]
586 public double power
587 {
588 get { return m_dfPower; }
589 set { m_dfPower = value; }
590 }
591
596 [Category("Solver - Not AdaGrad or RMSProp")]
597 [Description("Specifies the momentum value - used by all solvers EXCEPT the 'AdaGrad' and 'RMSProp' solvers. For these latter solvers, momentum should = 0.")]
598 public double momentum
599 {
600 get { return m_dfMomentum; }
601 set { m_dfMomentum = value; }
602 }
603
607 [Description("Specifies the weight decay (default = 0.0005).")]
608 public double weight_decay
609 {
610 get { return m_dfWeightDecay; }
611 set { m_dfWeightDecay = value; }
612 }
613
621 [DisplayName("regularization_type")]
622 [Description("Specifies the regularization type (default = L2). The regulation types supported are 'L1' and 'L2' controlled by weight decay.")]
624 {
625 get
626 {
627 switch (m_strRegularizationType)
628 {
629 case "NONE":
630 return RegularizationType.NONE;
631
632 case "L1":
633 return RegularizationType.L1;
634
635 case "L2":
636 return RegularizationType.L2;
637
638 default:
639 throw new Exception("Unknown regularization type '" + m_strRegularizationType + "'");
640 }
641 }
642 set
643 {
644 switch (value)
645 {
646 case RegularizationType.NONE:
647 m_strRegularizationType = "NONE";
648 break;
649
650 case RegularizationType.L1:
651 m_strRegularizationType = "L1";
652 break;
653
654 case RegularizationType.L2:
655 m_strRegularizationType = "L2";
656 break;
657
658 default:
659 throw new Exception("Unknown regularization type '" + value.ToString() + "'");
660 }
661 }
662 }
663
671 [Description("Specifies the regularization type (default = 'L2'). The regulation types supported are 'L1' and 'L2' controlled by weight decay.")]
672 [Browsable(false)]
674 {
675 get { return m_strRegularizationType; }
676 set { m_strRegularizationType = value; }
677 }
678
682 [Category("Learning Policy")]
683 [Description("Specifies the stepsize for the learning rate policy 'step'.")]
684 public int stepsize
685 {
686 get { return m_nStepSize; }
687 set { m_nStepSize = value; }
688 }
689
693 [Category("Learning Policy")]
694 [Description("Specifies the step values for the learning rate policy 'multistep'.")]
695 public List<int> stepvalue
696 {
697 get { return m_rgnStepValue; }
698 set { m_rgnStepValue = value; }
699 }
700
705 [Description("Set 'clip_gradients' to >= 0 to clip parameter gradients to that L2 norm, whenever their actual LT norm is larger.")]
706 public double clip_gradients
707 {
708 get { return m_dfClipGradients; }
709 set { m_dfClipGradients = value; }
710 }
711
715 [Description("Optionally, enable/disable output status when gradients are clipped (default = true).")]
717 {
718 get { return m_bEnableClipGradientOutput; }
719 set { m_bEnableClipGradientOutput = value; }
720 }
721
725 [Category("Snapshot")]
726 [Description("Sepcifies the snapshot interval.")]
727 public int snapshot
728 {
729 get { return m_nSnapshot; }
730 set { m_nSnapshot = value; }
731 }
732
736 [Description("Specifies the snapshot prefix.")]
737 [Browsable(false)]
738 public string snapshot_prefix
739 {
740 get { return m_strSnapshotPrefix; }
741 set { m_strSnapshotPrefix = value; }
742 }
743
748 [Category("Snapshot")]
749 [Description("Specifies whether ot snapshot diff in the results or not. Snapshotting diff may help debugging but the final snapshot data size will be much larger.")]
750 public bool snapshot_diff
751 {
752 get { return m_bSnapshotDiff; }
753 set { m_bSnapshotDiff = value; }
754 }
755
762 [Description("Specifies the snapshot format.")]
763 [Browsable(false)]
765 {
766 get { return m_snapshotFormat; }
767 set { m_snapshotFormat = value; }
768 }
769
773 [Category("Snapshot")]
774 [Description("Specifies whether or not the snapshot includes the trained weights. The default = 'true'.")]
776 {
777 get { return m_bSnapshotIncludeWeights; }
778 set { m_bSnapshotIncludeWeights = value; }
779 }
780
784 [Category("Snapshot")]
785 [Description("Specifies whether or not the snapshot includes the solver state. The default = 'false'. Including the solver state will slow down the time of each snapshot.")]
787 {
788 get { return m_bSnapshotIncludeState; }
789 set { m_bSnapshotIncludeState = value; }
790 }
791
795 [Description("Specifies the device ID that will be used when run on the GPU.")]
796 [Browsable(false)]
797 public int device_id
798 {
799 get { return m_nDeviceID; }
800 set { m_nDeviceID = value; }
801 }
802
808 [Description("If non-negative, the seed with which the Solver will initialize the caffe random number generator -- useful for reproducible results. Otherwise (and by default), initialize using a seed derived from the system clock.")]
809 public long random_seed
810 {
811 get { return m_lRandomSeed; }
812 set { m_lRandomSeed = value; }
813 }
814
818 [Category("Solver")]
819 [Description("Specifies the solver type. \n" +
820 " SGD - stochastic gradient descent with momentum updates weights by a linear combination of the negative gradient and the previous weight update. \n" +
821 " NESTEROV - Nesterov's accelerated gradient, similar to SGD, but error gradient is computed on the weights with added momentum. \n" +
822 " ADADELTA - Gradient based optimization like SGD, see M. Zeiler 'Adadelta, An adaptive learning rate method', arXiv preprint, 2012. \n" +
823 " ADAGRAD - Gradient based optimization like SGD that tries to find rarely seen features, see Duchi, E, and Y. Singer, 'Adaptive subgradient methods for online learning and stochastic optimization', The Journal of Machine Learning Research, 2011. \n" +
824 " ADAM - Gradient based optimization like SGD that includes 'adaptive momentum estimation' and can be thougth of as a generalization of AdaGrad, see D. Kingma, J. Ba, 'Adam: A method for stochastic optimization', Intl' Conference for Learning Representations, 2015. \n" +
825 " RMSPROP - Gradient based optimization like SGD, see T. Tieleman, and G. Hinton, 'RMSProp: Divide the gradient by a runnign average of its recent magnitude', COURSERA: Neural Networks for Machine Learning. Technical Report, 2012. \n" +
826 " LBGFS - Gradient based on minFunc, see Marc Schmidt 'minFunc'")]
828 {
829 get { return m_solverType; }
830 set { m_solverType = value; }
831 }
832
836 [Category("Solver - Ada, Adam and RMSProp")]
837 [Description("Specifies the numerical stability for 'RMSProp', 'AdaGrad', 'AdaDelta' and 'Adam' solvers (default = 1e-08).")]
838 public double delta
839 {
840 get { return m_dfDelta; }
841 set { m_dfDelta = value; }
842 }
843
847 [Category("Solver - Adam")]
848 [Description("Specifies an additional momentum property used by the 'Adam' and 'AdamW' solvers (default = 0.999).")]
849 public double momentum2
850 {
851 get { return m_dfMomentum2; }
852 set { m_dfMomentum2 = value; }
853 }
854
861 [Category("Solver - RMSProp")]
862 [Description("Specifies the 'RMSProp' decay value used by the 'RMSProp' solver (default = 0.95). MeanSquare(t) = 'rms_decay' * MeanSquare(t-1) + (1 - 'rms_decay') * SquareGradient(t). The 'rms_decay' is only used by the 'RMSProp' solver.")]
863 public double rms_decay
864 {
865 get { return m_dfRmsDecay; }
866 set { m_dfRmsDecay = value; }
867 }
868
875 [Category("Solver - AdamW")]
876 [Description("Specifies the 'AdamW' detached weight decay value used by the 'AdamW' solver (default = 0.1).")]
877 public double adamw_decay
878 {
879 get { return m_dfAdamWDecay; }
880 set { m_dfAdamWDecay = value; }
881 }
882
887 [Category("Debug")]
888 [Description("If true, print information about the sate of the net that may help with debugging learning problems.")]
889 public bool debug_info
890 {
891 get { return m_bDebugInfo; }
892 set { m_bDebugInfo = value; }
893 }
894
898 [Category("Solver - L-BGFS")]
899 [Description("Specifies the 'L-BGFS' corrections.")]
901 {
902 get { return m_lbfgs_corrections; }
903 set { m_lbfgs_corrections = value; }
904 }
905
909 [Category("Snapshot")]
910 [Description("If false, don't save a snapshot after training finishes.")]
912 {
913 get { return m_bSnapshotAfterTrain; }
914 set { m_bSnapshotAfterTrain = value; }
915 }
916
920 [Category("SSD")]
921 [Description("Specifies the evaluation type to use when using Single-Shot Detection (SSD) - (default = NONE, SSD not used).")]
923 {
924 get { return m_evalType; }
925 set { m_evalType = value; }
926 }
927
931 [Category("SSD")]
932 [Description("Specifies the AP Version to use for average precision when using Single-Shot Detection (SSD) - (default = INTEGRAL).")]
934 {
935 get { return m_apVersion; }
936 set { m_apVersion = value; }
937 }
938
942 [Category("SSD")]
943 [Description("Specifies whether or not to display results per class when using Single-Shot Detection (SSD) - (default = false).")]
945 {
946 get { return m_bShowPerClassResult; }
947 set { m_bShowPerClassResult = value; }
948 }
949
953 [Description("Specifies the window over which to average the accuracies (default = 0, which ignores the averaging).")]
955 {
956 get { return m_nAverageAccuracyWindow; }
957 set { m_nAverageAccuracyWindow = value; }
958 }
959
965 public override RawProto ToProto(string strName)
966 {
967 RawProtoCollection rgChildren = new RawProtoCollection();
968
969 if (net_param != null)
970 rgChildren.Add(net_param.ToProto("net_param"));
971
972 if (train_net_param != null)
973 rgChildren.Add(train_net_param.ToProto("train_net_param"));
974
975 foreach (NetParameter np in test_net_param)
976 {
977 rgChildren.Add(np.ToProto("test_net_param"));
978 }
979
980 if (train_state != null)
981 rgChildren.Add(train_state.ToProto("train_state"));
982
983 foreach (NetState ns in test_state)
984 {
985 rgChildren.Add(ns.ToProto("test_state"));
986 }
987
988 rgChildren.Add<int>("test_iter", test_iter);
989 rgChildren.Add("test_interval", test_interval.ToString());
990 rgChildren.Add("test_compute_loss", test_compute_loss.ToString());
991 rgChildren.Add("test_initialization", test_initialization.ToString());
992 rgChildren.Add("base_lr", base_lr.ToString());
993 rgChildren.Add("display", display.ToString());
994 rgChildren.Add("average_loss", average_loss.ToString());
995 rgChildren.Add("max_iter", max_iter.ToString());
996
997 if (iter_size != 1)
998 rgChildren.Add("iter_size", iter_size.ToString());
999
1000 rgChildren.Add("lr_policy", lr_policy);
1001
1002 if (lr_policy == "step" || lr_policy == "exp" || lr_policy == "inv" || lr_policy == "sigmoid")
1003 rgChildren.Add("gamma", gamma.ToString());
1004
1005 if (lr_policy == "inv" || lr_policy == "poly")
1006 rgChildren.Add("power", power.ToString());
1007
1008 rgChildren.Add("momentum", momentum.ToString());
1009 rgChildren.Add("weight_decay", weight_decay.ToString());
1010 rgChildren.Add("regularization_type", regularization_type);
1011
1012 if (lr_policy == "step")
1013 rgChildren.Add("stepsize", stepsize.ToString());
1014
1015 if (lr_policy == "multistep")
1016 rgChildren.Add<int>("stepvalue", stepvalue);
1017
1018 if (clip_gradients >= 0)
1019 {
1020 rgChildren.Add("clip_gradients", clip_gradients.ToString());
1021 rgChildren.Add("enable_clip_gradient_status", enable_clip_gradient_status.ToString());
1022 }
1023
1024 rgChildren.Add("snapshot", snapshot.ToString());
1025
1026 if (snapshot_prefix.Length > 0)
1027 rgChildren.Add("snapshot_prefix", snapshot_prefix);
1028
1029 if (snapshot_diff != false)
1030 rgChildren.Add("snapshot_diff", snapshot_diff.ToString());
1031
1032 rgChildren.Add("snapshot_format", snapshot_format.ToString());
1033 rgChildren.Add("device_id", device_id.ToString());
1034
1035 if (random_seed >= 0)
1036 rgChildren.Add("ransom_seed", random_seed.ToString());
1037
1038 rgChildren.Add("type", type.ToString());
1039
1040 if (type == SolverType.RMSPROP || type == SolverType.ADAGRAD || type == SolverType.ADADELTA || type == SolverType.ADAM || type == SolverType.ADAMW)
1041 rgChildren.Add("delta", delta.ToString());
1042
1043 if (type == SolverType.ADAM || type == SolverType.ADAMW)
1044 rgChildren.Add("momentum2", momentum2.ToString());
1045
1046 if (type == SolverType.RMSPROP)
1047 rgChildren.Add("rms_decay", rms_decay.ToString());
1048
1049 if (type == SolverType.ADAMW)
1050 rgChildren.Add("adamw_decay", adamw_decay.ToString());
1051
1052 if (type == SolverType.LBFGS)
1053 rgChildren.Add("lbgfs_corrections", lbgfs_corrections.ToString());
1054
1055 if (debug_info != false)
1056 rgChildren.Add("debug_info", debug_info.ToString());
1057
1058 if (snapshot_after_train != false)
1059 rgChildren.Add("snapshot_after_train", snapshot_after_train.ToString());
1060
1061 if (!string.IsNullOrEmpty(custom_trainer))
1062 rgChildren.Add("custom_trainer", custom_trainer);
1063
1064 if (!string.IsNullOrEmpty(custom_trainer_properties))
1065 rgChildren.Add("custom_trainer_properties", custom_trainer_properties);
1066
1067 if (output_average_results != false)
1068 rgChildren.Add("output_average_results", output_average_results.ToString());
1069
1070 rgChildren.Add("snapshot_include_weights", snapshot_include_weights.ToString());
1071 rgChildren.Add("snapshot_include_state", snapshot_include_state.ToString());
1072
1073 // SSD Parameters
1074 rgChildren.Add("eval_type", eval_type.ToString().ToLower());
1075
1076 if (ap_version == ApVersion.ELEVENPOINT)
1077 rgChildren.Add("ap_version", "11point");
1078 else
1079 rgChildren.Add("ap_version", ap_version.ToString().ToLower());
1080
1081 rgChildren.Add("show_per_class_result", show_per_class_result.ToString());
1082 rgChildren.Add("accuracy_average_window", accuracy_average_window.ToString());
1083
1084 return new RawProto(strName, "", rgChildren);
1085 }
1086
1093 {
1094 string strVal;
1096
1097 RawProto rpNetParam = rp.FindChild("net_param");
1098 if (rpNetParam != null)
1099 p.net_param = NetParameter.FromProto(rpNetParam);
1100
1101 RawProto rpTrainNetParam = rp.FindChild("train_net_param");
1102 if (rpTrainNetParam != null)
1103 p.train_net_param = NetParameter.FromProto(rpTrainNetParam);
1104
1105 RawProtoCollection rgpTn = rp.FindChildren("test_net_param");
1106 foreach (RawProto rpTest in rgpTn)
1107 {
1108 p.test_net_param.Add(NetParameter.FromProto(rpTest));
1109 }
1110
1111 RawProto rpTrainState = rp.FindChild("train_state");
1112 if (rpTrainState != null)
1113 p.train_state = NetState.FromProto(rpTrainState);
1114
1115 RawProtoCollection rgpNs = rp.FindChildren("test_state");
1116 foreach (RawProto rpNs in rgpNs)
1117 {
1118 p.test_state.Add(NetState.FromProto(rpNs));
1119 }
1120
1121 p.test_iter = rp.FindArray<int>("test_iter");
1122
1123 if ((strVal = rp.FindValue("test_interval")) != null)
1124 p.test_interval = int.Parse(strVal);
1125
1126 if ((strVal = rp.FindValue("test_compute_loss")) != null)
1127 p.test_compute_loss = bool.Parse(strVal);
1128
1129 if ((strVal = rp.FindValue("test_initialization")) != null)
1130 p.test_initialization = bool.Parse(strVal);
1131
1132 if ((strVal = rp.FindValue("base_lr")) != null)
1133 p.base_lr = ParseDouble(strVal);
1134
1135 if ((strVal = rp.FindValue("display")) != null)
1136 p.display = int.Parse(strVal);
1137
1138 if ((strVal = rp.FindValue("average_loss")) != null)
1139 p.average_loss = int.Parse(strVal);
1140
1141 if ((strVal = rp.FindValue("max_iter")) != null)
1142 p.max_iter = int.Parse(strVal);
1143
1144 if ((strVal = rp.FindValue("iter_size")) != null)
1145 p.iter_size = int.Parse(strVal);
1146
1147 if ((strVal = rp.FindValue("lr_policy")) != null)
1148 p.lr_policy = strVal;
1149
1150 if ((strVal = rp.FindValue("gamma")) != null)
1151 p.gamma = ParseDouble(strVal);
1152
1153 if ((strVal = rp.FindValue("power")) != null)
1154 p.power = ParseDouble(strVal);
1155
1156 if ((strVal = rp.FindValue("momentum")) != null)
1157 p.momentum = ParseDouble(strVal);
1158
1159 if ((strVal = rp.FindValue("weight_decay")) != null)
1160 p.weight_decay = ParseDouble(strVal);
1161
1162 if ((strVal = rp.FindValue("regularization_type")) != null)
1163 p.regularization_type = strVal;
1164
1165 if ((strVal = rp.FindValue("stepsize")) != null)
1166 p.stepsize = int.Parse(strVal);
1167
1168 p.stepvalue = rp.FindArray<int>("stepvalue");
1169
1170 if ((strVal = rp.FindValue("clip_gradients")) != null)
1171 p.clip_gradients = ParseDouble(strVal);
1172
1173 if ((strVal = rp.FindValue("enable_clip_gradient_status")) != null)
1174 p.enable_clip_gradient_status = bool.Parse(strVal);
1175
1176 if ((strVal = rp.FindValue("snapshot")) != null)
1177 p.snapshot = int.Parse(strVal);
1178
1179 if ((strVal = rp.FindValue("snapshot_prefix")) != null)
1180 p.snapshot_prefix = strVal;
1181
1182 if ((strVal = rp.FindValue("snapshot_diff")) != null)
1183 p.snapshot_diff = bool.Parse(strVal);
1184
1185 if ((strVal = rp.FindValue("snapshot_format")) != null)
1186 {
1187 switch (strVal)
1188 {
1189 case "BINARYPROTO":
1190 p.snapshot_format = SnapshotFormat.BINARYPROTO;
1191 break;
1192
1193 case "HDF5":
1194 p.snapshot_format = SnapshotFormat.BINARYPROTO;
1195 break;
1196
1197 default:
1198 throw new Exception("Unknown 'snapshot_format' value: " + strVal);
1199 }
1200 }
1201
1202 if ((strVal = rp.FindValue("device_id")) != null)
1203 p.device_id = int.Parse(strVal);
1204
1205 if ((strVal = rp.FindValue("random_seed")) != null)
1206 p.random_seed = long.Parse(strVal);
1207
1208 if ((strVal = rp.FindValue("type")) != null)
1209 {
1210 string strVal1 = strVal.ToLower();
1211
1212 switch (strVal1)
1213 {
1214 case "sgd":
1215 p.type = SolverType.SGD;
1216 break;
1217
1218 case "nesterov":
1219 p.type = SolverType.NESTEROV;
1220 break;
1221
1222 case "adagrad":
1223 p.type = SolverType.ADAGRAD;
1224 break;
1225
1226 case "adadelta":
1227 p.type = SolverType.ADADELTA;
1228 break;
1229
1230 case "adam":
1231 p.type = SolverType.ADAM;
1232 break;
1233
1234 case "adamw":
1235 p.type = SolverType.ADAMW;
1236 break;
1237
1238 case "rmsprop":
1239 p.type = SolverType.RMSPROP;
1240 break;
1241
1242 case "lbgfs":
1243 p.type = SolverType.LBFGS;
1244 break;
1245
1246 default:
1247 throw new Exception("Unknown solver 'type' value: " + strVal);
1248 }
1249 }
1250
1251 if ((strVal = rp.FindValue("delta")) != null)
1252 p.delta = ParseDouble(strVal);
1253
1254 if ((strVal = rp.FindValue("momentum2")) != null)
1255 p.momentum2 = ParseDouble(strVal);
1256
1257 if ((strVal = rp.FindValue("rms_decay")) != null)
1258 p.rms_decay = ParseDouble(strVal);
1259
1260 if ((strVal = rp.FindValue("adamw_decay")) != null)
1261 p.adamw_decay = ParseDouble(strVal);
1262
1263 if ((strVal = rp.FindValue("debug_info")) != null)
1264 p.debug_info = bool.Parse(strVal);
1265
1266 if ((strVal = rp.FindValue("lbgfs_corrections")) != null)
1267 p.lbgfs_corrections = int.Parse(strVal);
1268
1269 if ((strVal = rp.FindValue("snapshot_after_train")) != null)
1270 p.snapshot_after_train = bool.Parse(strVal);
1271
1272 if ((strVal = rp.FindValue("custom_trainer")) != null)
1273 p.custom_trainer = strVal;
1274
1275 if ((strVal = rp.FindValue("custom_trainer_properties")) != null)
1276 p.custom_trainer_properties = strVal;
1277
1278 if ((strVal = rp.FindValue("output_average_results")) != null)
1279 p.output_average_results = bool.Parse(strVal);
1280
1281 if ((strVal = rp.FindValue("snapshot_include_weights")) != null)
1282 p.snapshot_include_weights = bool.Parse(strVal);
1283
1284 if ((strVal = rp.FindValue("snapshot_include_state")) != null)
1285 p.snapshot_include_state = bool.Parse(strVal);
1286
1287 if ((strVal = rp.FindValue("eval_type")) != null)
1288 {
1289 strVal = strVal.ToLower();
1290
1291 switch (strVal)
1292 {
1293 case "classification":
1294 p.eval_type = EvaluationType.CLASSIFICATION;
1295 break;
1296
1297 case "detection":
1298 p.eval_type = EvaluationType.DETECTION;
1299 break;
1300
1301 default:
1302 throw new Exception("Unknown eval_type '" + strVal + "'!");
1303 }
1304 }
1305
1306 if ((strVal = rp.FindValue("ap_version")) != null)
1307 {
1308 strVal = strVal.ToLower();
1309
1310 switch (strVal)
1311 {
1312 case "11point":
1313 p.ap_version = ApVersion.ELEVENPOINT;
1314 break;
1315
1316 case "maxintegral":
1317 p.ap_version = ApVersion.MAXINTEGRAL;
1318 break;
1319
1320 case "integral":
1321 p.ap_version = ApVersion.INTEGRAL;
1322 break;
1323
1324 default:
1325 throw new Exception("Unknown ap_type '" + strVal + "'!");
1326 }
1327 }
1328
1329 if ((strVal = rp.FindValue("show_per_class_result")) != null)
1330 p.show_per_class_result = bool.Parse(strVal);
1331
1332 if ((strVal = rp.FindValue("accuracy_average_window")) != null)
1333 p.accuracy_average_window = int.Parse(strVal);
1334
1335 return p;
1336 }
1337
1342 public string DebugString()
1343 {
1344 return m_solverType.ToString();
1345 }
1346 }
1347}
The BaseParameter class is the base class for all other parameter classes.
static double ParseDouble(string strVal)
Parse double values using the US culture if the decimal separator = '.', then using the native cultur...
The RawProtoCollection class is a list of RawProto objects.
void Add(RawProto p)
Adds a RawProto to the collection.
The RawProto class is used to parse and output Google prototxt file data.
Definition: RawProto.cs:17
RawProto FindChild(string strName)
Searches for a given node.
Definition: RawProto.cs:231
RawProtoCollection FindChildren(params string[] rgstrName)
Searches for all children with a given name in this node's children.
Definition: RawProto.cs:263
The Utility class provides general utility funtions.
Definition: Utility.cs:35
static string Replace(string str, char ch1, char ch2)
Replaces each instance of one character with another character in a given string.
Definition: Utility.cs:864
Specifies the parameters use to create a Net
Definition: NetParameter.cs:18
static NetParameter FromProto(RawProto rp)
Parse a RawProto into a new instance of the parameter.
override RawProto ToProto(string strName)
Constructor for the parameter.
Specifies the NetState which includes the phase, level and stage for which a given Net is to run unde...
Definition: NetState.cs:19
static NetState FromProto(RawProto rp)
Parses a RawProto representing a NetState into a NetState instance.
Definition: NetState.cs:178
override RawProto ToProto(string strName)
Converts this NetState to a RawProto.
Definition: NetState.cs:162
The SolverParameter is a parameter for the solver, specifying the train and test networks.
int stepsize
The stepsize for learning rate policy 'step'.
RegularizationType Regularization
Specifies the regularization type (default = L2).
int max_iter
The maximum number of iterations.
int lbgfs_corrections
Specifies the number of lbgfs corrections used with the L-BGFS solver.
SnapshotFormat
Defines the format of each snapshot.
List< int > test_iter
The number of iterations for each test.
string regularization_type
Specifies the regularization type (default = 'L2').
SolverParameter Clone()
Creates a new copy of the SolverParameter.
NetParameter net_param
Inline train net param, possibly combined with one or more test nets.
bool debug_info
If true, print information about the state of the net that may help with debugging learning problems.
NetParameter train_net_param
Inline train net param, possibly combined with one or more test nets.
LearningRatePolicyType
Defines the learning rate policy to use.
List< NetState > test_state
The states for the train/test nets. Must be unspecified or specified once per net.
SolverType
Defines the type of solver.
string snapshot_prefix
The prefix for the snapshot.
string lr_policy
The learning rate decay policy.
double delta
Numerical stability for RMSProp, AdaGrad, AdaDelta, Adam and AdamW solvers (default = 1e-08).
static SolverParameter FromProto(RawProto rp)
Parses a new SolverParameter from a RawProto.
LearningRatePolicyType LearningRatePolicy
The learning rate decay policy.
int device_id
The device id that will be used when run on the GPU.
ApVersion ap_version
Specifies the AP Version to use for average precision when using Single-Shot Detection (SSD) - (defau...
double power
The 'power' parameter to compute the learning rate.
long random_seed
If non-negative, the seed with which the Solver will initialize the caffe random number generator – u...
int average_loss
Display the loss averaged over the last average_loss iterations.
bool enable_clip_gradient_status
Optionally, enable status output when gradients are clipped (default = true)
double momentum2
An additional momentum property for the Adam and AdamW solvers (default = 0.999).
int test_interval
The number of iterations between two testing phases.
bool output_average_results
Specifies to average loss results before they are output - this can be faster when there are a lot of...
int iter_size
Accumulate gradients over 'iter_size' x 'batch_size' instances.
override RawProto ToProto(string strName)
Converts the SolverParameter into a RawProto.
string DebugString()
Returns a debug string for the SolverParameter.
RegularizationType
Defines the regularization type. When enabled, weight_decay is used.
EvaluationType
Defines the evaluation method used in the SSD algorithm.
double gamma
Specifies the 'gamma' parameter to compute the 'step', 'exp', 'inv', and 'sigmoid' learning policy (d...
bool snapshot_after_train
If false, don't save a snapshot after training finishes.
string custom_trainer_properties
Specifies the custom trainer properties (if any) - this is an optional setting used by exteral softwa...
SnapshotFormat snapshot_format
The snapshot format.
bool snapshot_include_weights
Specifies whether or not the snapshot includes the trained weights. The default = true.
bool test_compute_loss
Test the compute loss.
SolverParameter()
The SolverParameter constructor.
string custom_trainer
Specifies the Name of the custom trainer (if any) - this is an optional setting used by exteral softw...
EvaluationType eval_type
Specifies the evaluation type to use when using Single-Shot Detection (SSD) - (default = NONE,...
bool test_initialization
If true, run an initial test pass before the first iteration, ensuring memory availability and printi...
List< NetParameter > test_net_param
Inline test net params.
int display
The number of iterations between displaying info. If display = 0, no info will be displayed.
double adamw_decay
Specifies the 'AdamW' detached weight decay value used by the 'AdamW' solver (default = 0....
bool snapshot_diff
Whether to snapshot diff in the results or not. Snapshotting diff will help debugging but the final p...
double weight_decay
Specifies the weight decay (default = 0.0005).
bool snapshot_include_state
Specifies whether or not the snapshot includes the solver state. The default = false....
List< int > stepvalue
The step values for learning rate policy 'multistep'.
double momentum
Specifies the momentum value - used by all solvers EXCEPT the 'AdaGrad' and 'RMSProp' solvers....
bool show_per_class_result
Specifies whether or not to display results per class when using Single-Shot Detection (SSD) - (defau...
int accuracy_average_window
Specifies the window over which to average the accuracies (default = 0 which ignores averaging).
int snapshot
Specifies the snapshot interval.
double base_lr
The base learning rate (default = 0.01).
SolverType type
Specifies the solver type.
double rms_decay
Specifies the 'RMSProp' decay value used by the 'RMSProp' solver (default = 0.95).
NetState train_state
The states for the train/test nets. Must be unspecified or specified once per net.
double clip_gradients
Set clip_gradients to >= 0 to clip parameter gradients to that L2 norm, whenever their actual L2 norm...
The DictionaryParamEditor is used to visually edit dictionary based parameters that are stored as a k...
The MyCaffe.basecode contains all generic types used throughout MyCaffe.
Definition: Annotation.cs:12
@ NONE
No training category specified.
ApVersion
Defines the different way of computing average precision.
Definition: Interfaces.cs:234
The MyCaffe.common namespace contains common MyCaffe classes.
Definition: BatchInput.cs:8
@ L2
Specifies to use L2 loss.
The MyCaffe.param namespace contains parameters used to create models.
The MyCaffe namespace contains the main body of MyCaffe code that closesly tracks the C++ Caffe open-...
Definition: Annotation.cs:12