MyCaffe  1.11.8.27
Deep learning software for Windows C# programmers.
ProjectEx.cs
1using System;
2using System.Collections.Generic;
3using System.Linq;
4using System.Text;
5using System.IO;
7
8namespace MyCaffe.basecode
9{
14 public class ProjectEx
15 {
16 ProjectDescriptor m_project;
17 StateDescriptor m_state;
18 RawProto m_protoModel = null;
19 RawProto m_protoSolver = null;
20 bool m_bExistTest = false;
21 bool m_bExistTrain = false;
22 bool m_bDatasetAdjusted = false;
23 bool m_bDefaultSaveImagesToFile = true;
24 Stage m_stage = Stage.NONE;
25 int m_nOriginalProjectID = 0;
26
30 public event EventHandler<OverrideProjectArgs> OnOverrideModel;
34 public event EventHandler<OverrideProjectArgs> OnOverrideSolver;
35
41 public ProjectEx(string strName, string strDsName = null)
42 {
43 m_project = new ProjectDescriptor(strName);
44 m_project.Dataset = new descriptors.DatasetDescriptor(strDsName);
45 m_state = new StateDescriptor(0, strName + " results", m_project.Owner);
46 }
47
57 public ProjectEx(ProjectDescriptor prj, StateDescriptor state = null, bool bExistTrain = false, bool bExistTest = false, bool bQueryModel = true, bool bQuerySolver = true)
58 {
59 m_project = prj;
60
61 if (state == null)
62 state = new StateDescriptor(0, prj.Name + " results", m_project.Owner);
63
64 m_state = state;
65
66 if (bQueryModel)
68 else
69 m_project.ModelName = getModelName(prj.ModelDescription);
70
71 if (bQuerySolver)
73 else
74 m_project.SolverName = getSolverType(prj.SolverDescription);
75
76 m_bExistTest = bExistTest;
77 m_bExistTrain = bExistTrain;
78 }
79
85 {
86 RawProtoCollection col = m_protoModel.FindChildren("layer");
87 foreach (RawProto layer in col)
88 {
89 RawProto type = layer.FindChild("type");
90 if (type.Value.ToLower() == "annotateddata")
91 return true;
92 }
93
94 return false;
95 }
96
97 private string parse(string str, string strTarget, string strDefault = "UNKNOWN")
98 {
99 if (str == null)
100 return strDefault;
101
102 int nPos1 = 0;
103
104 while (nPos1 < str.Length)
105 {
106 nPos1 = str.IndexOf(strTarget, nPos1);
107 if (nPos1 < 0)
108 return strDefault;
109
110 if (nPos1 == 0 || char.IsWhiteSpace(str[nPos1 - 1]) || str[nPos1 - 1] == '\n' || str[nPos1 - 1] == '\r')
111 break;
112
113 nPos1++;
114 }
115
116 if (nPos1 >= str.Length)
117 return strDefault;
118
119 nPos1 += strTarget.Length;
120
121 while (nPos1 < str.Length && (char.IsWhiteSpace(str[nPos1]) || str[nPos1] == '\"'))
122 {
123 nPos1++;
124 }
125
126 string strName;
127
128 int nPos2 = str.IndexOfAny(new char[] { ' ', '\n', '\r', '\"', '\t' }, nPos1);
129 if (nPos2 < 0)
130 strName = str.Substring(nPos1);
131 else
132 strName = str.Substring(nPos1, nPos2 - nPos1).Trim(' ', '\"');
133
134 return strName;
135 }
136
137 private string getModelName(string strDesc)
138 {
139 string strName = parse(strDesc, "name:");
140 return strName;
141 }
142
143 private string getSolverType(string strDesc)
144 {
145 string strName = parse(strDesc, "type:", "SGD");
146 return strName;
147 }
148
149 private void setDatasetFromProto(RawProto proto)
150 {
151 RawProtoCollection col = proto.FindChildren("layer");
152 string strSrcTest = null;
153 string strSrcTrain = null;
154 string strSrcTest2 = null;
155 string strSrcTrain2 = null;
156
157 foreach (RawProto rp in col)
158 {
159 RawProto protoType = rp.FindChild("type");
160 if (protoType != null && protoType.Value == "Data")
161 {
162 RawProto protoParam = rp.FindChild("data_param");
163 if (protoParam != null)
164 {
165 bool bPrimary = true;
166
167 RawProto protoPrimary = protoParam.FindChild("primary_data");
168 if (protoPrimary != null)
169 bPrimary = bool.Parse(protoPrimary.Value);
170
171 RawProto protoSrc = protoParam.FindChild("source");
172 if (protoSrc != null)
173 {
174 RawProto protoInclude = rp.FindChild("include");
175 if (protoInclude != null)
176 {
177 RawProto protoPhase = protoInclude.FindChild("phase");
178 if (protoPhase != null)
179 {
180 if (bPrimary)
181 {
182 if (protoPhase.Value == "TRAIN")
183 strSrcTrain = protoSrc.Value;
184 else if (protoPhase.Value == "TEST")
185 strSrcTest = protoSrc.Value;
186 }
187 else
188 {
189 if (protoPhase.Value == "TRAIN")
190 strSrcTrain2 = protoSrc.Value;
191 else if (protoPhase.Value == "TEST")
192 strSrcTest2 = protoSrc.Value;
193 }
194 }
195 }
196 }
197 }
198 }
199 }
200
201 if (strSrcTest != null)
202 {
203 bool bSaveImagesToFile = (m_project.Dataset.TestingSource != null) ? m_project.Dataset.TestingSource.SaveImagesToFile : m_bDefaultSaveImagesToFile;
204 m_project.Dataset.TestingSource = new SourceDescriptor(strSrcTest, bSaveImagesToFile);
205 }
206
207 if (strSrcTrain != null)
208 {
209 bool bSaveImagesToFile = (m_project.Dataset.TrainingSource != null) ? m_project.Dataset.TrainingSource.SaveImagesToFile : m_bDefaultSaveImagesToFile;
210 m_project.Dataset.TrainingSource = new SourceDescriptor(strSrcTrain, bSaveImagesToFile);
211 }
212
213 if (strSrcTest2 != null || strSrcTrain2 != null)
214 {
215 if (m_project.DatasetTarget == null)
216 m_project.DatasetTarget = new DatasetDescriptor(m_project.Dataset.Name + "_tgt");
217
218 if (strSrcTest2 != null)
219 {
220 bool bSaveImagesToFile = (m_project.DatasetTarget.TestingSource != null) ? m_project.DatasetTarget.TestingSource.SaveImagesToFile : m_bDefaultSaveImagesToFile;
221 m_project.Dataset.TestingSource = new SourceDescriptor(strSrcTest2, bSaveImagesToFile);
222 }
223
224 if (strSrcTrain2 != null)
225 {
226 bool bSaveImagesToFile = (m_project.DatasetTarget.TrainingSource != null) ? m_project.DatasetTarget.TrainingSource.SaveImagesToFile : m_bDefaultSaveImagesToFile;
227 m_project.Dataset.TrainingSource = new SourceDescriptor(strSrcTrain2, bSaveImagesToFile);
228 }
229 }
230 }
231
232 private void setDatasetToProto(RawProto proto)
233 {
234 RawProtoCollection col = proto.FindChildren("layer");
235 string strSrcTest = m_project.Dataset.TestingSourceName;
236 string strSrcTrain = m_project.Dataset.TrainingSourceName;
237 string strSrcTest2 = (m_project.DatasetTarget != null) ? m_project.DatasetTarget.TestingSourceName : null;
238 string strSrcTrain2 = (m_project.DatasetTarget != null ) ? m_project.DatasetTarget.TrainingSourceName : null;
239
240 foreach (RawProto rp in col)
241 {
242 RawProto protoType = rp.FindChild("type");
243 if (protoType != null && (protoType.Value == "Data" || protoType.Value == "AnnotatedData"))
244 {
245 RawProto protoParam = rp.FindChild("data_param");
246 if (protoParam != null)
247 {
248 bool bPrimary = true;
249
250 RawProto protoPrimary = protoParam.FindChild("primary_data");
251 if (protoPrimary != null)
252 bPrimary = bool.Parse(protoPrimary.Value);
253
254 RawProto protoSrc = protoParam.FindChild("source");
255 if (protoSrc != null)
256 {
257 RawProto protoInclude = rp.FindChild("include");
258 if (protoInclude != null)
259 {
260 RawProto protoPhase = protoInclude.FindChild("phase");
261 if (protoPhase != null)
262 {
263 if (bPrimary)
264 {
265 if (protoPhase.Value == "TRAIN")
266 {
267 if (strSrcTrain != null)
268 protoSrc.Value = strSrcTrain;
269 }
270 else if (protoPhase.Value == "TEST")
271 {
272 if (strSrcTest != null)
273 protoSrc.Value = strSrcTest;
274 }
275 }
276 else
277 {
278 if (protoPhase.Value == "TRAIN")
279 {
280 if (strSrcTrain2 != null)
281 protoSrc.Value = strSrcTrain2;
282 }
283 else if (protoPhase.Value == "TEST")
284 {
285 if (strSrcTest2 != null)
286 protoSrc.Value = strSrcTest2;
287 }
288 }
289 }
290 }
291 }
292 }
293 }
294 }
295 }
296
300 public bool DatasetAdjusted
301 {
302 get { return m_bDatasetAdjusted; }
303 set { m_bDatasetAdjusted = value; }
304 }
305
312 public string GetCustomTrainer(out string strProperties)
313 {
314 if (m_protoSolver == null)
316
317 strProperties = "";
318
319 RawProto rp = m_protoSolver.FindChild("custom_trainer");
320 if (rp == null)
321 return null;
322
323 if (rp.Value == null || rp.Value.Length == 0)
324 return null;
325
326 RawProto rprop = m_protoSolver.FindChild("custom_trainer_properties");
327 if (rprop != null)
328 strProperties = rprop.Value;
329
330 return rp.Value;
331 }
332
333 private Phase getPhase(RawProto rp)
334 {
335 RawProto rpInc = rp.FindChild("include");
336 if (rpInc == null)
337 return Phase.NONE;
338
339 RawProto rpPhase = rpInc.FindChild("phase");
340 if (rpPhase == null)
341 return Phase.NONE;
342
343 string strPhase = rpPhase.Value.ToUpper();
344
345 if (strPhase == Phase.TEST.ToString())
346 return Phase.TEST;
347
348 if (strPhase == Phase.TRAIN.ToString())
349 return Phase.TRAIN;
350
351 return Phase.NONE;
352 }
353
359 public int GetBatchSize(Phase phase)
360 {
361 if (m_protoModel == null)
363
364 RawProtoCollection col = m_protoModel.FindChildren("layer");
365
366 foreach (RawProto rp1 in col)
367 {
368 Phase p = getPhase(rp1);
369
370 if (p == phase || phase == Phase.NONE)
371 {
372 RawProto rp = rp1.FindChild("batch_data_param");
373
374 if (rp == null)
375 rp = rp1.FindChild("data_param");
376
377 if (rp == null)
378 rp = rp1.FindChild("memory_data_param");
379
380 if (rp != null)
381 {
382 rp = rp.FindChild("batch_size");
383
384 if (rp == null)
385 return 0;
386
387 return int.Parse(rp.Value);
388 }
389 }
390 }
391
392 return 0;
393 }
394
402 public double? GetLayerSetting(Phase phase, string strLayer, string strParam)
403 {
404 if (m_protoModel == null)
406
407 RawProtoCollection col = m_protoModel.FindChildren("layer");
408
409 foreach (RawProto rp1 in col)
410 {
411 Phase p = getPhase(rp1);
412
413 if (p == phase || phase == Phase.NONE)
414 {
415 RawProto rp = rp1.FindChild(strLayer);
416
417 if (rp != null)
418 {
419 rp = rp.FindChild(strParam);
420
421 if (rp == null)
422 return null;
423
425 }
426 }
427 }
428
429 return null;
430 }
431
437 public string GetSolverSetting(string strParam)
438 {
439 if (m_protoSolver == null)
441
442 RawProto proto = m_protoSolver.FindChild(strParam);
443 if (proto == null)
444 return null;
445
446 return proto.Value;
447 }
448
454 public double? GetSolverSettingAsNumeric(string strParam)
455 {
456 string strVal = GetSolverSetting(strParam);
457 if (strVal == null)
458 return null;
459
460 double dfVal;
461 if (!BaseParameter.TryParse(strVal, out dfVal))
462 return null;
463
464 return dfVal;
465 }
466
472 public int? GetSolverSettingAsInt(string strParam)
473 {
474 double? dfVal = GetSolverSettingAsNumeric(strParam);
475 if (!dfVal.HasValue)
476 return null;
477
478 return (int)dfVal.Value;
479 }
480
486 public bool? GetSolverSettingAsBool(string strParam)
487 {
488 string strVal = GetSolverSetting(strParam);
489 if (strVal == null)
490 return null;
491
492 return bool.Parse(strVal);
493 }
494
499 {
500 get { return m_project.Settings; }
501 set { m_project.Settings = value; }
502 }
503
507 public string Name
508 {
509 get { return m_project.Name; }
510 set { m_project.Name = value; }
511 }
512
516 public int ID
517 {
518 get { return m_project.ID; }
519 }
520
524 public int OriginalID
525 {
526 get
527 {
528 if (m_nOriginalProjectID > 0)
529 return m_nOriginalProjectID;
530
531 return ID;
532 }
533 set
534 {
535 m_nOriginalProjectID = value;
536 }
537 }
538
542 public string Owner
543 {
544 get { return m_project.Owner; }
545 set { m_project.Owner = value; }
546 }
547
551 public bool Active
552 {
553 get { return m_project.Active; }
554 }
555
560 {
561 get
562 {
563 if (m_protoSolver == null)
564 return TRAINING_CATEGORY.NONE;
565
566 string strCustomTrainer = GetSolverSetting("custom_trainer");
567 if (string.IsNullOrEmpty(strCustomTrainer))
568 return TRAINING_CATEGORY.NONE;
569
570 if (strCustomTrainer == "RL.Trainer")
571 return TRAINING_CATEGORY.REINFORCEMENT;
572
573 if (strCustomTrainer == "RNN.Trainer")
574 return TRAINING_CATEGORY.RECURRENT;
575
576 if (strCustomTrainer == "Dual.Trainer")
577 return TRAINING_CATEGORY.DUAL;
578
579 return TRAINING_CATEGORY.CUSTOM;
580 }
581 }
582
587 {
588 get { return m_stage; }
589 set { m_stage = value; }
590 }
591
596 {
597 get { return (double)m_project.Settings.SuperBoostProbability; }
598 set { m_project.Settings.SuperBoostProbability = value; }
599 }
600
605 {
606 get { return m_project.Parameters.Find("UseTrainingSourceForTesting", false); }
607 }
608
614 {
615 get { return m_project.Settings.EnableLabelBalancing; }
616 }
617
623 {
624 get { return m_project.Settings.EnableLabelBoosting; }
625 }
626
632 {
633 get { return m_project.Settings.EnableRandomInputSelection; }
634 }
635
642 {
643 get { return m_project.Settings.EnablePairInputSelection; }
644 }
645
649 public string GpuOverride
650 {
651 get { return m_project.GpuOverride; }
652 }
653
659 {
660 get { return m_project.Settings.ImageDbLoadMethod; }
661 }
662
666 public int ImageLoadLimit
667 {
668 get { return m_project.Settings.ImageDbLoadLimit; }
669 }
670
675 {
676 get { return m_project.Settings.ImageDbAutoRefreshScheduledUpdateInMs; }
677 }
678
683 {
685 }
686
694 {
695 get { return m_project.Settings.SnapshotWeightUpdateMethod; }
696 }
697
702 {
703 get { return m_project.Settings.SnapshotLoadMethod; }
704 }
705
709 public string SolverDescription
710 {
711 get { return (m_protoSolver == null) ? null : m_protoSolver.ToString(); }
712 set
713 {
714 m_project.SolverName = getSolverType(value);
715 m_project.SolverDescription = value;
716 m_protoSolver = null;
717
718 if (value != null && value.Length > 0)
719 {
720 m_protoSolver = RawProto.Parse(value);
721
722 if (m_project.Dataset != null)
723 {
724 if (string.IsNullOrEmpty(m_project.Dataset.Name))
725 setDatasetFromProto(m_protoSolver);
726 else
727 setDatasetToProto(m_protoSolver);
728 }
729
730 RawProto rpType = m_protoSolver.FindChild("type");
731 if (rpType != null)
732 m_project.SolverName = rpType.Value;
733 }
734 }
735 }
736
740 public string ModelDescription
741 {
742 get { return (m_protoModel == null) ? null : m_protoModel.ToString(); }
743 set
744 {
745 m_project.ModelName = getModelName(value);
746 m_project.ModelDescription = value;
747 m_protoModel = null;
748
749 if (value != null && value.Length > 0)
750 {
751 m_protoModel = RawProto.Parse(value);
752
753 if (m_project.Dataset != null)
754 {
755 if (string.IsNullOrEmpty(m_project.Dataset.Name))
756 setDatasetFromProto(m_protoModel);
757 else
758 setDatasetToProto(m_protoModel);
759 }
760
761 RawProto rpName = m_protoModel.FindChild("name");
762 if (rpName != null)
763 m_project.ModelName = rpName.Value;
764 }
765 }
766 }
767
772 {
773 get { return m_project.Group; }
774 }
775
780 {
781 get { return m_project.Dataset.ModelGroup; }
782 }
783
788 {
789 get { return m_project.Dataset.DatasetGroup; }
790 }
791
796 {
797 get { return m_project.Parameters; }
798 }
799
804 {
805 get { return m_project.TotalIterations; }
806 set { m_project.TotalIterations = value; }
807 }
808
812 public bool HasResults
813 {
814 get { return m_state.HasResults; }
815 }
816
820 public int Iterations
821 {
822 get { return m_state.Iterations; }
823 set { m_state.Iterations = value; }
824 }
825
829 public double BestAccuracy
830 {
831 get { return m_state.Accuracy; }
832 set { m_state.Accuracy = value; }
833 }
834
838 public double BestError
839 {
840 get { return m_state.Error; }
841 set { m_state.Error = value; }
842 }
843
847 public byte[] SolverState
848 {
849 get { return m_state.State; }
850 set { m_state.State = value; }
851 }
852
856 public byte[] WeightsState
857 {
858 get { return m_state.Weights; }
859 set { m_state.Weights = value; }
860 }
861
865 public string DatasetName
866 {
867 get
868 {
869 if (m_project.Dataset != null)
870 return m_project.Dataset.Name;
871
872 return null;
873 }
874 }
875
880 {
881 get { return m_project.Dataset; }
882 }
883
891 {
892 get { return m_project.DatasetTarget; }
893 }
894
899 {
900 get
901 {
902 ParameterDescriptor p = m_project.Parameters.Find("TargetDatasetID");
903 if (p == null)
904 return 0;
905
906 int nID;
907 if (!int.TryParse(p.Value, out nID))
908 return 0;
909
910 return nID;
911 }
912
913 set
914 {
915 ParameterDescriptor p = m_project.Parameters.Find("TargetDatasetID");
916 if (p == null)
917 m_project.Parameters.Add(new ParameterDescriptor(0, "TargetDatasetID", value.ToString()));
918 else
919 p.Value = value.ToString();
920 }
921 }
922
927 {
928 get { return m_bExistTest; }
929 }
930
935 {
936 get { return m_bExistTrain; }
937 }
938
943 {
944 get { return m_project.AnalysisItems; }
945 }
946
950 public string ModelName
951 {
952 get { return m_project.ModelName; }
953 }
954
958 public string SolverType
959 {
960 get { return m_project.SolverName; }
961 }
962
969 public bool SetSolverVariable(string strVar, string strVal)
970 {
971 if (m_protoSolver != null)
972 {
973 RawProto protoVar = m_protoSolver.FindChild(strVar);
974
975 if (protoVar != null)
976 protoVar.Value = strVal;
977 else
978 m_protoSolver.Children.Add(new RawProto(strVar, strVal));
979
980 m_project.SolverDescription = m_protoSolver.ToString();
981
982 return true;
983 }
984
985 return false;
986 }
987
992 public void LoadSolverFile(string strFile)
993 {
994 using (StreamReader sr = new StreamReader(strFile))
995 {
996 SolverDescription = sr.ReadToEnd();
997 }
998 }
999
1004 public void LoadModelFile(string strFile)
1005 {
1006 using (StreamReader sr = new StreamReader(strFile))
1007 {
1008 ModelDescription = sr.ReadToEnd();
1009 }
1010 }
1011
1024 public RawProto CreateModelForRunning(string strName, int nNum, int nChannels, int nHeight, int nWidth, out RawProto protoTransform, Stage stage = Stage.NONE, bool bSkipLossLayer = false)
1025 {
1026 return CreateModelForRunning(m_project.ModelDescription, strName, nNum, nChannels, nHeight, nWidth, out protoTransform, stage, bSkipLossLayer);
1027 }
1028
1036 public static RawProto CreateModelForTraining(string strModelDescription, string strName, bool bCaffeFormat = false)
1037 {
1038 RawProto proto = RawProto.Parse(strModelDescription);
1039
1040 string strLayers = "layer";
1041 RawProtoCollection rgLayers = proto.FindChildren("layer");
1042 if (rgLayers.Count == 0)
1043 {
1044 rgLayers = proto.FindChildren("layers");
1045 strLayers = "layers";
1046 }
1047
1048 bool bDirty = false;
1049 RawProtoCollection rgRemove = new RawProtoCollection();
1050 RawProto protoSoftmax = null;
1051 RawProto protoName = proto.FindChild("name");
1052 int nTrainDataLayerIdx = -1;
1053 int nTestDataLayerIdx = -1;
1054 int nSoftmaxLossLayerIdx = -1;
1055 int nAccuracyLayerIdx = -1;
1056 int nIdx = 0;
1057
1058 foreach (RawProto layer in rgLayers)
1059 {
1060 bool bRemove = false;
1061 RawProto type = layer.FindChild("type");
1062 RawProto include = layer.FindChild("include");
1063 RawProto exclude = layer.FindChild("exclude");
1064
1065 string strType = type.Value.ToLower();
1066
1067 if (strType == "softmax")
1068 protoSoftmax = layer;
1069
1070 if (include != null)
1071 {
1072 RawProto phase = include.FindChild("phase");
1073 if (phase != null)
1074 {
1075 if (phase.Value != "TEST" && phase.Value != "TRAIN")
1076 bRemove = true;
1077 else
1078 {
1079 if (strType == "data" || strType == "tokenizeddata")
1080 {
1081 if (phase.Value == "TRAIN")
1082 nTrainDataLayerIdx = nIdx;
1083 else
1084 nTestDataLayerIdx = nIdx;
1085 }
1086 else if (strType == "accuracy")
1087 {
1088 nAccuracyLayerIdx = nIdx;
1089 }
1090 else if (strType == "softmaxwithloss")
1091 {
1092 nSoftmaxLossLayerIdx = nIdx;
1093 }
1094 }
1095 }
1096 }
1097
1098 if (!bRemove && exclude != null)
1099 {
1100 RawProto phase = exclude.FindChild("phase");
1101 if (phase != null)
1102 {
1103 if (phase.Value == "TEST" || phase.Value == "TRAIN")
1104 bRemove = true;
1105 }
1106 }
1107
1108 if (bRemove)
1109 {
1110 rgRemove.Add(layer);
1111 }
1112
1113 nIdx++;
1114 }
1115
1116 if (nTestDataLayerIdx < 0)
1117 {
1118 string strProto = getDataLayerProto(strLayers, strName, bCaffeFormat, 16, "", Phase.TEST);
1119 RawProto protoData = RawProto.Parse(strProto).Children[0];
1120
1121 if (nTrainDataLayerIdx > 0)
1122 rgLayers.Insert(nTrainDataLayerIdx + 1, protoData);
1123 else
1124 rgLayers.Insert(0, protoData);
1125
1126 bDirty = true;
1127 }
1128
1129 if (nTrainDataLayerIdx < 0)
1130 {
1131 string strProto = getDataLayerProto(strLayers, strName, bCaffeFormat, 16, "", Phase.TRAIN);
1132 RawProto protoData = RawProto.Parse(strProto).Children[0];
1133 rgLayers.Insert(0, protoData);
1134 bDirty = true;
1135 }
1136
1137 foreach (RawProto layer in rgRemove)
1138 {
1139 proto.RemoveChild(layer);
1140 }
1141
1142 if (protoSoftmax != null)
1143 {
1144 RawProto type = protoSoftmax.FindChild("type");
1145 if (type != null)
1146 type.Value = "SoftmaxWithLoss";
1147
1148 protoSoftmax.Children.Add(new RawProto("bottom", "label"));
1149 protoSoftmax.Children.Add(new RawProto("loss_weight", "1", null, RawProto.TYPE.NUMERIC));
1150
1151 string strInclude = "include { phase: TRAIN }";
1152 protoSoftmax.Children.Add(RawProto.Parse(strInclude).Children[0]);
1153
1154 string strLoss = "loss_param { normalization: VALID }";
1155 protoSoftmax.Children.Add(RawProto.Parse(strLoss).Children[0]);
1156 bDirty = true;
1157 }
1158
1159 if (nAccuracyLayerIdx < 0)
1160 {
1161 string strBottom = null;
1162 if (rgLayers.Count > 0)
1163 {
1164 RawProto last = rgLayers[rgLayers.Count - 1];
1165 RawProtoCollection colBtm = last.FindChildren("bottom");
1166
1167 if (colBtm.Count > 0)
1168 strBottom = colBtm[0].Value;
1169 }
1170
1171 if (strBottom != null)
1172 {
1173 string strProto = getAccuracyLayerProto(strLayers, strBottom);
1174 RawProto protoData = RawProto.Parse(strProto).Children[0];
1175 rgLayers.Add(protoData);
1176 bDirty = true;
1177 }
1178 }
1179
1180 if (bDirty || proto.FindChildren("input_dim").Count > 0)
1181 {
1182 rgLayers.Insert(0, protoName);
1183 proto = new RawProto("root", null, rgLayers);
1184 }
1185
1186 return proto;
1187 }
1188
1189 private static string getDataLayerProto(string strLayer, string strName, bool bCaffeFormat, int nBatchSize, string strSrc, Phase phase)
1190 {
1191 string strRgb = (bCaffeFormat) ? "BGR" : "RGB";
1192 string strPhase = phase.ToString();
1193 return strLayer + " { name: \"" + strName + "\" type: \"Data\" top: \"data\" top: \"label\" include { phase: " + strPhase + " } transform_param { scale: 1 mirror: True use_imagedb_mean: True color_order: " + strRgb + " } data_param { source: \"" + strSrc + "\" batch_size: " + nBatchSize.ToString() + " backend: IMAGEDB enable_random_selection: True } }";
1194 }
1195
1196 private static string getAccuracyLayerProto(string strLayer, string strBottom)
1197 {
1198 return strLayer + " { name: \"accuracy\" type: \"Accuracy\" bottom: \"" + strBottom + "\" bottom: \"label\" top: \"accuracy\" include { phase: TEST } accuracy_param { top_k: 1 } }";
1199 }
1200
1201 private static PhaseStageCollection getPhases(RawProto proto, string strType)
1202 {
1203 PhaseStageCollection psCol = new PhaseStageCollection();
1204
1205 RawProtoCollection type = proto.FindChildren(strType);
1206 if (type == null || type.Count == 0)
1207 return psCol;
1208
1209 return getPhases(type);
1210 }
1211
1212 private static PhaseStageCollection getPhases(RawProtoCollection col)
1213 {
1214 PhaseStageCollection psCol = new PhaseStageCollection();
1215
1216 foreach (RawProto proto1 in col)
1217 {
1218 RawProto protoPhase = proto1.FindChild("phase");
1219 if (protoPhase == null)
1220 continue;
1221
1222 Stage stage = Stage.NONE;
1223 RawProto protoStage = proto1.FindChild("stage");
1224 if (protoStage != null)
1225 {
1226 if (protoStage.Value == Stage.RL.ToString())
1227 stage = Stage.RL;
1228
1229 else if (protoStage.Value == Stage.RNN.ToString())
1230 stage = Stage.RNN;
1231 }
1232
1233 Phase phase = Phase.NONE;
1234 if (protoPhase != null)
1235 {
1236 if (protoPhase.Value == Phase.ALL.ToString())
1237 phase = Phase.ALL;
1238
1239 else if (protoPhase.Value == Phase.RUN.ToString())
1240 phase = Phase.RUN;
1241
1242 else if (protoPhase.Value == Phase.TEST.ToString())
1243 phase = Phase.TEST;
1244
1245 else if (protoPhase.Value == Phase.TRAIN.ToString())
1246 phase = Phase.TRAIN;
1247 }
1248
1249 psCol.Add(phase, stage);
1250 }
1251
1252 return psCol;
1253 }
1254
1255 private static bool includeLayer(RawProto layer, Stage stage, out PhaseStageCollection psInclude, out PhaseStageCollection psExclude)
1256 {
1257 psInclude = getPhases(layer, "include");
1258 psExclude = getPhases(layer, "exlcude").FindAllWith(stage);
1259
1260 PhaseStageCollection psInclude1 = psInclude.FindAllWith(Stage.NONE);
1261 PhaseStageCollection psInclude2 = psInclude.FindAllWith(stage);
1262 PhaseStageCollection psInclude3 = psInclude.FindAllWith(Phase.NONE, Phase.ALL, Phase.RUN);
1263 psExclude = psExclude.FindAllWith(Phase.RUN);
1264
1265 if (psExclude.Count > 0)
1266 return false;
1267
1268 if (psInclude.Count > 0)
1269 {
1270 if (psInclude3.Count == 0 || (psInclude1.Count == 0 && psInclude2.Count == 0))
1271 return false;
1272 }
1273
1274 return true;
1275 }
1276
1277
1291 public static RawProto CreateModelForRunning(string strModelDescription, string strName, int nNum, int nChannels, int nHeight, int nWidth, out RawProto protoTransform, Stage stage = Stage.NONE, bool bSkipLossLayer = false)
1292 {
1293 PhaseStageCollection psInclude;
1294 PhaseStageCollection psExclude;
1295 RawProto proto = RawProto.Parse(strModelDescription);
1296 int nNameIdx = proto.FindChildIndex("name");
1297 int nInputInsertIdx = -1;
1298 int nInputShapeInsertIdx = -1;
1299 bool bNoInput = false;
1300 bool bNameSet = false;
1301
1302 protoTransform = null;
1303
1304 nNameIdx++;
1305 if (nNameIdx < 0)
1306 nNameIdx = 0;
1307
1308 RawProtoCollection rgLayers = proto.FindChildren("layer");
1309 bool bUsesLstm = false;
1310
1311 foreach (RawProto layer in rgLayers)
1312 {
1313 RawProto type = layer.FindChild("type");
1314 if (type != null)
1315 {
1316 string strType = type.Value.ToLower();
1317 if (strType == "lstm")
1318 {
1319 bUsesLstm = true;
1320 break;
1321 }
1322 else if (!bNameSet && (
1323 strType == "data" ||
1324 strType == "annotated_data" ||
1325 strType == "tokenizeddata"))
1326 {
1327 RawProtoCollection tops = layer.FindChildren("top");
1328 if (tops != null && tops.Count > 0)
1329 {
1330 if (strType == "tokenizeddata")
1331 strName = "data";
1332 else
1333 strName = tops[0].Value;
1334 bNameSet = true;
1335 }
1336 }
1337 }
1338 }
1339
1340 List<Tuple<string, int, int, int, int>> rgInputs = new List<Tuple<string, int, int, int, int>>();
1341 rgInputs.Add(new Tuple<string, int, int, int, int>(strName, nNum, nChannels, nHeight, nWidth));
1342
1343 bool bFoundInput = false;
1344 bool bFoundMemoryData = false;
1345
1346 foreach (RawProto layer in rgLayers)
1347 {
1348 RawProto type = layer.FindChild("type");
1349 if (type != null)
1350 {
1351 string strType = type.Value.ToLower();
1352 if (strType == "input")
1353 {
1354 bFoundInput = true;
1355
1356 if (includeLayer(layer, stage, out psInclude, out psExclude))
1357 {
1358 rgInputs.Clear();
1359
1360 RawProtoCollection rgTop = layer.FindChildren("top");
1361 RawProto input_param = layer.FindChild("input_param");
1362 if (input_param != null)
1363 {
1364 RawProtoCollection rgShape = input_param.FindChildren("shape");
1365
1366 if (rgTop.Count == rgShape.Count)
1367 {
1368 for (int i = 0; i < rgTop.Count; i++)
1369 {
1370 if (bUsesLstm && i < 2)
1371 {
1372 RawProtoCollection rgDim = rgShape[i].FindChildren("dim");
1373 if (rgDim.Count > 1)
1374 {
1375 rgDim[1].Value = "1";
1376 }
1377 }
1378
1379 if (rgTop[i].Value.ToLower() != "label")
1380 {
1381 List<int> rgVal = new List<int>();
1382 RawProtoCollection rgDim = rgShape[i].FindChildren("dim");
1383 foreach (RawProto dim in rgDim)
1384 {
1385 rgVal.Add(int.Parse(dim.Value));
1386 }
1387
1388 nNum = (rgVal.Count > 0) ? rgVal[0] : 1;
1389 nChannels = (rgVal.Count > 1) ? rgVal[1] : 1;
1390 nHeight = (rgVal.Count > 2) ? rgVal[2] : 1;
1391 nWidth = (rgVal.Count > 3) ? rgVal[3] : 1;
1392
1393 rgInputs.Add(new Tuple<string, int, int, int, int>(rgTop[i].Value, nNum, nChannels, nHeight, nWidth));
1394 }
1395 }
1396 }
1397 }
1398 }
1399 }
1400 else if (strType == "memorydata")
1401 {
1402 bFoundMemoryData = true;
1403
1404 if (includeLayer(layer, stage, out psInclude, out psExclude))
1405 {
1406 bNoInput = true;
1407 rgInputs.Clear();
1408 }
1409 }
1410 else if (strType == "data")
1411 {
1412 if (rgInputs.Count > 0)
1413 {
1414 RawProtoCollection colTop = layer.FindChildren("top");
1415 if (colTop.Count > 0)
1416 {
1417 rgInputs[0] = new Tuple<string, int, int, int, int>(colTop[0].Value, rgInputs[0].Item2, rgInputs[0].Item3, rgInputs[0].Item4, rgInputs[0].Item5);
1418 break;
1419 }
1420 }
1421 }
1422 else if (strType == "tokenizeddata")
1423 {
1424 if (rgInputs.Count > 0)
1425 {
1426 RawProtoCollection colTop = layer.FindChildren("top");
1427 if (colTop.Count > 0)
1428 {
1429 rgInputs[0] = new Tuple<string, int, int, int, int>("data", rgInputs[0].Item2, rgInputs[0].Item3, rgInputs[0].Item4, rgInputs[0].Item5);
1430 layer.Children.Add<string>("bottom", new List<string>() { "data" });
1431 break;
1432 }
1433 }
1434 }
1435
1436 if (bFoundInput && bFoundMemoryData)
1437 break;
1438 }
1439 }
1440
1441 RawProto input = null;
1442 RawProto input_shape = null;
1443 RawProtoCollection rgInput = null;
1444 RawProtoCollection rgInputShape = null;
1445
1446 if (!bNoInput)
1447 {
1448 rgInput = new RawProtoCollection();
1449 rgInputShape = new RawProtoCollection();
1450
1451 input = proto.FindChild("input");
1452 if (input != null)
1453 {
1454 input.Value = rgInputs[0].Item1;
1455 }
1456 else
1457 {
1458 for (int i = 0; i < rgInputs.Count; i++)
1459 {
1460 input = new RawProto("input", rgInputs[i].Item1, null, RawProto.TYPE.STRING);
1461 rgInput.Add(input);
1462 nInputInsertIdx = nNameIdx;
1463 nNameIdx++;
1464 }
1465 }
1466
1467 input_shape = proto.FindChild("input_shape");
1468 if (input_shape != null)
1469 {
1470 RawProtoCollection colDim = input_shape.FindChildren("dim");
1471
1472 if (colDim.Count > 0)
1473 colDim[0].Value = rgInputs[0].Item2.ToString();
1474
1475 if (colDim.Count > 1)
1476 colDim[1].Value = rgInputs[0].Item3.ToString();
1477
1478 if (colDim.Count > 2)
1479 colDim[2].Value = rgInputs[0].Item4.ToString();
1480
1481 if (colDim.Count > 3)
1482 colDim[3].Value = rgInputs[0].Item5.ToString();
1483 }
1484 else
1485 {
1486 for (int i = 0; i < rgInputs.Count; i++)
1487 {
1488 input_shape = new RawProto("input_shape", "");
1489
1490 nNum = rgInputs[i].Item2;
1491 nChannels = rgInputs[i].Item3;
1492 nHeight = rgInputs[i].Item4;
1493 nWidth = rgInputs[i].Item5;
1494
1495 input_shape.Children.Add(new RawProto("dim", nNum.ToString()));
1496 input_shape.Children.Add(new RawProto("dim", nChannels.ToString()));
1497
1498 if (nHeight > 1 || nWidth > 1)
1499 {
1500 input_shape.Children.Add(new RawProto("dim", nHeight.ToString()));
1501 input_shape.Children.Add(new RawProto("dim", nWidth.ToString()));
1502 }
1503
1504 rgInputShape.Add(input_shape);
1505 nInputShapeInsertIdx = nNameIdx;
1506 }
1507 }
1508 }
1509
1510 RawProto net_name = proto.FindChild("name");
1511 if (net_name != null)
1512 net_name.Value += "-Live";
1513
1514 RawProtoCollection rgRemove = new RawProtoCollection();
1515
1516 List<RawProto> rgProtoSoftMaxLoss = new List<basecode.RawProto>();
1517 RawProto protoSoftMax = null;
1518
1519 foreach (RawProto layer in rgLayers)
1520 {
1521 RawProto type = layer.FindChild("type");
1522 if (type != null)
1523 {
1524 string strType = type.Value.ToLower();
1525 bool bKeepLayer = false;
1526
1527 bool bInclude = includeLayer(layer, stage, out psInclude, out psExclude);
1528
1529 if (strType == "data" || strType == "annotateddata" || strType == "batchdata" || strType == "tokenizeddata")
1530 {
1531 if (psInclude.Find(Phase.TEST, stage) != null)
1532 protoTransform = layer.FindChild("transform_param");
1533 }
1534 else if (strType == "decode")
1535 {
1536 List<RawProto> rgBtm = new List<RawProto>();
1537
1538 foreach (RawProto child in layer.Children)
1539 {
1540 if (child.Name == "bottom")
1541 rgBtm.Add(child);
1542 }
1543
1544 if (rgBtm.Count > 0)
1545 rgBtm.RemoveAt(0);
1546
1547 foreach (RawProto btm in rgBtm)
1548 {
1549 layer.Children.Remove(btm);
1550 }
1551 }
1552
1553 if (!bInclude)
1554 {
1555 rgRemove.Add(layer);
1556 }
1557 else if (psExclude.Find(Phase.RUN, stage) != null)
1558 {
1559 rgRemove.Add(layer);
1560 }
1561 else if (strType == "input")
1562 {
1563 rgRemove.Add(layer);
1564 }
1565 else if (strType == "softmaxwithloss" ||
1566 strType == "softmaxcrossentropy_loss" ||
1567 strType == "softmaxcrossentropyloss" ||
1568 strType == "softmaxcrossentropy2_loss" ||
1569 strType == "softmaxcrossentropy2loss")
1570 {
1571 if (!bSkipLossLayer)
1572 {
1573 rgProtoSoftMaxLoss.Add(layer);
1574 bKeepLayer = true;
1575 }
1576 else
1577 {
1578 rgRemove.Add(layer);
1579 }
1580 }
1581 else if (strType == "memoryloss" ||
1582 strType == "contrastive_loss" ||
1583 strType == "contrastiveloss" ||
1584 strType == "euclidean_loss" ||
1585 strType == "euclideanloss" ||
1586 strType == "hinge_loss" ||
1587 strType == "hingeloss" ||
1588 strType == "infogain_loss" ||
1589 strType == "infogainloss" ||
1590 strType == "multinomiallogistic_loss" ||
1591 strType == "multinomiallogisticloss" ||
1592 strType == "sigmoidcrossentropy_loss" ||
1593 strType == "sigmoidcrossentropyloss" ||
1594 strType == "triplet_loss" ||
1595 strType == "tripletloss" ||
1596 strType == "triplet_loss_simple" ||
1597 strType == "tripletlosssimple")
1598 {
1599 rgRemove.Add(layer);
1600 }
1601 else if (strType == "softmax")
1602 {
1603 protoSoftMax = layer;
1604 }
1605 else if (strType == "labelmapping")
1606 {
1607 rgRemove.Add(layer);
1608 }
1609 else if (strType == "debug")
1610 {
1611 rgRemove.Add(layer);
1612 }
1613 else if (strType == "tokenizeddata")
1614 {
1615 //rgRemove.Add(layer);
1616 }
1617
1618 if (!bKeepLayer && psInclude.FindAllWith(Phase.TEST, Phase.TRAIN).Count > 0 && psInclude.FindAllWith(Phase.RUN).Count == 0)
1619 {
1620 rgRemove.Add(layer);
1621 }
1622 else
1623 {
1624 RawProto max_btm = layer.FindChild("max_bottom_count");
1625 if (max_btm != null)
1626 {
1627 RawProto phase1 = max_btm.FindChild("phase");
1628 RawProto stage1 = max_btm.FindChild("stage");
1629
1630 if (phase1 != null && phase1.Value == "RUN" && (stage1 == null || stage1.Value == stage.ToString() || stage1.Value == Stage.NONE.ToString()))
1631 {
1632 RawProto count = max_btm.FindChild("count");
1633 int nCount = int.Parse(count.Value);
1634
1635 int nBtmIdx = layer.FindChildIndex("bottom");
1636 int nBtmEnd = layer.Children.Count;
1637 List<int> rgRemoveIdx = new List<int>();
1638
1639 for (int i = nBtmIdx; i < layer.Children.Count; i++)
1640 {
1641 if (layer.Children[i].Name != "bottom")
1642 {
1643 nBtmEnd = i;
1644 break;
1645 }
1646 }
1647
1648 for (int i = nBtmEnd - 1; i >= nBtmIdx + nCount; i--)
1649 {
1650 layer.Children.RemoveAt(i);
1651 }
1652 }
1653 }
1654 }
1655 }
1656
1657 RawProto exclude = layer.FindChild("exclude");
1658 if (exclude != null)
1659 {
1660 RawProto phase = exclude.FindChild("phase");
1661 if (phase != null)
1662 {
1663 if (phase.Value == "RUN")
1664 {
1665 if (!rgRemove.Contains(layer))
1666 rgRemove.Add(layer);
1667 }
1668 }
1669 }
1670 }
1671
1672 foreach (RawProto protoSoftMaxLoss in rgProtoSoftMaxLoss)
1673 {
1674 if (protoSoftMax != null)
1675 {
1676 rgRemove.Add(protoSoftMaxLoss);
1677 }
1678 else
1679 {
1680 RawProto type = protoSoftMaxLoss.FindChild("type");
1681 if (type != null)
1682 type.Value = "Softmax";
1683
1684 RawProtoCollection colBtm = protoSoftMaxLoss.FindChildren("bottom");
1685
1686 for (int i = 1; i < colBtm.Count; i++)
1687 {
1688 protoSoftMaxLoss.RemoveChild("bottom", colBtm[i].Value, true);
1689 }
1690 }
1691 }
1692
1693 foreach (RawProto layer in rgRemove)
1694 {
1695 proto.RemoveChild(layer);
1696 }
1697
1698 RawProto layer1 = proto.FindChild("layer");
1699 if (layer1 != null)
1700 {
1701 RawProto btm = layer1.FindChild("bottom");
1702 if (btm != null)
1703 btm.Value = strName;
1704 }
1705
1706 if (input != null && input_shape != null)
1707 {
1708 if (protoTransform != null)
1709 {
1710 RawProto resize = protoTransform.FindChild("resize_param");
1711
1712 if (resize != null)
1713 {
1714 bool bActive = (bool)resize.FindValue("active", typeof(bool));
1715 if (bActive)
1716 {
1717 int nNewHeight = (int)resize.FindValue("height", typeof(int));
1718 int nNewWidth = (int)resize.FindValue("width", typeof(int));
1719
1720 if (rgInputShape[0].Children.Count < 1)
1721 rgInputShape[0].Children.Add(new RawProto("dim", "1"));
1722
1723 if (rgInputShape[0].Children.Count < 2)
1724 rgInputShape[0].Children.Add(new RawProto("dim", "1"));
1725
1726 if (rgInputShape[0].Children.Count < 3)
1727 rgInputShape[0].Children.Add(new RawProto("dim", nNewHeight.ToString()));
1728 else
1729 rgInputShape[0].Children[2] = new RawProto("dim", nNewHeight.ToString());
1730
1731 if (rgInputShape[0].Children.Count < 4)
1732 rgInputShape[0].Children.Add(new RawProto("dim", nNewWidth.ToString()));
1733 else
1734 rgInputShape[0].Children[3] = new RawProto("dim", nNewWidth.ToString());
1735 }
1736 }
1737 }
1738
1739 for (int i = rgInputShape.Count - 1; i >= 0; i--)
1740 {
1741 proto.Children.Insert(0, rgInputShape[i]);
1742 }
1743
1744 for (int i = rgInput.Count - 1; i >= 0; i--)
1745 {
1746 proto.Children.Insert(0, rgInput[i]);
1747 }
1748 }
1749
1750 return proto;
1751 }
1752
1760 public void SetDataset(DatasetDescriptor dataset)
1761 {
1762 if (dataset == null)
1763 return;
1764
1765 m_project.Dataset = dataset;
1766
1767 if (m_project.ModelDescription != null && m_project.ModelDescription.Length > 0)
1768 {
1769 bool bResized = false;
1770 string strProto = SetDataset(m_project.ModelDescription, dataset, out bResized);
1771 RawProto proto = RawProto.Parse(strProto);
1772
1773 if (OnOverrideModel != null)
1774 {
1775 OverrideProjectArgs args = new OverrideProjectArgs(proto);
1776 OnOverrideModel(this, args);
1777 proto = args.Proto;
1778 }
1779
1780 ModelDescription = proto.ToString();
1781 }
1782
1783 if (m_project.SolverDescription != null && m_project.SolverDescription.Length > 0)
1784 {
1785 if (OnOverrideSolver != null)
1786 {
1787 RawProto proto = RawProto.Parse(m_project.SolverDescription);
1788 OverrideProjectArgs args = new OverrideProjectArgs(proto);
1789 OnOverrideSolver(this, args);
1790 proto = args.Proto;
1791
1792 SolverDescription = proto.ToString();
1793 }
1794 }
1795 }
1796
1807 public static string SetDataset(string strModelDesc, DatasetDescriptor dataset, out bool bResized, bool bUpdateOutputs = false)
1808 {
1809 bResized = false;
1810
1811 if (dataset == null)
1812 return null;
1813
1814 if (dataset.Name == "MODEL")
1815 return strModelDesc;
1816
1817 string strTypeLast = null;
1818 RawProto protoLast = null;
1819 RawProto proto = RawProto.Parse(strModelDesc);
1820 List<RawProto> rgLastIp = new List<RawProto>();
1821 RawProtoCollection colLayers = proto.FindChildren("layer");
1822 RawProto protoDataTrainBatch = null;
1823 RawProto protoDataTestBatch = null;
1824
1825 if (colLayers.Count == 0)
1826 colLayers = proto.FindChildren("layers");
1827
1829
1830 foreach (RawProto protoChild in colLayers)
1831 {
1832 RawProto type = protoChild.FindChild("type");
1833 RawProto name = protoChild.FindChild("name");
1834
1835 string strType = type.Value.ToLower();
1836
1837 if (strType == "data")
1838 {
1839 int nCropSize = 0;
1840
1841 RawProto data_param = protoChild.FindChild("data_param");
1842 if (data_param != null)
1843 {
1844 RawProto batchProto = data_param.FindChild("batch_size");
1845
1846 RawProto include = protoChild.FindChild("include");
1847 if (include != null)
1848 {
1849 RawProto phase = include.FindChild("phase");
1850 if (phase != null)
1851 {
1852 RawProto source = data_param.FindChild("source");
1853 if (phase.Value == "TEST")
1854 {
1855 protoDataTestBatch = batchProto;
1856
1857 if (source != null)
1858 {
1859 source.Value = dataset.TestingSource.Name;
1860 nCropSize = dataset.TestingSource.ImageHeight;
1861 }
1862 else
1863 {
1864 data_param.Children.Add(new RawProto("source", dataset.TestingSource.Name, null, RawProto.TYPE.STRING));
1865 }
1866 }
1867 else
1868 {
1869 protoDataTrainBatch = batchProto;
1870
1871 if (source != null)
1872 {
1873 source.Value = dataset.TrainingSource.Name;
1874 nCropSize = dataset.TrainingSource.ImageHeight;
1875 }
1876 else
1877 {
1878 data_param.Children.Add(new RawProto("source", dataset.TrainingSource.Name, null, RawProto.TYPE.STRING));
1879 }
1880 }
1881 }
1882 }
1883 }
1884
1885 RawProto transform_param = protoChild.FindChild("transform_param");
1886 if (transform_param != null)
1887 {
1888 RawProto crop_size = transform_param.FindChild("crop_size");
1889 if (crop_size != null)
1890 {
1891 int nSize = int.Parse(crop_size.Value);
1892
1893 if (nCropSize != nSize)
1894 crop_size.Value = nCropSize.ToString();
1895 }
1896 }
1897 }
1898 else if (strType.Contains("loss"))
1899 {
1900 if (colIP.Count > 0)
1901 {
1902 rgLastIp.Add(colIP[0]);
1903 colIP.Clear();
1904 }
1905 }
1906 else if (strType == "inner_product" || strType == "innerproduct")
1907 {
1908 colIP.Insert(0, protoChild);
1909 }
1910
1911 protoLast = protoChild;
1912 strTypeLast = strType;
1913 }
1914
1915 if (protoDataTestBatch != null && protoDataTrainBatch != null)
1916 {
1917 int nTestSize = int.Parse(protoDataTestBatch.Value);
1918 int nTrainSize = int.Parse(protoDataTrainBatch.Value);
1919
1920 if (nTrainSize < nTestSize)
1921 protoDataTrainBatch.Value = nTestSize.ToString();
1922 }
1923
1924 if (bUpdateOutputs)
1925 {
1926 foreach (RawProto lastIp in rgLastIp)
1927 {
1928 RawProto protoParam = lastIp.FindChild("inner_product_param");
1929 if (protoParam != null)
1930 {
1931 RawProto protoNumOut = protoParam.FindChild("num_output");
1932 if (protoNumOut != null)
1933 {
1934 int nNumOut = dataset.TrainingSource.Labels.Count;
1935
1936 if (nNumOut > 0)
1937 {
1938 protoNumOut.Value = nNumOut.ToString();
1939 bResized = true;
1940 }
1941 }
1942 }
1943 }
1944 }
1945
1946 return proto.ToString();
1947 }
1948
1962 public static string FindLayerParameter(string strModelDescription, string strLayerName, string strLayerType, string strParam, string strField, Phase phaseMatch = Phase.NONE)
1963 {
1964 RawProto proto = RawProto.Parse(strModelDescription);
1965
1966 RawProtoCollection rgLayers = proto.FindChildren("layer");
1967 RawProto firstFound = null;
1968
1969 foreach (RawProto layer in rgLayers)
1970 {
1971 RawProto type = layer.FindChild("type");
1972 RawProto name = layer.FindChild("name");
1973
1974 if (strLayerType == type.Value.ToString() && (strLayerName == null || name.Value.ToString() == strLayerName))
1975 {
1976 if (phaseMatch != Phase.NONE)
1977 {
1978 RawProto include = layer.FindChild("include");
1979
1980 if (include != null)
1981 {
1982 RawProto phase = include.FindChild("phase");
1983 if (phase != null)
1984 {
1985 if (phase.Value == phaseMatch.ToString())
1986 {
1987 firstFound = layer;
1988 break;
1989 }
1990 }
1991 }
1992 else
1993 {
1994 if (firstFound == null)
1995 firstFound = layer;
1996 }
1997 }
1998 else
1999 {
2000 if (firstFound == null)
2001 firstFound = layer;
2002 }
2003 }
2004 }
2005
2006 if (firstFound == null)
2007 return null;
2008
2009 RawProto child = null;
2010
2011 if (strParam != null)
2012 child = firstFound.FindChild(strParam);
2013
2014 if (child != null)
2015 firstFound = child;
2016
2017 return firstFound.FindValue(strField);
2018 }
2019
2024 public bool DisableTesting()
2025 {
2026 // Force parse the proto if not already parsed.
2027 if (m_protoSolver == null)
2028 {
2029 string strProto = SolverDescription;
2030 SolverDescription = strProto;
2031 }
2032
2033 bool bSet = false;
2034 RawProto protoTestIter = m_protoSolver.FindChild("test_iter");
2035 RawProto protoTestInterval = m_protoSolver.FindChild("test_interval");
2036 RawProto protoTestInit = m_protoSolver.FindChild("test_initialization");
2037
2038 if (protoTestInterval != null)
2039 {
2040 if (protoTestInterval.Value != "0")
2041 {
2042 protoTestInterval.Value = "0";
2043 bSet = true;
2044 }
2045 }
2046
2047 if (protoTestInit != null)
2048 {
2049 if (protoTestInit.Value != "False")
2050 {
2051 protoTestInit.Value = "False";
2052 bSet = true;
2053 }
2054 }
2055
2056 if (protoTestIter != null)
2057 {
2058 m_protoSolver.RemoveChild(protoTestIter);
2059 bSet = true;
2060 }
2061
2062 if (bSet)
2063 SolverDescription = m_protoSolver.ToString();
2064
2065 return bSet;
2066 }
2067
2072 public override string ToString()
2073 {
2074 string strName = Name;
2075
2076 if (strName == null || strName.Length == 0)
2077 {
2078 string strModelDesc = ModelDescription;
2079
2080 if (strModelDesc != null && strModelDesc.Length > 0)
2081 {
2082 int nPos = strModelDesc.IndexOf("name:");
2083
2084 if (nPos < 0)
2085 nPos = strModelDesc.IndexOf("Name:");
2086
2087 if (nPos >= 0)
2088 {
2089 nPos += 5;
2090 int nPos2 = strModelDesc.IndexOfAny(new char[] { ' ', '\n', '\r' }, nPos);
2091
2092 if (nPos2 > 0)
2093 strName = strModelDesc.Substring(nPos + 5, nPos2).Trim();
2094 }
2095 }
2096
2097 if (strName.Length == 0)
2098 strName = "(ID = " + m_project.ID.ToString() + ")";
2099 }
2100
2101 return "Project: " + strName + " -> Dataset: " + m_project.Dataset.Name;
2102 }
2103 }
2104
2105 class PhaseStageCollection
2106 {
2107 List<PhaseStage> m_rgItems = new List<PhaseStage>();
2108
2109 public PhaseStageCollection()
2110 {
2111 }
2112
2113 public int Count
2114 {
2115 get { return m_rgItems.Count; }
2116 }
2117
2118 public bool Add(Phase p, Stage s)
2119 {
2120 PhaseStage ps = Find(p, s);
2121 if (ps == null)
2122 {
2123 m_rgItems.Add(new PhaseStage(p, s));
2124 return true;
2125 }
2126
2127 return false;
2128 }
2129
2130 public PhaseStage Find(Phase p, Stage s)
2131 {
2132 foreach (PhaseStage ps in m_rgItems)
2133 {
2134 if (ps.Phase == p && ps.Stage == s)
2135 return ps;
2136 }
2137
2138 return null;
2139 }
2140
2141 public PhaseStageCollection FindAllWith(Stage stage)
2142 {
2143 PhaseStageCollection psCol = new PhaseStageCollection();
2144
2145 foreach (PhaseStage ps in m_rgItems)
2146 {
2147 if (ps.Stage == stage)
2148 psCol.Add(ps.Phase, ps.Stage);
2149 }
2150
2151 return psCol;
2152 }
2153 public PhaseStageCollection FindAllWith(params Phase[] phase)
2154 {
2155 PhaseStageCollection psCol = new PhaseStageCollection();
2156
2157 foreach (PhaseStage ps in m_rgItems)
2158 {
2159 if (phase.Contains(ps.Phase))
2160 psCol.Add(ps.Phase, ps.Stage);
2161 }
2162
2163 return psCol;
2164 }
2165 }
2166
2167 class PhaseStage
2168 {
2169 Phase m_phase = Phase.NONE;
2170 Stage m_stage = Stage.NONE;
2171
2172 public PhaseStage(Phase p, Stage s)
2173 {
2174 m_phase = p;
2175 m_stage = s;
2176 }
2177
2178 public Phase Phase
2179 {
2180 get { return m_phase; }
2181 }
2182
2183 public Stage Stage
2184 {
2185 get { return m_stage; }
2186 }
2187 }
2188}
The BaseParameter class is the base class for all other parameter classes.
static bool TryParse(string strVal, out double df)
Parse double values using the US culture if the decimal separator = '.', then using the native cultur...
static double ParseDouble(string strVal)
Parse double values using the US culture if the decimal separator = '.', then using the native cultur...
The OverrideProjectArgs is passed as an argument to the OnOverrideModel and OnOverrideSolver events f...
Definition: EventArgs.cs:180
RawProto Proto
Get/set the RawProto used.
Definition: EventArgs.cs:196
The ProjectEx class manages a project containing the solver description, model description,...
Definition: ProjectEx.cs:15
bool DisableTesting()
Disables the testing interval so that no test passes are run.
Definition: ProjectEx.cs:2024
SNAPSHOT_WEIGHT_UPDATE_METHOD SnapshotWeightUpdateMethod
Returns the snapshot weight update favor. The snapshot can favor an improving accuracy,...
Definition: ProjectEx.cs:694
GroupDescriptor ModelGroup
Return the model group descriptor of the group that the Project participates in (if any).
Definition: ProjectEx.cs:780
string GpuOverride
Returns the list of comma separated GPU ID's that are to be used when training this Project.
Definition: ProjectEx.cs:650
int TargetDatasetID
Get/set the dataset ID of the target dataset (if exists), otherwise return 0.
Definition: ProjectEx.cs:899
bool HasResults
Return whether or not the project has results from a training session.
Definition: ProjectEx.cs:813
ProjectEx(ProjectDescriptor prj, StateDescriptor state=null, bool bExistTrain=false, bool bExistTest=false, bool bQueryModel=true, bool bQuerySolver=true)
The ProjectEx constructor.
Definition: ProjectEx.cs:57
string? SolverDescription
Get/set the solver description script used by the Project.
Definition: ProjectEx.cs:710
bool EnableRandomSelection
Returns whether or not random image selection is enabled. When enabled, images are randomly selected ...
Definition: ProjectEx.cs:632
string Name
Get/set the name of the Project.
Definition: ProjectEx.cs:508
ValueDescriptorCollection ProjectPerformanceItems
Return Project performance metrics.
Definition: ProjectEx.cs:943
int ID
Returns the ID of the Project in the database.
Definition: ProjectEx.cs:517
string DatasetName
Return the name of the dataset used.
Definition: ProjectEx.cs:866
SNAPSHOT_LOAD_METHOD SnapshotLoadMethod
Returns the snapshot load method. When loading the best error or accuracy, the snapshot loaded may no...
Definition: ProjectEx.cs:702
bool Active
Returns whether or not the Project is active.
Definition: ProjectEx.cs:552
int ImageLoadLimitRefreshPeriod
Returns the image load limit refresh period in milliseconds.
Definition: ProjectEx.cs:675
bool UseTrainingSourceForTesting
Returns whether or not the Project uses the training data source when testing (default = false).
Definition: ProjectEx.cs:605
int Iterations
Get/set the current number of iterations that the Project has been trained.
Definition: ProjectEx.cs:821
int TotalIterations
Get/set the total number of iterations that the Project has been trained.
Definition: ProjectEx.cs:804
bool ExistTrainResults
Return whether or not training results exist.
Definition: ProjectEx.cs:935
SettingsCaffe Settings
Get/set the Caffe setting to use with the Project.
Definition: ProjectEx.cs:499
static string FindLayerParameter(string strModelDescription, string strLayerName, string strLayerType, string strParam, string strField, Phase phaseMatch=Phase.NONE)
This method searches for a given parameter within a given layer, optionally for a certain Phase.
Definition: ProjectEx.cs:1962
string ModelName
Return the name of the model used by the Project.
Definition: ProjectEx.cs:951
bool RequiresDataCriteria()
Returns whether or not the data criteria is required by the current project model (e....
Definition: ProjectEx.cs:84
ProjectEx(string strName, string strDsName=null)
The ProjectEx constructor.
Definition: ProjectEx.cs:41
IMAGEDB_LOAD_METHOD ImageLoadMethod
Returns the method used to load the images into memory. Loading all images into memory has the highes...
Definition: ProjectEx.cs:659
double ImageLoadLimitRefreshPercent
Returns the image load limit refresh percentage (to update).
Definition: ProjectEx.cs:683
double? GetSolverSettingAsNumeric(string strParam)
Get a setting from the solver descriptor as a double value.
Definition: ProjectEx.cs:454
ParameterDescriptorCollection Parameters
Returns any project parameters that may exist (if any).
Definition: ProjectEx.cs:796
void SetDataset(DatasetDescriptor dataset)
Sets the dataset used by the Project, overriding the current dataset used.
Definition: ProjectEx.cs:1760
DatasetDescriptor Dataset
Return the descriptor of the dataset used.
Definition: ProjectEx.cs:880
static RawProto CreateModelForRunning(string strModelDescription, string strName, int nNum, int nChannels, int nHeight, int nWidth, out RawProto protoTransform, Stage stage=Stage.NONE, bool bSkipLossLayer=false)
Create a model description as a RawProto for running the Project.
Definition: ProjectEx.cs:1291
TRAINING_CATEGORY TrainingCategory
Returns the training category of the project, or NONE if no custom trainer is used.
Definition: ProjectEx.cs:560
int? GetSolverSettingAsInt(string strParam)
Get a setting from the solver descriptor as an integer value.
Definition: ProjectEx.cs:472
string GetCustomTrainer(out string strProperties)
Returns the custom trainer and properties used by the project (if any).
Definition: ProjectEx.cs:312
void LoadSolverFile(string strFile)
Load the solver description from a file.
Definition: ProjectEx.cs:992
int ImageLoadLimit
Returns the image load limit.
Definition: ProjectEx.cs:667
bool ExistTestResults
Return whether or not testing results exist.
Definition: ProjectEx.cs:927
string SolverType
Return the type of the Solver used by the Project.
Definition: ProjectEx.cs:959
string? ModelDescription
Get/set the model description script used by the Project.
Definition: ProjectEx.cs:741
GroupDescriptor ProjectGroup
Return the project group descriptor of the group that the Project resides (if any).
Definition: ProjectEx.cs:772
EventHandler< OverrideProjectArgs > OnOverrideSolver
The OverrideSolver event fires each time the SetDataset function is called.
Definition: ProjectEx.cs:34
double? GetLayerSetting(Phase phase, string strLayer, string strParam)
Returns the setting of a Layer (if found).
Definition: ProjectEx.cs:402
byte[] WeightsState
Get/set the weight state.
Definition: ProjectEx.cs:857
override string ToString()
Returns a string representation of the Project.
Definition: ProjectEx.cs:2072
int GetBatchSize(Phase phase)
Returns the batch size of the project used in a given Phase.
Definition: ProjectEx.cs:359
bool EnablePairSelection
Returns whether or not pair selection is enabled. When using pair selection, images are queried in pa...
Definition: ProjectEx.cs:642
RawProto CreateModelForRunning(string strName, int nNum, int nChannels, int nHeight, int nWidth, out RawProto protoTransform, Stage stage=Stage.NONE, bool bSkipLossLayer=false)
Create a model description as a RawProto for running the Project.
Definition: ProjectEx.cs:1024
string GetSolverSetting(string strParam)
Get a setting from the solver descriptor.
Definition: ProjectEx.cs:437
bool DatasetAdjusted
Get/set whether or not the dataset for the project has been changed.
Definition: ProjectEx.cs:301
int OriginalID
Get/set the original project ID.
Definition: ProjectEx.cs:525
DatasetDescriptor DatasetTarget
Returns the target dataset (if exists) or null if it does not.
Definition: ProjectEx.cs:891
string Owner
Get/set the ID of the Project owner.
Definition: ProjectEx.cs:543
double BestAccuracy
Get/set the best accuracy observed while testing the Project.
Definition: ProjectEx.cs:830
void LoadModelFile(string strFile)
Load the model description from a file.
Definition: ProjectEx.cs:1004
double BestError
Get/set the best error observed while training the Project.
Definition: ProjectEx.cs:839
bool SetSolverVariable(string strVar, string strVal)
Set a given Solver variable in the solver description script.
Definition: ProjectEx.cs:969
bool EnableLabelBalancing
Returns whether or not label balancing is enabled. When enabled, first the label set is randomly sele...
Definition: ProjectEx.cs:614
GroupDescriptor DatasetGroup
Return the dataset group descriptor of the group that the Project participates in (if any).
Definition: ProjectEx.cs:788
bool EnableLabelBoosting
Returns whether or not label boosting is enabled. When using Label boosting, images are selected from...
Definition: ProjectEx.cs:623
static RawProto CreateModelForTraining(string strModelDescription, string strName, bool bCaffeFormat=false)
Create a model description as a RawProto for training the Project.
Definition: ProjectEx.cs:1036
byte[] SolverState
Get/set the solver state.
Definition: ProjectEx.cs:848
double SuperBoostProbability
Get/set the super boost probability used by the Project.
Definition: ProjectEx.cs:596
Stage Stage
Return the stage under which the project was opened.
Definition: ProjectEx.cs:587
EventHandler< OverrideProjectArgs > OnOverrideModel
The OverrrideModel event fires each time the SetDataset function is called.
Definition: ProjectEx.cs:30
static string SetDataset(string strModelDesc, DatasetDescriptor dataset, out bool bResized, bool bUpdateOutputs=false)
Sets the dataset of a model, overriding the current dataset used.
Definition: ProjectEx.cs:1807
bool? GetSolverSettingAsBool(string strParam)
Get a setting from the solver descriptor as a boolean value.
Definition: ProjectEx.cs:486
The RawProtoCollection class is a list of RawProto objects.
bool Remove(RawProto p)
Removes a RawProto from the collection.
void RemoveAt(int nIdx)
Removes the RawProto at a given index in the collection.
void Add(RawProto p)
Adds a RawProto to the collection.
void Insert(int nIdx, RawProto p)
Inserts a new RawProto into the collection at a given index.
int Count
Returns the number of items in the collection.
The RawProto class is used to parse and output Google prototxt file data.
Definition: RawProto.cs:17
TYPE
Defines the type of a RawProto node.
Definition: RawProto.cs:27
string Name
Returns the name of the node.
Definition: RawProto.cs:71
RawProtoCollection Children
Returns a collection of this nodes child nodes.
Definition: RawProto.cs:96
string Value
Get/set the value of the node.
Definition: RawProto.cs:79
RawProto FindChild(string strName)
Searches for a given node.
Definition: RawProto.cs:231
override string ToString()
Returns the RawProto as its full prototxt string.
Definition: RawProto.cs:677
static RawProto Parse(string str)
Parses a prototxt and places it in a new RawProto.
Definition: RawProto.cs:306
int FindChildIndex(string strName)
Searches for the index to a given node's child.
Definition: RawProto.cs:247
string FindValue(string strName)
Searches for a falue of a node within this nodes children.
Definition: RawProto.cs:105
bool RemoveChild(RawProto p)
Removes a given child from this node's children.
Definition: RawProto.cs:188
RawProtoCollection FindChildren(params string[] rgstrName)
Searches for all children with a given name in this node's children.
Definition: RawProto.cs:263
The SettingsCaffe defines the settings used by the MyCaffe CaffeControl.
SNAPSHOT_WEIGHT_UPDATE_METHOD SnapshotWeightUpdateMethod
Get/set the snapshot update method.
bool EnableRandomInputSelection
Get/set random image selection. When enabled, images are randomly selected from the entire set,...
SNAPSHOT_LOAD_METHOD SnapshotLoadMethod
Get/set the snapshot load method.
int ImageDbAutoRefreshScheduledUpdateInMs
Get/set the automatic refresh scheduled udpate period (default = 10000, only applies when ImageDbLoad...
bool EnableLabelBoosting
DEPRECIATED: Get/set label boosting. When using Label boosting, images are selected from boosted labe...
double SuperBoostProbability
Get/set the superboost probability used when selecting boosted images.
bool EnablePairInputSelection
Get/set pair image selection. When using pair selection, images are queried in pairs where the first ...
IMAGEDB_LOAD_METHOD ImageDbLoadMethod
Get/set the image database loading method.
double ImageDbAutoRefreshScheduledReplacementPercent
Get/set the automatic refresh scheduled update replacement percentage used on refresh (default = 0....
int ImageDbLoadLimit
Get/set the image database load limit.
bool EnableLabelBalancing
Get/set label balancing. When enabled, first the label set is randomly selected and then the image is...
string Owner
Get/set the owner of the item.
int ID
Get/set the database ID of the item.
string Name
Get/set the name of the item.
The DatasetDescriptor class describes a dataset which contains both a training data source and testin...
GroupDescriptor DatasetGroup
Returns the dataset group.
SourceDescriptor TrainingSource
Get/set the training data source.
GroupDescriptor ModelGroup
Get/set the dataset model group.
string? TrainingSourceName
Returns the training source name, or null if not specifies.
SourceDescriptor TestingSource
Get/set the testing data source.
string? TestingSourceName
Returns the testing source name or null if not specified.
DatasetDescriptor(int nID, string strName, GroupDescriptor grpModel, GroupDescriptor grpDs, SourceDescriptor srcTrain, SourceDescriptor srcTest, string strCreatorName, string strDescription, string strOwner=null, GYM_TYPE gymType=GYM_TYPE.NONE)
The DatasetDescriptor constructor.
The GroupDescriptor class defines a group.
The ParameterDescriptorCollection class contains a list of ParameterDescriptor's.
void Add(ParameterDescriptor p)
Adds a ParameterDescriptor to the collection.
ParameterDescriptor Find(string strName)
Searches for a parameter by name in the collection.
The ParameterDescriptor class describes a parameter in the database.
string Value
Get/set the value of the item.
The ProjectDescriptor class contains all information describing a project, such as its: dataset,...
ValueDescriptorCollection AnalysisItems
Returns the collection of analysis ValueDescriptors of the Project.
string SolverDescription
Get/set the solver description script.
GroupDescriptor Group
Get/set the project group.
string ModelDescription
Get/set the model description script.
DatasetDescriptor DatasetTarget
Get/set the secondary 'target' dataset (if used).
int TotalIterations
Get/set the total iterations.
virtual string GpuOverride
Get/set the GPU ID's to use as an override.
SettingsCaffe Settings
Get/set the settings of the Project.
bool Active
Returns whether or not the project is active.
ParameterDescriptorCollection Parameters
Returns the collection of parameters of the Project.
DatasetDescriptor Dataset
Get/set the dataset used.
The SourceDescriptor class contains all information describing a data source.
bool SaveImagesToFile
Gets whether or not the images are saved to the file system (true), or directly to the database (fals...
The StateDescriptor class contains the information related to the state of a project incuding the sol...
double Accuracy
Returns the accuracy observed while testing.
byte[] State
Get/set the state of a Solver in training.
bool? HasResults
Returns whether or not the state has results (e.g. it has been trained at least to some degree).
byte[] Weights
Get/set the weights of a trained Net.
double Error
Specifies the error observed whiel training.
int Iterations
Specifies the number of iterations run.
The ValueDescriptorCollection class contains a list of ValueDescriptor's.
The descriptors namespace contains all descriptor used to describe various items stored within the da...
The MyCaffe.basecode contains all generic types used throughout MyCaffe.
Definition: Annotation.cs:12
Phase
Defines the Phase under which to run a Net.
Definition: Interfaces.cs:42
SNAPSHOT_LOAD_METHOD
Defines the snapshot load method.
Definition: Interfaces.cs:185
IMAGEDB_LOAD_METHOD
Defines how to laod the images into the image database.
Definition: Interfaces.cs:135
SNAPSHOT_WEIGHT_UPDATE_METHOD
Defines the snapshot weight update method.
Definition: Interfaces.cs:162
TRAINING_CATEGORY
Defines the category of training.
Definition: Interfaces.cs:15
Stage
Specifies the stage underwhich to run a custom trainer.
Definition: Interfaces.cs:69
The MyCaffe namespace contains the main body of MyCaffe code that closesly tracks the C++ Caffe open-...
Definition: Annotation.cs:12