MyCaffe  1.12.2.41
Deep learning software for Windows C# programmers.
MyCaffeControl.cs
1using System;
2using System.Collections.Generic;
4using System.Diagnostics;
5using System.Linq;
6using System.Text;
7using System.Threading;
8using System.Drawing;
9using System.Threading.Tasks;
10using System.IO;
11using MyCaffe.basecode;
13using MyCaffe.db.image;
14using MyCaffe.solvers;
15using MyCaffe.common;
16using MyCaffe.param;
17using MyCaffe.data;
18using MyCaffe.layers;
19using System.Globalization;
20using System.Reflection;
21using System.Security.Cryptography;
22using System.Net;
24
28namespace MyCaffe
29{
34 public partial class MyCaffeControl<T> : Component, IXMyCaffeState<T>, IXMyCaffe<T>, IXMyCaffeNoDb<T>, IXMyCaffeExtension<T>, IDisposable
35 {
43 protected Log m_log;
47 protected IXDatabaseBase m_db = null;
51 protected bool m_bDbOwner = true;
59 protected AutoResetEvent m_evtForceSnapshot;
63 protected AutoResetEvent m_evtForceTest;
67 protected ManualResetEvent m_evtPause;
75 protected ProjectEx m_project = null;
79 protected DatasetDescriptor m_dataSet = null;
83 protected string m_strCudaPath = null;
87 protected List<int> m_rgGpu;
88 CudaDnn<T> m_cuda;
89 Solver<T> m_solver;
90 Net<T> m_net;
91 bool m_bOwnRunNet = true;
92 MemoryStream m_msWeights = new MemoryStream();
93 Guid m_guidUser;
94 PersistCaffe<T> m_persist;
95 BlobShape m_inputShape = null;
96 Phase m_lastPhaseRun = Phase.NONE;
97 long m_hCopyBuffer = 0;
98 string m_strStage = null;
99 bool m_bLoadLite = false;
100 string m_strSolver = null; // Used with LoadLite.
101 string m_strModel = null; // Used with LoadLite.
102 ManualResetEvent m_evtSyncUnload = new ManualResetEvent(false);
103 ManualResetEvent m_evtSyncMain = new ManualResetEvent(false);
104 ConnectInfo m_dsCi = null;
105 bool m_bEnableVerboseStatus = false;
106 T[] m_rgRunData = null;
107 BlobShape m_loadToRunShape = null;
108
112 public event EventHandler<SnapshotArgs> OnSnapshot;
116 public event EventHandler<TrainingIterationArgs<T>> OnTrainingIteration;
120 public event EventHandler<TestingIterationArgs<T>> OnTestingIteration;
121
122
136 public MyCaffeControl(SettingsCaffe settings, Log log, CancelEvent evtCancel, AutoResetEvent evtSnapshot = null, AutoResetEvent evtForceTest = null, ManualResetEvent evtPause = null, List<int> rgGpuId = null, string strCudaPath = "", bool bCreateCudaDnn = false, ConnectInfo ci = null)
137 {
138 m_dsCi = ci;
139 m_guidUser = Guid.NewGuid();
140
141 InitializeComponent();
142
143 if (evtCancel == null)
144 throw new ArgumentNullException("The cancel event must be specified!");
145
146 if (evtSnapshot == null)
147 evtSnapshot = new AutoResetEvent(false);
148
149 if (evtForceTest == null)
150 evtForceTest = new AutoResetEvent(false);
151
152 if (evtPause == null)
153 evtPause = new ManualResetEvent(false);
154
155 m_log = log;
156 m_settings = settings;
157 m_evtCancel = evtCancel;
158 m_evtForceSnapshot = evtSnapshot;
159 m_evtForceTest = evtForceTest;
160 m_evtPause = evtPause;
161
162 if (rgGpuId == null)
163 {
164 m_rgGpu = new List<int>();
165 string[] rgstrGpuId = settings.GpuIds.Split(',');
166
167 foreach (string str in rgstrGpuId)
168 {
169 string strGpuId = str.Trim(' ', '\t', '\n', '\r');
170 m_rgGpu.Add(int.Parse(strGpuId));
171 }
172 }
173 else
174 {
175 m_rgGpu = Utility.Clone<int>(rgGpuId);
176 }
177
178 if (m_rgGpu.Count == 0)
179 m_rgGpu.Add(0);
180
181 m_strCudaPath = strCudaPath;
182 m_persist = new common.PersistCaffe<T>(m_log, false);
183
184 if (bCreateCudaDnn)
185 m_cuda = new CudaDnn<T>(m_rgGpu[0], DEVINIT.CUBLAS | DEVINIT.CURAND, null, m_strCudaPath, false);
186 }
187
191 public void dispose()
192 {
193 if (m_evtSyncMain.WaitOne(0))
194 return;
195
196 m_evtSyncMain.Set();
197
198 try
199 {
200 if (m_evtCancel != null)
202
203 if (m_hCopyBuffer != 0)
204 {
205 try
206 {
207 m_cuda.FreeHostBuffer(m_hCopyBuffer);
208 }
209 catch
210 {
211 }
212
213 m_hCopyBuffer = 0;
214 }
215
216 Unload(true, true);
217
218 if (m_cuda != null)
219 {
220 try
221 {
222 m_cuda.Dispose();
223 }
224 catch
225 {
226 }
227
228 m_cuda = null;
229 }
230
231 if (m_msWeights != null)
232 {
233 m_msWeights.Dispose();
234 m_msWeights = null;
235 }
236
237 if (m_dataTransformer != null)
238 {
239 m_dataTransformer.Dispose();
240 m_dataTransformer = null;
241 }
242 }
243 finally
244 {
245 m_evtSyncMain.Reset();
246 }
247 }
248
252 public static FileVersionInfo Version
253 {
254 get
255 {
256 string strLocation = Assembly.GetExecutingAssembly().Location;
257 return FileVersionInfo.GetVersionInfo(strLocation);
258 }
259 }
260
265 {
266 get { return m_dsCi; }
267 }
268
272 public string CurrentStage
273 {
274 get { return m_strStage; }
275 }
276
285 public MyCaffeControl<T> Clone(int nGpuID)
286 {
288 s.GpuIds = nGpuID.ToString();
289
290 MyCaffeControl<T> mycaffe = new MyCaffeControl<T>(s, m_log, m_evtCancel, null, null, null, null, m_strCudaPath);
291
292 if (m_bLoadLite)
293 mycaffe.LoadLite(Phase.TRAIN, m_strSolver, m_strModel, null);
294 else
295 mycaffe.Load(Phase.TRAIN, m_project, null, null, false, m_db, (m_db == null) ? false : true, true, m_strStage);
296
297 Net<T> netSrc = GetInternalNet(Phase.TRAIN);
298 Net<T> netDst = mycaffe.GetInternalNet(Phase.TRAIN);
299
300 m_log.CHECK_EQ(netSrc.learnable_parameters.Count, netDst.learnable_parameters.Count, "The src and dst networks do not have the same number of learnable parameters!");
301
302 for (int i = 0; i < netSrc.learnable_parameters.Count; i++)
303 {
304 Blob<T> bSrc = netSrc.learnable_parameters[i];
305 Blob<T> bDst = netDst.learnable_parameters[i];
306
307 mycaffe.m_hCopyBuffer = bDst.CopyFrom(bSrc, false, false, mycaffe.m_hCopyBuffer);
308 }
309
310 return mycaffe;
311 }
312
318 {
319 Net<T> netSrc = src.GetInternalNet(Phase.TRAIN);
320 Net<T> netDst = GetInternalNet(Phase.TRAIN);
321
322 m_log.CHECK_EQ(netSrc.learnable_parameters.Count, netDst.learnable_parameters.Count, "The src and dst networks do not have the same number of learnable parameters!");
323
324 for (int i = 0; i < netSrc.learnable_parameters.Count; i++)
325 {
326 Blob<T> bSrc = netSrc.learnable_parameters[i];
327 Blob<T> bDst = netDst.learnable_parameters[i];
328
329 m_hCopyBuffer = bDst.CopyFrom(bSrc, true, false, m_hCopyBuffer);
330 }
331 }
332
338 {
339 Net<T> netSrc = src.GetInternalNet(Phase.TRAIN);
340 Net<T> netDst = GetInternalNet(Phase.TRAIN);
341
342 m_log.CHECK_EQ(netSrc.learnable_parameters.Count, netDst.learnable_parameters.Count, "The src and dst networks do not have the same number of learnable parameters!");
343
344 for (int i = 0; i < netSrc.learnable_parameters.Count; i++)
345 {
346 Blob<T> bSrc = netSrc.learnable_parameters[i];
347 Blob<T> bDst = netDst.learnable_parameters[i];
348
349 m_hCopyBuffer = bDst.CopyFrom(bSrc, false, false, m_hCopyBuffer);
350 }
351 }
352
359 public double ApplyUpdate(int nIteration)
360 {
361 return m_solver.ApplyUpdate(nIteration);
362 }
363
367 public bool EnableTesting
368 {
369 get { return m_solver.EnableTesting; }
370 set { m_solver.EnableTesting = value; }
371 }
372
377 {
378 get { return m_bEnableVerboseStatus; }
379 set { m_bEnableVerboseStatus = value; }
380 }
381
387 public void Unload(bool bUnloadImageDb = true, bool bIgnoreExceptions = false)
388 {
389 if (m_solver == null && m_net == null)
390 return;
391
392 if (m_evtSyncUnload.WaitOne(0))
393 return;
394
395 m_evtSyncUnload.Set();
396
397 try
398 {
399 if (m_solver != null)
400 {
401 m_solver.Dispose();
402 m_solver = null;
403 }
404
405 if (m_net != null)
406 {
407 if (m_bOwnRunNet)
408 m_net.Dispose();
409 m_net = null;
410 }
411
412 if (m_db != null && bUnloadImageDb)
413 {
414 if (m_bDbOwner)
415 {
416 if (m_dataSet != null)
418
419 IDisposable idisp = m_db as IDisposable;
420 if (idisp != null)
421 idisp.Dispose();
422 }
423
424 m_db = null;
425 }
426
427 m_project = null;
428 }
429 catch (Exception excpt)
430 {
431 if (!bIgnoreExceptions)
432 throw excpt;
433 }
434 finally
435 {
436 m_evtSyncUnload.Reset();
437 }
438 }
439
448 public bool ReInitializeParameters(WEIGHT_TARGET target, params string[] rgstrLayers)
449 {
450 Net<T> net = GetInternalNet(Phase.TRAIN);
451 net.ReInitializeParameters(target, rgstrLayers);
452 return m_solver.ForceOnTrainingIterationEvent();
453 }
454
459 public void SetOnTestOverride(EventHandler<TestArgs> onTest)
460 {
461 m_solver.OnTest += onTest;
462 }
463
468 public void SetOnTrainingStartOverride(EventHandler onTrainingStart)
469 {
470 m_solver.OnStart += onTrainingStart;
471 }
472
477 public void SetOnTestingStartOverride(EventHandler onTestingStart)
478 {
479 m_solver.OnTestStart += onTestingStart;
480 }
481
482
487 public void AddCancelOverrideByName(string strEvtCancel)
488 {
489 m_evtCancel.AddCancelOverride(strEvtCancel);
490 }
491
496 public void AddCancelOverride(CancelEvent evtCancel)
497 {
499 }
500
505 public void RemoveCancelOverrideByName(string strEvtCancel)
506 {
507 m_evtCancel.RemoveCancelOverride(strEvtCancel);
508 }
509
514 public void RemoveCancelOverride(CancelEvent evtCancel)
515 {
517 }
518
526 {
527 get { return (m_solver == null) ? false : m_solver.EnableBlobDebugging; }
528 set
529 {
530 if (m_solver != null)
531 m_solver.EnableBlobDebugging = value;
532 }
533 }
534
542 {
543 get { return (m_solver == null) ? false : m_solver.EnableBreakOnFirstNaN; }
544 set
545 {
546 if (m_solver != null)
547 m_solver.EnableBreakOnFirstNaN = value;
548 }
549 }
550
555 {
556 get { return (m_solver == null) ? false : m_solver.EnableDetailedNanDetection; }
557 set
558 {
559 if (m_solver != null)
560 m_solver.EnableDetailedNanDetection = value;
561 }
562 }
563
571 {
572 get { return (m_solver == null) ? false : m_solver.EnableLayerDebugging; }
573 set
574 {
575 if (m_solver != null)
576 m_solver.EnableLayerDebugging = value;
577 }
578 }
579
587 {
588 get { return (m_solver == null) ? false : m_solver.EnableSingleStep; }
589 set
590 {
591 if (m_solver != null)
592 m_solver.EnableSingleStep = value;
593 }
594 }
595
600 {
601 get { return m_dataTransformer; }
602 }
603
608 {
609 get { return m_settings; }
610 }
611
616 {
617 get { return m_cuda; }
618 }
619
623 public Log Log
624 {
625 get { return m_log; }
626 }
627
632 {
633 get { return m_persist; }
634 }
635
640 {
641 get { return m_db; }
642 }
643
648 {
649 get { return m_evtCancel; }
650 }
651
655 public List<int> ActiveGpus
656 {
657 get { return m_rgGpu; }
658 }
659
666 public string ActiveLabelCounts
667 {
668 get { return m_solver.ActiveLabelCounts; }
669 }
670
678 {
679 get { return m_solver.LabelQueryHitPercents; }
680 }
681
688 public string LabelQueryEpochs
689 {
690 get { return m_solver.LabelQueryEpochs; }
691 }
692
696 public string CurrentDevice
697 {
698 get
699 {
700 int nId = m_cuda.GetDeviceID();
701 return "GPU #" + nId.ToString() + " " + m_cuda.GetDeviceName(nId);
702 }
703 }
704
709 {
710 get { return m_project; }
711 }
712
717 {
718 get { return m_solver.CurrentIteration; }
719 }
720
725 {
726 get { return m_solver.MaximumIteration; }
727 }
728
733 public int GetDeviceCount()
734 {
735 return m_cuda.GetDeviceCount();
736 }
737
743 public string GetDeviceName(int nDeviceID)
744 {
745 return m_cuda.GetDeviceName(nDeviceID);
746 }
747
752 {
753 get { return m_lastPhaseRun; }
754 }
755
766 {
767 return createNetParameterForRunning(p.Dataset, p.ModelDescription, out transform_param, p.Stage);
768 }
769
781 protected NetParameter createNetParameterForRunning(DatasetDescriptor ds, string strModel, out TransformationParameter transform_param, Stage stage = Stage.NONE)
782 {
783 BlobShape shape = datasetToShape(ds);
784 return createNetParameterForRunning(shape, strModel, out transform_param, stage);
785 }
786
798 protected NetParameter createNetParameterForRunning(BlobShape shape, string strModel, out TransformationParameter transform_param, Stage stage = Stage.NONE)
799 {
800 NetParameter param = CreateNetParameterForRunning(shape, strModel, out transform_param, stage);
801 m_inputShape = shape;
802 return param;
803 }
804
819 protected NetParameter createNetParameterForRunning(SimpleDatum sdMean, string strModel, out TransformationParameter transform_param, out int nC, out int nH, out int nW, Stage stage = Stage.NONE)
820 {
821 nC = 0;
822 nH = 0;
823 nW = 0;
824
825 if (sdMean != null)
826 {
827 nC = sdMean.Channels;
828 nH = sdMean.Height;
829 nW = sdMean.Width;
830 }
831 else
832 {
833 RawProto protoModel = RawProto.Parse(strModel);
834 RawProtoCollection layers = protoModel.FindChildren("layer");
835
836 foreach (RawProto layer in layers)
837 {
838 RawProto type = layer.FindChild("type");
839 if (type != null && type.Value == "Input")
840 {
841 RawProto input_param = layer.FindChild("input_param");
842 if (input_param != null)
843 {
844 RawProto shape1 = input_param.FindChild("shape");
845 if (shape1 != null)
846 {
847 RawProtoCollection rgDim = shape1.FindChildren("dim");
848 int nNum = (rgDim.Count > 0) ? int.Parse(rgDim[0].Value) : 1;
849 nC = (rgDim.Count > 1) ? int.Parse(rgDim[1].Value) : 1;
850 nH = (rgDim.Count > 2) ? int.Parse(rgDim[2].Value) : 1;
851 nW = (rgDim.Count > 3) ? int.Parse(rgDim[3].Value) : 1;
852 break;
853 }
854 }
855 }
856 }
857 }
858
859 if (nC == 0 && nH == 0 && nW == 0)
860 throw new Exception("Could not dicern the shape to use for no 'sdMean' parameter was supplied and the model does not contain an 'Input' layer!");
861
862 BlobShape shape = new BlobShape(1, nC, nH, nW);
863 NetParameter param = CreateNetParameterForRunning(shape, strModel, out transform_param, stage);
864 m_inputShape = shape;
865
866 return param;
867 }
868
882 public static NetParameter CreateNetParameterForRunning(BlobShape shape, string strModel, out TransformationParameter transform_param, Stage stage = Stage.NONE, bool bSkipLossLayer = false, bool bMaintainBatchSize = false)
883 {
884 int nNum = (bMaintainBatchSize) ? shape.dim[0] : 1;
885 int nImageChannels = shape.dim[1];
886 int nImageHeight = (shape.dim.Count > 2) ? shape.dim[2] : 1;
887 int nImageWidth = (shape.dim.Count > 3) ? shape.dim[3] : 1;
888
889 transform_param = null;
890
891 RawProto protoTransform = null;
892 bool bSkipTransformParam = false;
893 RawProto protoModel = ProjectEx.CreateModelForRunning(strModel, "data", nNum, nImageChannels, nImageHeight, nImageWidth, out protoTransform, out bSkipTransformParam, stage, bSkipLossLayer);
894
895 if (!bSkipTransformParam)
896 {
897 if (protoTransform != null)
898 transform_param = TransformationParameter.FromProto(protoTransform);
899 else
900 transform_param = new param.TransformationParameter();
901
902 if (transform_param.resize_param != null && transform_param.resize_param.Active)
903 {
904 shape.dim[2] = (int)transform_param.resize_param.height;
905 shape.dim[3] = (int)transform_param.resize_param.width;
906 }
907 }
908
909 NetParameter np = NetParameter.FromProto(protoModel);
910
911 string strInput = "";
912 foreach (LayerParameter layer in np.layer)
913 {
914 layer.PrepareRunModel();
915
916 string strInput1 = layer.PrepareRunModelInputs();
917 if (!string.IsNullOrEmpty(strInput1))
918 strInput += strInput1;
919 }
920
921 if (!string.IsNullOrEmpty(strInput))
922 {
923 RawProto proto = RawProto.Parse(strInput);
924 Dictionary<string, BlobShape> rgInput = NetParameter.InputFromProto(proto);
925
926 if (rgInput.Count > 0)
927 {
928 np.input = new List<string>();
929 np.input_dim = new List<int>();
930 np.input_shape = new List<BlobShape>();
931
932 foreach (KeyValuePair<string, BlobShape> kv in rgInput)
933 {
934 np.input.Add(kv.Key);
935 np.input_shape.Add(kv.Value);
936 }
937 }
938 }
939
940
941 np.ProjectID = 0;
942 np.state.phase = Phase.RUN;
943
944 return np;
945 }
946
947 private BlobShape datasetToShape(DatasetDescriptor ds)
948 {
949 int nH = 1;
950 int nW = 1;
951 int nC = 1;
952
953 if (!ds.IsModelData)
954 {
955 nH = ds.TestingSource.Height;
956 nW = ds.TestingSource.Width;
957 nC = ds.TestingSource.Channels;
958 }
959
960 List<int> rgShape = new List<int>() { 1, nC, nH, nW };
961 return new BlobShape(rgShape);
962 }
963
964 private Stage getStage(string strStage)
965 {
966 if (strStage == Stage.RNN.ToString())
967 return Stage.RNN;
968
969 if (strStage == Stage.RL.ToString())
970 return Stage.RL;
971
972 return Stage.NONE;
973 }
974
975 private string addStage(string strModel, Phase phase, string strStage)
976 {
977 if (string.IsNullOrEmpty(strStage))
978 return strModel;
979
980 RawProto proto = RawProto.Parse(strModel);
981 NetParameter param = NetParameter.FromProto(proto);
982
983 param.state.stage.Clear();
984 param.state.phase = phase;
985 param.state.stage.Add(strStage);
986
987 return param.ToProto("root", true).ToString();
988 }
989
995 {
996 if (prj.Dataset.IsModelData || prj.Dataset.IsGym)
997 return;
998
999 DatasetFactory factory = new DatasetFactory();
1000
1001 // Copy the training image mean to the testing source if it does not have a mean.
1002 // NOTE: This this will not impact a service based image database that is already loaded,
1003 // - it must be reloaded.
1004 int nDstID = factory.GetRawImageMeanID(prj.Dataset.TestingSource.ID);
1005 if (nDstID == 0)
1006 {
1007 int nSrcID = factory.GetRawImageMeanID(prj.Dataset.TrainingSource.ID);
1008 if (nSrcID != 0)
1010 }
1011
1012 if (prj.DatasetTarget != null)
1013 {
1014 // Copy the training image mean to the testing source if it does not have a mean.
1015 // NOTE: This this will not impact a service based image database that is already loaded,
1016 // - it must be reloaded.
1017 nDstID = factory.GetRawImageMeanID(prj.DatasetTarget.TestingSource.ID);
1018 if (nDstID == 0)
1019 {
1020 int nSrcID = factory.GetRawImageMeanID(prj.DatasetTarget.TrainingSource.ID);
1021 if (nSrcID != 0)
1023 }
1024 }
1025 }
1026
1027 private bool verifySharedWeights()
1028 {
1029 Net<T> netTest = m_solver.TestingNet;
1030 if (netTest != null)
1031 {
1032 Net<T> netTrain = m_solver.TrainingNet;
1033
1034 if (netTrain.parameters.Count != netTest.parameters.Count)
1035 {
1036 m_log.WriteLine("WARNING: Training net has a different number of parameters than the testing net!");
1037 return false;
1038 }
1039
1040 for (int i = 0; i < netTrain.parameters.Count; i++)
1041 {
1042 if (netTrain.parameters[i].gpu_data != netTest.parameters[i].gpu_data)
1043 {
1044 m_log.WriteLine("WARNING: Training net parameter " + i.ToString() + " is not shared with the testing net!");
1045 return false;
1046 }
1047 }
1048 }
1049
1050 return true;
1051 }
1052
1070 public bool Load(Phase phase, ProjectEx p, DB_LABEL_SELECTION_METHOD? labelSelectionOverride = null, DB_ITEM_SELECTION_METHOD? itemSelectionOverride = null, bool bResetFirst = false, IXDatabaseBase db = null, bool bUseDb = true, bool bCreateRunNet = true, string strStage = null, bool bEnableMemTrace = false)
1071 {
1072 try
1073 {
1074 m_log.Enable = m_bEnableVerboseStatus;
1075
1076 DatasetFactory factory = new DatasetFactory();
1077 m_strStage = strStage;
1078 m_db = db;
1079 m_bDbOwner = false;
1080
1081 if (db != null && bUseDb)
1082 {
1083 if (m_settings.DbVersion != db.GetVersion())
1084 throw new Exception("The database version in the settings (" + m_settings.DbVersion.ToString() + ") must match the database version (" + db.GetVersion().ToString() + ") of the 'db' parameter!");
1085 }
1086
1087 if (m_db == null && bUseDb)
1088 {
1089 switch (m_settings.DbVersion)
1090 {
1091 case DB_VERSION.IMG_V1:
1093 m_log.WriteLine("Loading primary images...", true);
1094 m_bDbOwner = true;
1095 m_log.Enable = true;
1096 ((IXImageDatabase1)m_db).InitializeWithDs1(m_settings, p.Dataset, m_evtCancel.Name);
1097 break;
1098
1099 case DB_VERSION.TEMPORAL:
1100 throw new NotImplementedException("The temporal database is not yet supported.");
1101
1102 default:
1104 m_log.WriteLine("Loading primary images...", true);
1105 m_bDbOwner = true;
1106 m_log.Enable = true;
1107 ((IXImageDatabase2)m_db).InitializeWithDs(m_settings, p.Dataset, m_evtCancel.Name);
1108 break;
1109 }
1110
1111 if (m_evtCancel.WaitOne(0))
1112 return false;
1113
1114 // m_db.UpdateLabelBoosts(p.ID, p.Dataset.TrainingSource.ID);
1115
1116 Tuple<DB_LABEL_SELECTION_METHOD, DB_ITEM_SELECTION_METHOD> selMethod = MyCaffeImageDatabase.GetSelectionMethod(p);
1117 DB_LABEL_SELECTION_METHOD lblSel = selMethod.Item1;
1118 DB_ITEM_SELECTION_METHOD imgSel = selMethod.Item2;
1119
1120 if (labelSelectionOverride.HasValue)
1121 lblSel = labelSelectionOverride.Value;
1122
1123 if (itemSelectionOverride.HasValue)
1124 imgSel = itemSelectionOverride.Value;
1125
1126 m_db.SetSelectionMethod(lblSel, imgSel);
1128 m_log.WriteLine("Images loaded.");
1129
1130 if (p.TargetDatasetID > 0)
1131 {
1132 DatasetDescriptor dsTarget = factory.LoadDataset(p.TargetDatasetID);
1133
1134 m_log.WriteLine("Loading target dataset '" + dsTarget.Name + "' images using " + m_settings.DbLoadMethod.ToString() + " loading method.");
1135 string strType = "images";
1136
1137 switch (m_settings.DbVersion)
1138 {
1139 case DB_VERSION.IMG_V1:
1140 ((IXImageDatabase1)m_db).LoadDatasetByID1(dsTarget.ID);
1141 break;
1142
1143 case DB_VERSION.TEMPORAL:
1144 strType = "items";
1145 throw new NotImplementedException("The temporal database is not yet supported.");
1146
1147 default:
1148 ((IXImageDatabase2)m_db).LoadDatasetByID(dsTarget.ID);
1149 break;
1150 }
1151
1153 m_log.WriteLine("Target dataset " + strType + " loaded.");
1154 }
1155
1156 m_log.Enable = m_bEnableVerboseStatus;
1157 }
1158
1159 p.ModelDescription = addStage(p.ModelDescription, phase, strStage);
1160 m_project = p;
1161 m_project.Stage = getStage(m_strStage);
1162
1163 if (m_project == null)
1164 throw new Exception("You must specify a project.");
1165
1167
1168 if (m_cuda != null)
1169 m_cuda.Dispose();
1170
1171 m_cuda = new CudaDnn<T>(m_rgGpu[0], DEVINIT.CUBLAS | DEVINIT.CURAND, null, m_strCudaPath, bResetFirst, bEnableMemTrace);
1172
1173 m_log.WriteLine("Cuda Connection created using '" + m_cuda.Path + "'.", true);
1174
1175 if (phase == Phase.TEST || phase == Phase.TRAIN)
1176 {
1177 m_log.WriteLine("Creating solver...", true);
1178
1179 m_solver = Solver<T>.Create(m_cuda, m_log, p, m_evtCancel, m_evtForceSnapshot, m_evtForceTest, m_db, m_persist, m_rgGpu.Count, 0);
1181 if (p.WeightsState != null || p.SolverState != null)
1182 {
1183 string strSkipBlobType = null;
1184
1185 ParameterDescriptor param = p.Parameters.Find("ModelResized");
1186 if (param != null && param.Value == "True")
1187 strSkipBlobType = BLOB_TYPE.IP_WEIGHT.ToString();
1188
1189 if (p.SolverState != null)
1190 m_solver.Restore(p.WeightsState, p.SolverState, strSkipBlobType);
1191 else
1192 m_solver.TrainingNet.LoadWeights(p.WeightsState, m_persist);
1193 }
1194
1195 m_solver.OnSnapshot += new EventHandler<SnapshotArgs>(m_solver_OnSnapshot);
1196 m_solver.OnTrainingIteration += new EventHandler<TrainingIterationArgs<T>>(m_solver_OnTrainingIteration);
1197 m_solver.OnTestingIteration += new EventHandler<TestingIterationArgs<T>>(m_solver_OnTestingIteration);
1198 m_log.WriteLine("Solver created.", true);
1199
1200 verifySharedWeights();
1201 }
1202
1203 if (m_db is IXImageDatabase1)
1204 {
1205#warning ImageDatabase V1 only
1206 if (phase == Phase.TRAIN && m_db != null)
1207 ((IXImageDatabase1)m_db).UpdateLabelBoosts(p.ID, m_dataSet.TrainingSource.ID);
1208
1209 if (phase == Phase.TEST && m_db != null)
1210 ((IXImageDatabase1)m_db).UpdateLabelBoosts(p.ID, m_dataSet.TestingSource.ID);
1211 }
1212
1213 if (phase == Phase.RUN && !bCreateRunNet)
1214 throw new Exception("You cannot opt out of creating the Run net when using the RUN phase.");
1215
1216 if (p == null || !bCreateRunNet)
1217 return true;
1218
1219 TransformationParameter tp = null;
1220 NetParameter netParam = createNetParameterForRunning(p, out tp);
1221
1222 m_dataTransformer = null;
1223
1224 if (tp != null && !p.Dataset.IsModelData && !p.Dataset.IsGym && m_settings.DbVersion != DB_VERSION.TEMPORAL)
1225 {
1226 SimpleDatum sdMean = (m_db == null) ? null : m_db.QueryItemMean(m_dataSet.TrainingSource.ID);
1230
1231 if (sdMean != null)
1232 {
1233 m_log.CHECK_EQ(nC, sdMean.Channels, "The mean channel count does not match the datasets channel count.");
1234 m_log.CHECK_EQ(nH, sdMean.Height, "The mean height count does not match the datasets height count.");
1235 m_log.CHECK_EQ(nW, sdMean.Width, "The mean width count does not match the datasets width count.");
1236 }
1237
1238 m_dataTransformer = new DataTransformer<T>(m_cuda, m_log, tp, Phase.RUN, nC, nH, nW, sdMean);
1239 }
1240
1241 m_log.WriteLine("Creating run net...", true);
1242
1243 if (phase == Phase.RUN)
1244 {
1245 m_net = new Net<T>(m_cuda, m_log, netParam, m_evtCancel, m_db);
1246
1247 if (p.WeightsState != null)
1248 {
1249 m_log.WriteLine("Loading run weights...", true);
1250 loadWeights(m_net, p.WeightsState);
1251 }
1252 }
1253 else if (phase == Phase.TEST || phase == Phase.TRAIN)
1254 {
1255 try
1256 {
1257 m_net = new Net<T>(m_cuda, m_log, netParam, m_evtCancel, m_db, Phase.RUN, null, m_solver.TrainingNet);
1258 }
1259 catch (Exception excpt)
1260 {
1261 m_log.WriteLine("WARNING: Failed to create run net - using Test net instead for running. Error = " + excpt.Message);
1262 m_net = m_solver.TestingNet;
1263 m_bOwnRunNet = false;
1264 }
1265 }
1266 }
1267 catch (Exception excpt)
1268 {
1269 throw excpt;
1270 }
1271 finally
1272 {
1273 m_log.Enable = true;
1274 }
1275
1276 return true;
1277 }
1278
1298 public bool Load(Phase phase, string strSolver, string strModel, byte[] rgWeights, DB_LABEL_SELECTION_METHOD? labelSelectionOverride = null, DB_ITEM_SELECTION_METHOD? itemSelectionOverride = null, bool bResetFirst = false, IXDatabaseBase db = null, bool bUseDb = true, bool bCreateRunNet = true, string strStage = null, bool bEnableMemTrace = false)
1299 {
1300 try
1301 {
1302 m_log.Enable = m_bEnableVerboseStatus;
1303
1304 m_strStage = strStage;
1305 m_db = db;
1306 m_bDbOwner = false;
1307
1308 RawProto protoSolver = RawProto.Parse(strSolver);
1309 SolverParameter solverParam = SolverParameter.FromProto(protoSolver);
1310
1311 strModel = addStage(strModel, phase, strStage);
1312
1313 RawProto protoModel = RawProto.Parse(strModel);
1314 solverParam.net_param = NetParameter.FromProto(protoModel);
1315
1316 m_dataSet = findDataset(solverParam.net_param);
1317
1318 if (db != null && bUseDb)
1319 {
1320 if (m_settings.DbVersion != db.GetVersion())
1321 throw new Exception("The database version in the settings (" + m_settings.DbVersion.ToString() + ") must match the database version (" + db.GetVersion().ToString() + ") of the 'db' parameter!");
1322 }
1323
1324 if (m_db == null && bUseDb)
1325 {
1326 string strType = "images";
1327 switch (m_settings.DbVersion)
1328 {
1329 case DB_VERSION.IMG_V1:
1331 ((IXImageDatabase1)m_db).InitializeWithDs1(m_settings, m_dataSet, m_evtCancel.Name);
1332 m_log.WriteLine("Loading primary images...", true);
1333 m_log.Enable = true;
1334 break;
1335
1336 case DB_VERSION.IMG_V2:
1339 m_log.WriteLine("Loading primary images...", true);
1340 m_log.Enable = true;
1341 break;
1342
1343 case DB_VERSION.TEMPORAL:
1344 throw new NotImplementedException("The temporal database is not yet implemented!");
1345 }
1346
1347 m_bDbOwner = true;
1348 if (m_evtCancel.WaitOne(0))
1349 return false;
1350
1351 Tuple<DB_LABEL_SELECTION_METHOD, DB_ITEM_SELECTION_METHOD> selMethod = MyCaffeImageDatabase.GetSelectionMethod(m_settings);
1352 DB_LABEL_SELECTION_METHOD lblSel = selMethod.Item1;
1353 DB_ITEM_SELECTION_METHOD imgSel = selMethod.Item2;
1354
1355 if (labelSelectionOverride.HasValue)
1356 lblSel = labelSelectionOverride.Value;
1357
1358 if (itemSelectionOverride.HasValue)
1359 imgSel = itemSelectionOverride.Value;
1360
1361 m_db.SetSelectionMethod(lblSel, imgSel);
1363 m_log.WriteLine("Images loaded.", true);
1364
1365 DatasetDescriptor dsTarget = findDataset(solverParam.net_param, m_dataSet);
1366 if (dsTarget != null)
1367 {
1368 m_log.WriteLine("Loading target dataset '" + dsTarget.Name + "' " + strType + " using " + m_settings.DbLoadMethod.ToString() + " loading method.", true);
1369
1370 switch (m_settings.DbVersion)
1371 {
1372 case DB_VERSION.IMG_V1:
1373 ((IXImageDatabase1)m_db).LoadDatasetByID1(dsTarget.ID);
1374 break;
1375
1376 case DB_VERSION.IMG_V2:
1377 ((IXImageDatabase2)m_db).LoadDatasetByID(dsTarget.ID);
1378 break;
1379
1380 case DB_VERSION.TEMPORAL:
1381 throw new NotImplementedException("The temporal database is not yet implemented!");
1382 }
1383
1385 m_log.WriteLine("Target dataset " + strType + " loaded.", true);
1386 }
1387
1388 m_log.Enable = m_bEnableVerboseStatus;
1389 }
1390
1391 m_project = null;
1392
1393 if (m_cuda != null)
1394 m_cuda.Dispose();
1395
1396 m_cuda = new CudaDnn<T>(m_rgGpu[0], DEVINIT.CUBLAS | DEVINIT.CURAND, null, m_strCudaPath, bResetFirst, bEnableMemTrace);
1397 m_log.WriteLine("Cuda Connection created using '" + m_cuda.Path + "'.", true);
1398
1399 if (phase == Phase.TEST || phase == Phase.TRAIN)
1400 {
1401 m_log.WriteLine("Creating solver...", true);
1402
1403 m_solver = Solver<T>.Create(m_cuda, m_log, solverParam, m_evtCancel, m_evtForceSnapshot, m_evtForceTest, m_db, m_persist, m_rgGpu.Count, 0);
1404
1405 if (rgWeights != null)
1406 {
1407 m_log.WriteLine("Restoring weights...", true);
1408 m_solver.Restore(rgWeights, null);
1409 }
1410
1411 m_solver.OnSnapshot += new EventHandler<SnapshotArgs>(m_solver_OnSnapshot);
1412 m_solver.OnTrainingIteration += new EventHandler<TrainingIterationArgs<T>>(m_solver_OnTrainingIteration);
1413 m_solver.OnTestingIteration += new EventHandler<TestingIterationArgs<T>>(m_solver_OnTestingIteration);
1414 m_log.WriteLine("Solver created.", true);
1415 }
1416
1417 if (!bCreateRunNet)
1418 {
1419 if (phase == Phase.RUN)
1420 throw new Exception("You cannot opt out of creating the Run net when using the RUN phase.");
1421
1422 return true;
1423 }
1424
1425 TransformationParameter tp = null;
1426 NetParameter netParam = createNetParameterForRunning(m_dataSet, strModel, out tp);
1427
1428 m_dataTransformer = null;
1429
1430 if (tp != null)
1431 {
1432 SimpleDatum sdMean = (m_db == null) ? null : m_db.QueryItemMean(m_dataSet.TrainingSource.ID);
1433 int nC = 0;
1434 int nH = 0;
1435 int nW = 0;
1436
1437 if (sdMean != null)
1438 {
1439 nC = sdMean.Channels;
1440 nH = sdMean.Height;
1441 nW = sdMean.Width;
1442 }
1443 else if (m_project != null)
1444 {
1448 }
1449
1450 if (nC == 0 || nH == 0 || nW == 0)
1451 throw new Exception("Unable to size the Data Transformer for there is no Mean or Project to gather the sizing information from.");
1452
1453 m_dataTransformer = new DataTransformer<T>(m_cuda, m_log, tp, Phase.RUN, nC, nH, nW, sdMean);
1454 }
1455
1456 m_log.WriteLine("Creating run net...", true);
1457
1458 if (phase == Phase.RUN)
1459 {
1460 m_net = new Net<T>(m_cuda, m_log, netParam, m_evtCancel, m_db);
1461
1462 if (rgWeights != null)
1463 {
1464 m_log.WriteLine("Loading run weights...", true);
1465 loadWeights(m_net, rgWeights);
1466 }
1467 }
1468 else if (phase == Phase.TEST || phase == Phase.TRAIN)
1469 {
1470 netParam.force_backward = true;
1471
1472 try
1473 {
1474 m_net = new Net<T>(m_cuda, m_log, netParam, m_evtCancel, m_db, Phase.RUN, null, m_solver.TrainingNet);
1475 }
1476 catch (Exception excpt)
1477 {
1478 m_log.WriteLine("WARNING: Failed to create run net - using Test net instead for running. Error = " + excpt.Message);
1479 m_net = m_solver.TestingNet;
1480 m_bOwnRunNet = false;
1481 }
1482 }
1483 }
1484 catch (Exception excpt)
1485 {
1486 throw excpt;
1487 }
1488 finally
1489 {
1490 m_log.Enable = true;
1491 }
1492
1493 return true;
1494 }
1495
1512 public bool LoadLite(Phase phase, string strSolver, string strModel, byte[] rgWeights = null, bool bResetFirst = false, bool bCreateRunNet = true, SimpleDatum sdMean = null, string strStage = null, bool bEnableMemTrace = false)
1513 {
1514 try
1515 {
1516 m_log.Enable = m_bEnableVerboseStatus;
1517
1518 m_bLoadLite = true;
1519 m_strSolver = strSolver;
1520 m_strModel = strModel;
1521
1522 m_strStage = strStage;
1523 m_db = null;
1524 m_bDbOwner = false;
1525
1526 RawProto protoSolver = RawProto.Parse(strSolver);
1527 SolverParameter solverParam = SolverParameter.FromProto(protoSolver);
1528
1529 strModel = addStage(strModel, phase, strStage);
1530
1531 RawProto protoModel = RawProto.Parse(strModel);
1532 solverParam.net_param = NetParameter.FromProto(protoModel);
1533
1534 m_dataSet = null;
1535 m_project = null;
1536
1537 if (m_cuda != null)
1538 m_cuda.Dispose();
1539
1540 m_cuda = new CudaDnn<T>(m_rgGpu[0], DEVINIT.CUBLAS | DEVINIT.CURAND, null, m_strCudaPath, bResetFirst, bEnableMemTrace);
1541 m_log.WriteLine("Cuda Connection created using '" + m_cuda.Path + "'.", true);
1542
1543 if (phase == Phase.TEST || phase == Phase.TRAIN)
1544 {
1545 m_log.WriteLine("Creating solver...", true);
1546
1547 m_solver = Solver<T>.Create(m_cuda, m_log, solverParam, m_evtCancel, m_evtForceSnapshot, m_evtForceTest, m_db, m_persist, m_rgGpu.Count, 0);
1548
1549 if (rgWeights != null)
1550 m_solver.Restore(rgWeights, null);
1551
1552 m_solver.OnSnapshot += new EventHandler<SnapshotArgs>(m_solver_OnSnapshot);
1553 m_solver.OnTrainingIteration += new EventHandler<TrainingIterationArgs<T>>(m_solver_OnTrainingIteration);
1554 m_solver.OnTestingIteration += new EventHandler<TestingIterationArgs<T>>(m_solver_OnTestingIteration);
1555 m_log.WriteLine("Solver created.");
1556 }
1557
1558 if (!bCreateRunNet)
1559 {
1560 if (phase == Phase.RUN)
1561 throw new Exception("You cannot opt out of creating the Run net when using the RUN phase.");
1562
1563 return true;
1564 }
1565
1566 TransformationParameter tp = null;
1567 int nC = 0;
1568 int nH = 0;
1569 int nW = 0;
1570 NetParameter netParam = createNetParameterForRunning(sdMean, strModel, out tp, out nC, out nH, out nW);
1571
1572 m_dataTransformer = null;
1573
1574 if (tp != null)
1575 {
1576 if (nC == 0 || nH == 0 || nW == 0)
1577 throw new Exception("Unable to size the Data Transformer for no Mean image was provided as the 'sdMean' parameter which is used to gather the sizing information.");
1578
1579 m_dataTransformer = new DataTransformer<T>(m_cuda, m_log, tp, Phase.RUN, nC, nH, nW, sdMean);
1580 }
1581
1582 m_log.WriteLine("Creating run net...", true);
1583
1584 if (phase == Phase.RUN)
1585 {
1586 m_net = new Net<T>(m_cuda, m_log, netParam, m_evtCancel, null);
1587
1588 if (rgWeights != null)
1589 {
1590 m_log.WriteLine("Loading run weights...", true);
1591 loadWeights(m_net, rgWeights);
1592 }
1593 }
1594 else if (phase == Phase.TEST || phase == Phase.TRAIN)
1595 {
1596 netParam.force_backward = true;
1597
1598 try
1599 {
1600 m_net = new Net<T>(m_cuda, m_log, netParam, m_evtCancel, null, Phase.RUN, null, m_solver.TrainingNet);
1601 }
1602 catch (Exception excpt)
1603 {
1604 m_log.WriteLine("WARNING: Failed to create run net - using Test net instead for running. Error = " + excpt.Message);
1605 m_net = m_solver.TestingNet;
1606 m_bOwnRunNet = false;
1607 }
1608 }
1609 }
1610 catch (Exception excpt)
1611 {
1612 throw excpt;
1613 }
1614 finally
1615 {
1616 m_log.Enable = true;
1617 }
1618
1619 return true;
1620 }
1621
1637 public void LoadToRun(string strModel, byte[] rgWeights, BlobShape shape, SimpleDatum sdMean = null, TransformationParameter transParam = null, bool bForceBackward = false, bool bConvertToRunNet = true)
1638 {
1639 try
1640 {
1641 m_log.Enable = m_bEnableVerboseStatus;
1642 m_dataSet = null;
1643 m_project = null;
1644 m_loadToRunShape = shape;
1645
1646 if (m_cuda != null)
1647 m_cuda.Dispose();
1648
1649 m_cuda = new CudaDnn<T>(m_rgGpu[0], DEVINIT.CUBLAS | DEVINIT.CURAND, null, m_strCudaPath);
1650 m_log.WriteLine("Cuda Connection created using '" + m_cuda.Path + "'.", true);
1651
1652 TransformationParameter tp = null;
1653 NetParameter netParam = null;
1654
1655 if (bConvertToRunNet)
1656 {
1657 netParam = createNetParameterForRunning(shape, strModel, out tp);
1658 }
1659 else
1660 {
1661 netParam = NetParameter.FromProto(RawProto.Parse(strModel));
1662
1663 foreach (LayerParameter layer in netParam.layer)
1664 {
1665 if (layer.transform_param != null)
1666 {
1667 tp = layer.transform_param;
1668 break;
1669 }
1670 }
1671 }
1672
1673 netParam.force_backward = bForceBackward;
1674
1675 if (transParam != null)
1676 tp = transParam;
1677
1678 if (tp != null)
1679 {
1680 if (tp.use_imagedb_mean && sdMean == null)
1681 throw new Exception("The transformer expects an image mean, yet the sdMean parameter is null!");
1682
1683 m_dataTransformer = new DataTransformer<T>(m_cuda, m_log, tp, Phase.RUN, shape.dim[1], shape.dim[2], shape.dim[3], sdMean);
1684 }
1685 else
1686 {
1687 m_dataTransformer = null;
1688 }
1689
1690 m_log.WriteLine("Creating run net...", true);
1691 m_net = new Net<T>(m_cuda, m_log, netParam, m_evtCancel, null, Phase.RUN);
1692
1693 m_log.WriteLine("Loading weights...", true);
1694 loadWeights(m_net, rgWeights);
1695 }
1696 catch (Exception excpt)
1697 {
1698 throw excpt;
1699 }
1700 finally
1701 {
1702 m_log.Enable = true;
1703 }
1704 }
1705
1706 private SimpleDatum getMeanImage(NetParameter p)
1707 {
1708 string strSrc = null;
1709
1710 foreach (LayerParameter lp in p.layer)
1711 {
1712 if (lp.type == LayerParameter.LayerType.TRANSFORM)
1713 {
1715 return null;
1716 }
1717 else if (lp.type == LayerParameter.LayerType.DATA)
1718 {
1719 switch (lp.type)
1720 {
1721 case LayerParameter.LayerType.DATA:
1722 strSrc = lp.data_param.source;
1723 break;
1724 }
1725 }
1726 }
1727
1728 if (strSrc == null)
1729 throw new Exception("Could not find the data source in the model!");
1730
1731 DatasetFactory factory = new DatasetFactory();
1732 SourceDescriptor sd = factory.LoadSource(strSrc);
1733
1734 if (sd == null)
1735 throw new Exception("Could not find the data source '" + strSrc + "' in the database.");
1736
1737 return factory.QueryImageMean(sd.ID);
1738 }
1739
1740 private DatasetDescriptor findDataset(NetParameter p, DatasetDescriptor dsPrimary = null)
1741 {
1742 string strTestSrc = null;
1743 string strTrainSrc = null;
1744
1745 foreach (LayerParameter lp in p.layer)
1746 {
1747 if (lp.type == LayerParameter.LayerType.DATA)
1748 {
1749 string strSrc = null;
1750
1751 switch (lp.type)
1752 {
1753 case LayerParameter.LayerType.DATA:
1754 strSrc = lp.data_param.source;
1755 break;
1756 }
1757
1758 foreach (NetStateRule rule in lp.include)
1759 {
1760 if (rule.phase == Phase.TRAIN)
1761 strTrainSrc = strSrc;
1762 else if (rule.phase == Phase.TEST)
1763 strTestSrc = strSrc;
1764 }
1765 }
1766
1767 if (strTrainSrc != null && strTestSrc != null)
1768 {
1769 if (dsPrimary == null || (strTrainSrc != dsPrimary.TrainingSourceName && strTestSrc != dsPrimary.TestingSourceName))
1770 break;
1771 }
1772 }
1773
1774 if (strTrainSrc == null || strTestSrc == null)
1775 return null;
1776
1777 if (dsPrimary != null && (strTrainSrc == dsPrimary.TrainingSourceName && strTestSrc == dsPrimary.TestingSourceName))
1778 return null;
1779
1780 DatasetFactory factory = new DatasetFactory();
1781 DatasetDescriptor ds = factory.LoadDataset(strTestSrc, strTrainSrc);
1782
1783 if (ds == null)
1784 throw new Exception("The datset sources '" + strTestSrc + "' and '" + strTrainSrc + "' do not exist in the database - do you need to load them?");
1785
1786 return ds;
1787 }
1788
1789 private void loadWeights(Net<T> net, byte[] rgWeights)
1790 {
1791 net.LoadWeights(rgWeights, m_persist);
1792 }
1793
1800 public bool CompareWeights(Net<T> net1, Net<T> net2)
1801 {
1802 if (net1.learnable_parameters.Count != net2.learnable_parameters.Count)
1803 {
1804 m_log.WriteLine("WARNING: The number of learnable parameters differs between the two nets!");
1805 return false;
1806 }
1807
1808 Blob<T> blobWork = new Blob<T>(m_cuda, m_log, false);
1809
1810 try
1811 {
1812 for (int i = 0; i < net1.learnable_parameters.Count; i++)
1813 {
1814 Blob<T> blob1 = net1.learnable_parameters[i];
1815 Blob<T> blob2 = net2.learnable_parameters[i];
1816
1817 if (blob1.Name != blob2.Name)
1818 {
1819 m_log.WriteLine("WARNING: The name of the blobs at index " + i.ToString() + " differ: net1 - " + blob1.Name + " vs. net2 - " + blob2.Name);
1820 return false;
1821 }
1822
1823 if (blob1.shape_string != blob2.shape_string)
1824 {
1825 m_log.WriteLine("WARNING: The shape of the blobs at index " + i.ToString() + " differ: net1 - " + blob1.Name + " " + blob1.shape_string + " vs. net2 - " + blob2.Name + " " + blob2.shape_string);
1826 return false;
1827 }
1828
1829 blobWork.ReshapeLike(blob1);
1830 m_cuda.sub(blob1.count(), blob1.gpu_data, blob2.gpu_data, blobWork.mutable_gpu_data);
1831 double dfSum = Utility.ConvertVal<T>(blobWork.asum_data());
1832 if (dfSum != 0)
1833 {
1834 m_log.WriteLine("WARNING: The data of the blobs at index " + i.ToString() + " differ: net1 - " + blob1.Name + " " + blob1.shape_string + " vs. net2 - " + blob2.Name + " " + blob2.shape_string);
1835 return false;
1836 }
1837 }
1838 }
1839 finally
1840 {
1841 blobWork.Dispose();
1842 }
1843
1844 return true;
1845 }
1846
1847 void m_solver_OnTestingIteration(object sender, TestingIterationArgs<T> e)
1848 {
1849 if (OnTestingIteration != null)
1850 OnTestingIteration(sender, e);
1851 }
1852
1853 void m_solver_OnTrainingIteration(object sender, TrainingIterationArgs<T> e)
1854 {
1855 if (OnTrainingIteration != null)
1856 OnTrainingIteration(sender, e);
1857 }
1858
1859 void m_solver_OnSnapshot(object sender, SnapshotArgs e)
1860 {
1861 if (OnSnapshot != null)
1862 OnSnapshot(sender, e);
1863 }
1864
1876 public void Train(int nIterationOverride = -1, int nTrainingTimeLimitInMinutes = 0, TRAIN_STEP step = TRAIN_STEP.NONE, double dfLearningRateOverride = 0, bool bReset = false)
1877 {
1878 m_lastPhaseRun = Phase.TRAIN;
1879
1880 if (nIterationOverride == -1)
1881 nIterationOverride = m_settings.MaximumIterationOverride;
1882
1883 if (bReset)
1884 m_solver.Reset();
1885
1886 m_solver.TrainingTimeLimitInMinutes = nTrainingTimeLimitInMinutes;
1887 m_solver.TrainingIterationOverride = nIterationOverride;
1889
1890 if (dfLearningRateOverride > 0)
1891 m_solver.LearningRateOverride = dfLearningRateOverride;
1892
1893 try
1894 {
1895 if (m_rgGpu.Count > 1)
1896 {
1897 if (nTrainingTimeLimitInMinutes > 0)
1898 {
1899 m_log.WriteLine("You have a training time-limit of " + nTrainingTimeLimitInMinutes.ToString("N0") + " minutes. Multi-GPU training is not supported when a training time-limit is imposed.");
1900 return;
1901 }
1902
1903 m_log.WriteLine("Starting multi-GPU training on GPUs: " + listToString(m_rgGpu));
1904 NCCL<T> nccl = new NCCL<T>(m_cuda, m_log, m_solver, m_rgGpu[0], 0, null);
1905 nccl.Run(m_rgGpu, m_solver.TrainingIterationOverride);
1906 }
1907 else
1908 {
1909 m_solver.Solve(-1, null, null, step);
1910 }
1911 }
1912 catch (Exception excpt)
1913 {
1914 throw excpt;
1915 }
1916 finally
1917 {
1918 m_solver.LearningRateOverride = 0;
1919 }
1920 }
1921
1922 private string listToString(List<int> rg)
1923 {
1924 string strOut = "";
1925
1926 for (int i = 0; i < rg.Count; i++)
1927 {
1928 strOut += rg[i].ToString();
1929
1930 if (i < rg.Count - 1)
1931 strOut += ", ";
1932 }
1933
1934 return strOut;
1935 }
1936
1942 public double Test(int nIterationOverride = -1)
1943 {
1944 m_lastPhaseRun = Phase.TEST;
1945
1946 if (nIterationOverride == -1)
1947 nIterationOverride = m_settings.TestingIterationOverride;
1948
1949 m_solver.TestingIterationOverride = nIterationOverride;
1950
1951 return m_solver.TestAll();
1952 }
1953
1962 {
1964 throw new Exception("Custom input is only supported by MODEL based datasets!");
1965
1966 m_lastPhaseRun = Phase.RUN;
1967
1968 UpdateRunWeights(false, false);
1969
1970 double dfThreshold = customInput.GetPropertyAsDouble("Threshold", 0.2);
1971 int nMax = customInput.GetPropertyAsInt("Max", 80);
1972 int nK = customInput.GetPropertyAsInt("K", 1);
1973 PropertySet res;
1974
1975 if (customInput.GetPropertyAsBool("Temporal", false))
1976 {
1977 res = RunModel(customInput);
1978 res.SetProperty("Temporal", "True");
1979 }
1980 else
1981 {
1982 res = Run(customInput, nK, dfThreshold, nMax);
1983 }
1984
1985 return res;
1986 }
1987
1996 {
1997 m_lastPhaseRun = Phase.RUN;
1998
1999 UpdateRunWeights(false, false);
2000
2001 double dfThreshold = customInput.GetPropertyAsDouble("Threshold", 0.2);
2002 int nMax = customInput.GetPropertyAsInt("Max", 80);
2003 int nK = customInput.GetPropertyAsInt("K", 1);
2004
2005 if (customInput.GetProperty("Temporal") == "True")
2006 return RunModelEx(customInput);
2007
2008 throw new Exception("TestManyEx currently only supports temporal testing.");
2009 }
2010
2023 public List<Tuple<SimpleDatum, ResultCollection>> TestMany(int nCount, bool bOnTrainingSet, bool bOnTargetSet = false, DB_ITEM_SELECTION_METHOD imgSelMethod = DB_ITEM_SELECTION_METHOD.RANDOM, int nImageStartIdx = 0, DateTime? dtImageStartTime = null, double? dfThreshold = null)
2024 {
2025 Dictionary<int, int> rgMissedThreshold = new Dictionary<int, int>();
2026
2027 m_lastPhaseRun = Phase.RUN;
2028
2029 UpdateRunWeights(false);
2030
2031 m_log.CHECK_GT(nCount, 0, "You must select at least 1 image to train on!");
2032
2033 Stopwatch sw = new Stopwatch();
2034 DB_LABEL_SELECTION_METHOD? lblSelMethod = null;
2035
2036 if (imgSelMethod == DB_ITEM_SELECTION_METHOD.NONE)
2037 lblSelMethod = DB_LABEL_SELECTION_METHOD.NONE;
2038
2039 Tuple<DB_LABEL_SELECTION_METHOD, DB_ITEM_SELECTION_METHOD> sel = m_db.GetSelectionMethod();
2040 if ((sel.Item2 & DB_ITEM_SELECTION_METHOD.BOOST) == DB_ITEM_SELECTION_METHOD.BOOST)
2041 imgSelMethod |= DB_ITEM_SELECTION_METHOD.BOOST;
2042
2043 int nSrcId = (bOnTrainingSet) ? m_dataSet.TrainingSource.ID : m_dataSet.TestingSource.ID;
2044 string strSrc = (bOnTrainingSet) ? m_dataSet.TrainingSourceName : m_dataSet.TestingSourceName;
2045 string strSet = (bOnTrainingSet) ? "training" : "test";
2046 int nCorrectCount = 0;
2047 Dictionary<int, int> rgCorrectCounts = new Dictionary<int, int>();
2048 Dictionary<int, int> rgLabelTotals = new Dictionary<int, int>();
2049 Dictionary<int, Dictionary<int, int>> rgDetectedCounts = new Dictionary<int, Dictionary<int, int>>();
2050
2051 if (bOnTargetSet && m_project.DatasetTarget != null)
2052 {
2055 strSet = (bOnTrainingSet) ? "target training" : "target test";
2056 }
2057
2058 sw.Start();
2059
2060 m_log.WriteHeader("Test Many (" + nCount.ToString() + ") - on " + strSet + " '" + strSrc + "'");
2061
2062 LabelMappingParameter labelMapping = null;
2063 Net<T> net = m_solver.TestingNet;
2064 if (net == null)
2065 net = m_solver.TrainingNet;
2066
2067 AccuracyParameter accuracyParam = null;
2068 foreach (Layer<T> layer in net.layers)
2069 {
2070 if (layer.type == LayerParameter.LayerType.LABELMAPPING)
2071 {
2072 labelMapping = layer.layer_param.labelmapping_param;
2073 break;
2074 }
2075 else if (layer.type == LayerParameter.LayerType.ACCURACY ||
2076 layer.type == LayerParameter.LayerType.ACCURACY_DECODE ||
2077 layer.type == LayerParameter.LayerType.ACCURACY_ENCODING)
2078 {
2079 accuracyParam = layer.layer_param.accuracy_param;
2080 }
2081 }
2082
2083 if (nImageStartIdx < 0)
2084 nImageStartIdx = 0;
2085
2086 List<SimpleDatum> rgImg = null;
2087 if (dtImageStartTime.HasValue && dtImageStartTime.Value > DateTime.MinValue)
2088 {
2089 m_log.WriteLine("INFO: Starting test many at images with time " + dtImageStartTime.Value.ToString() + " or later...");
2090 rgImg = m_db.GetItemsFromTime(nSrcId, dtImageStartTime.Value, nCount);
2091 if (nCount > rgImg.Count)
2092 nCount = rgImg.Count;
2093
2094 if (nCount == 0)
2095 throw new Exception("No images found after time '" + dtImageStartTime.Value.ToString() + "'. Make sure to use the LOAD_ALL image loading method when running TestMany after a specified time.");
2096 }
2097
2098 List<Tuple<SimpleDatum, ResultCollection>> rgrgResults = new List<Tuple<SimpleDatum, ResultCollection>>();
2099 int nTotalCount = 0;
2100 int nMidPoint = 0;
2101 bool bPad = true;
2102 int nImgCount = nCount;
2103
2104 Blob<T> blobData = null;
2105
2106 try
2107 {
2108 SimpleDatum sd = null;
2109 List<int> rgOriginalRunNetInputShape = null;
2110
2111 if (m_net.input_blobs != null && m_net.input_blobs.Count > 0)
2112 rgOriginalRunNetInputShape = Utility.Clone<int>(m_net.input_blobs[0].shape());
2113
2114 for (int i = 0; i < nCount; i++)
2115 {
2116 if (m_evtCancel.WaitOne(0))
2117 {
2118 m_log.WriteLine("Test Many aborted!");
2119 return null;
2120 }
2121
2122
2123 sd = (rgImg != null) ? rgImg[i] : m_db.QueryItem(nSrcId, nImageStartIdx + i, lblSelMethod, imgSelMethod, null, m_settings.ItemDbLoadDataCriteria, m_settings.ItemDbLoadDebugData);
2124
2126 {
2127 m_log.WriteLine("WARNING: Image size mismatch! Current image size " + sd.Width.ToString() + " x " + sd.Height.ToString() + " does not match the dataset image size " + m_dataSet.TrainingSource.Width.ToString() + " x " + m_dataSet.TrainingSource.Height.ToString() + "!");
2128 continue;
2129 }
2130
2131 m_dataTransformer.TransformLabel(sd);
2132
2133 if (!sd.GetDataValid(false))
2134 {
2135 Trace.WriteLine("You should not be here.");
2136 throw new Exception("NO DATA!");
2137 }
2138
2139 // Create blob masks images (when enabled) during the data transform.
2140 blobData = CreateDataBlob(sd, blobData, bPad);
2141
2142 List<int> rgIgnoreLabels = null;
2143 if (accuracyParam != null && accuracyParam.ignore_labels != null)
2144 rgIgnoreLabels = accuracyParam.ignore_labels;
2145
2146 List<ResultCollection> rgrgResults1 = Run(blobData, false, false, int.MaxValue, rgIgnoreLabels);
2147 ResultCollection rgResults = rgrgResults1[0];
2148
2149 // If masked, mask the sd to match the actual input in the blobData.
2150 m_dataTransformer.MaskImage(sd);
2151 rgrgResults.Add(new Tuple<SimpleDatum, ResultCollection>(sd, rgResults));
2152
2153 if (rgResults.ResultType == ResultCollection.RESULT_TYPE.MULTIBOX)
2154 {
2155 Dictionary<int, List<Result>> rgLabeledResults = new Dictionary<int, List<Result>>();
2156 Dictionary<int, int> rgLabeledOrder = new Dictionary<int, int>();
2157
2158 int nIdx = 0;
2159 foreach (Result result in rgResults.ResultsSorted)
2160 {
2161 if (!rgLabeledResults.ContainsKey(result.Label))
2162 {
2163 rgLabeledResults.Add(result.Label, new List<Result>());
2164 rgLabeledOrder.Add(result.Label, nIdx);
2165 nIdx++;
2166 }
2167
2168 rgLabeledResults[result.Label].Add(result);
2169 }
2170
2171 List<Tuple<int, List<Result>>> rgBestResults = new List<Tuple<int, List<Result>>>();
2172 List<int> rgDetectedLabels = rgLabeledOrder.OrderBy(p => p.Value).Select(p => p.Key).ToList();
2173
2174 if (sd.annotation_group != null)
2175 {
2176 rgDetectedLabels = rgDetectedLabels.Take(sd.annotation_group.Count).ToList();
2177
2178 for (int j = 0; j < sd.annotation_group.Count; j++)
2179 {
2180 int nExpectedLabel = sd.annotation_group[j].group_label;
2181
2182 if (!rgCorrectCounts.ContainsKey(nExpectedLabel))
2183 rgCorrectCounts.Add(nExpectedLabel, 0);
2184
2185 if (!rgLabelTotals.ContainsKey(nExpectedLabel))
2186 rgLabelTotals.Add(nExpectedLabel, 1);
2187 else
2188 rgLabelTotals[nExpectedLabel]++;
2189
2190 if (rgDetectedLabels.Contains(nExpectedLabel))
2191 {
2192 rgCorrectCounts[nExpectedLabel]++;
2193 nCorrectCount++;
2194 }
2195
2196 nTotalCount++;
2197 }
2198 }
2199 else
2200 {
2201 m_log.WriteLine("WARNING: No annotation data found in image with ID = " + sd.ImageID.ToString());
2202 }
2203 }
2204 else
2205 {
2206 int nDetectedLabel = rgResults.DetectedLabel;
2207 int nExpectedLabel = sd.Label;
2208
2209 if (!dfThreshold.HasValue || rgResults.DetectedLabelOutput >= dfThreshold.Value)
2210 {
2211 if (rgResults.ResultsOriginal.Count % 2 != 0)
2212 nMidPoint = (int)Math.Floor(rgResults.ResultsOriginal.Count / 2.0);
2213
2214
2215 if (labelMapping != null)
2216 {
2217 if (m_dataTransformer.param.label_mapping.Active)
2218 m_log.FAIL("You can use either the LabelMappingLayer or the DataTransformer label_mapping, but not both!");
2219
2220 nExpectedLabel = labelMapping.MapLabel(nExpectedLabel);
2221 }
2222
2223 if (!rgCorrectCounts.ContainsKey(nExpectedLabel))
2224 rgCorrectCounts.Add(nExpectedLabel, 0);
2225
2226 if (!rgLabelTotals.ContainsKey(nExpectedLabel))
2227 rgLabelTotals.Add(nExpectedLabel, 1);
2228 else
2229 rgLabelTotals[nExpectedLabel]++;
2230
2231 if (nExpectedLabel == nDetectedLabel)
2232 {
2233 nCorrectCount++;
2234 rgCorrectCounts[nExpectedLabel]++;
2235 }
2236
2237 if (!rgDetectedCounts.ContainsKey(nExpectedLabel))
2238 rgDetectedCounts.Add(nExpectedLabel, new Dictionary<int, int>());
2239
2240 if (!rgDetectedCounts[nExpectedLabel].ContainsKey(nDetectedLabel))
2241 rgDetectedCounts[nExpectedLabel].Add(nDetectedLabel, 0);
2242
2243 rgDetectedCounts[nExpectedLabel][nDetectedLabel]++;
2244
2245 nTotalCount++;
2246 }
2247 else
2248 {
2249 if (!rgMissedThreshold.ContainsKey(nExpectedLabel))
2250 rgMissedThreshold.Add(nExpectedLabel, 0);
2251
2252 rgMissedThreshold[nExpectedLabel]++;
2253 }
2254 }
2255
2256 double dfPct = ((double)i / (double)nCount);
2257 m_log.Progress = dfPct;
2258
2259 if (sw.ElapsedMilliseconds > 1000)
2260 {
2261 m_log.WriteLine("processing test many at " + dfPct.ToString("P"));
2262 sw.Stop();
2263 sw.Reset();
2264 sw.Start();
2265 }
2266 }
2267
2268 // Resize inputs back to unpaded.
2269 if (rgOriginalRunNetInputShape != null)
2270 {
2271 if (m_net.input_blobs != null && m_net.input_blobs.Count > 0)
2272 m_net.input_blobs[0].Reshape(rgOriginalRunNetInputShape);
2273 }
2274 }
2275 finally
2276 {
2277 if (blobData != null)
2278 {
2279 blobData.Dispose();
2280 blobData = null;
2281 }
2282 }
2283
2284 double dfCorrectPct = (nTotalCount == 0) ? 0 : ((double)nCorrectCount / (double)nTotalCount);
2285
2286 m_log.WriteLine("Test Many Completed.");
2287 m_log.WriteLine(" " + dfCorrectPct.ToString("P") + " correct detections.");
2288 m_log.WriteLine(" " + (nTotalCount - nCorrectCount).ToString("N") + " incorrect detections.");
2289
2290 foreach (KeyValuePair<int, int> kv in rgCorrectCounts.OrderBy(p => p.Key).ToList())
2291 {
2292 nCount = 0;
2293
2294 foreach (KeyValuePair<int, int> kv1 in rgLabelTotals)
2295 {
2296 if (kv1.Key == kv.Key)
2297 {
2298 nCount = kv1.Value;
2299 break;
2300 }
2301 }
2302
2303 if (nCount > 0)
2304 {
2305 string strSecondDetection = "";
2306 if (rgDetectedCounts.ContainsKey(kv.Key))
2307 {
2308 List<KeyValuePair<int, int>> rgDetectedCountsSorted = rgDetectedCounts[kv.Key].OrderByDescending(p => p.Value).ToList();
2309 if (rgDetectedCountsSorted.Count > 1)
2310 {
2311 strSecondDetection = " (secondary detections: " + rgDetectedCountsSorted[1].Key.ToString();
2312
2313 if (rgDetectedCountsSorted.Count > 2)
2314 strSecondDetection += " and " + rgDetectedCountsSorted[2].Key.ToString();
2315
2316 strSecondDetection += ")";
2317 }
2318 }
2319
2320 dfCorrectPct = ((double)kv.Value / (double)nCount);
2321 m_log.WriteLine("Label #" + kv.Key.ToString() + " had " + dfCorrectPct.ToString("P") + " correct detections out of " + nCount.ToString("N0") + " items with this label." + strSecondDetection);
2322 }
2323 }
2324
2325 if (nMidPoint > 0)
2326 {
2327 int nTotalBelow = 0;
2328 int nCorrectBelow = 0;
2329 int nTotalAbove = 0;
2330 int nCorrectAbove = 0;
2331 int nTotalBelowAndAbove = 0;
2332 int nCorrectBelowAndAbove = 0;
2333
2334 List<KeyValuePair<int, int>> rgLabelTotalsList = rgLabelTotals.OrderBy(p => p.Key).ToList();
2335 List<KeyValuePair<int, int>> rgCorrectCountsList = rgCorrectCounts.OrderBy(p => p.Key).ToList();
2336
2337 for (int i = 0; i < rgLabelTotalsList.Count; i++)
2338 {
2339 if (i < nMidPoint)
2340 {
2341 nTotalBelow += rgLabelTotalsList[i].Value;
2342 nCorrectBelow += rgCorrectCountsList[i].Value;
2343 nTotalBelowAndAbove += rgLabelTotalsList[i].Value;
2344 nCorrectBelowAndAbove += rgCorrectCountsList[i].Value;
2345 }
2346 else if (i > nMidPoint)
2347 {
2348 nTotalAbove += rgLabelTotalsList[i].Value;
2349 nCorrectAbove += rgCorrectCountsList[i].Value;
2350 nTotalBelowAndAbove += rgLabelTotalsList[i].Value;
2351 nCorrectBelowAndAbove += rgCorrectCountsList[i].Value;
2352 }
2353 }
2354
2355 dfCorrectPct = (nTotalBelow == 0) ? 0 : nCorrectBelow / (double)nTotalBelow;
2356 m_log.WriteLine("Correct below midpoint of " + nMidPoint.ToString() + " = " + dfCorrectPct.ToString("P"));
2357 dfCorrectPct = (nTotalAbove == 0) ? 0 : nCorrectAbove / (double)nTotalAbove;
2358 m_log.WriteLine("Correct above midpoint of " + nMidPoint.ToString() + " = " + dfCorrectPct.ToString("P"));
2359 dfCorrectPct = (nTotalBelowAndAbove == 0) ? 0 : nCorrectBelowAndAbove / (double)nTotalBelowAndAbove;
2360 m_log.WriteLine("Correct below and above midpoint of " + nMidPoint.ToString() + " = " + dfCorrectPct.ToString("P"));
2361 }
2362
2363 if (rgMissedThreshold.Count > 0)
2364 {
2365 m_log.WriteLine("---Missed Threshold Items---");
2366
2367 int nTotal = 0;
2368 foreach (KeyValuePair<int, int> kv in rgMissedThreshold)
2369 {
2370 m_log.WriteLine("Expected Label " + kv.Key.ToString() + ": " + kv.Value.ToString() + " items missed threshold (" + ((double)kv.Value/nImgCount).ToString("P") + ").");
2371 nTotal += kv.Value;
2372 }
2373
2374 m_log.WriteLine("A total of " + nTotal.ToString() + " items did not meet the threshold of " + dfThreshold.Value.ToString() + ", (" + ((double)nTotal / nImgCount).ToString("P") + ")");
2375 }
2376
2377 return rgrgResults;
2378 }
2379
2386 public ResultCollection Run(int nImageIdx, bool bPad = true)
2387 {
2389 m_dataTransformer.TransformLabel(sd);
2390 return Run(sd, true, bPad);
2391 }
2392
2399 public List<ResultCollection> Run(List<int> rgImageIdx, ref Blob<T> blob)
2400 {
2401 List<SimpleDatum> rgSd = new List<SimpleDatum>();
2402
2403 foreach (int nImageIdx in rgImageIdx)
2404 {
2406 m_dataTransformer.TransformLabel(sd);
2407 rgSd.Add(sd);
2408 }
2409
2410 return Run(rgSd, ref blob, false, int.MaxValue);
2411 }
2412
2418 public List<ResultCollection> Run(List<int> rgImageIdx)
2419 {
2420 List<SimpleDatum> rgSd = new List<SimpleDatum>();
2421
2422 if (m_dataSet == null)
2423 throw new Exception("Running on indexes requires a full Load that includes loading the dataset.");
2424
2425 foreach (int nImageIdx in rgImageIdx)
2426 {
2428 m_dataTransformer.TransformLabel(sd);
2429 rgSd.Add(sd);
2430 }
2431
2432 Blob<T> blob = null;
2433 List<ResultCollection> rgRes = Run(rgSd, ref blob);
2434
2435 if (blob != null)
2436 blob.Dispose();
2437
2438 return rgRes;
2439 }
2440
2441 private int getCount(List<int> rg)
2442 {
2443 int nCount = 1;
2444
2445 foreach (int nDim in rg)
2446 {
2447 nCount *= nDim;
2448 }
2449
2450 return nCount;
2451 }
2452
2460 public Blob<T> CreateDataBlob(SimpleDatum d, Blob<T> blob = null, bool bPad = true)
2461 {
2462 if (m_dataTransformer == null)
2463 {
2464 if (blob != null)
2465 blob.SetData(d, true);
2466 else
2467 blob = new Blob<T>(m_cuda, m_log, d, true);
2468 }
2469 else
2470 {
2471 if (blob == null)
2472 blob = new Blob<T>(m_cuda, m_log);
2473
2474 Datum datum = new Datum(d);
2475
2476 List<int> rgShape = m_dataTransformer.InferBlobShape(datum);
2477
2478 if (bPad)
2479 rgShape[0] = 2;
2480
2481 int nCount = getCount(rgShape);
2482 blob.Reshape(rgShape);
2483 blob.Padded = bPad;
2484
2485 if (m_rgRunData == null || m_rgRunData.Length != nCount)
2486 m_rgRunData = new T[nCount];
2487
2488 T[] rgData = m_dataTransformer.Transform(datum);
2489 Array.Copy(rgData, 0, m_rgRunData, 0, rgData.Length);
2490
2491 blob.mutable_cpu_data = m_rgRunData;
2492
2493 m_dataTransformer.SetRange(blob);
2494 }
2495
2496 return blob;
2497 }
2498
2507 public ResultCollection Run(SimpleDatum d, bool bSort, bool bUseSolverNet, bool bPad = true)
2508 {
2509 if (m_net == null)
2510 throw new Exception("The Run net has not been created!");
2511
2512 ResultCollection result = null;
2513 Blob<T> blob = null;
2514
2515 try
2516 {
2517 blob = CreateDataBlob(d, null, bPad);
2518 BlobCollection<T> colBottom = new BlobCollection<T>() { blob };
2519 double dfLoss = 0;
2520
2521 BlobCollection<T> colResults;
2522 LayerParameter.LayerType lastLayerType;
2523
2524 if (bUseSolverNet)
2525 {
2526 lastLayerType = m_solver.TrainingNet.layers[m_solver.TrainingNet.layers.Count - 1].type;
2527 colResults = m_solver.TrainingNet.Forward(colBottom, out dfLoss, bPad);
2528 }
2529 else
2530 {
2531 lastLayerType = m_net.layers[m_net.layers.Count - 1].type;
2532 colResults = m_net.Forward(colBottom, out dfLoss, bPad);
2533 }
2534
2535 if (blob.Padded)
2536 {
2537 List<int> rgShape = Utility.Clone<int>(colResults[0].shape());
2538 rgShape[0]--;
2539 if (rgShape[0] <= 0)
2540 rgShape[0] = 1;
2541 colResults[0].Reshape(rgShape);
2542 }
2543
2544 List<Result> rgResults = new List<Result>();
2545 float[] rgData = Utility.ConvertVecF<T>(colResults[0].update_cpu_data());
2546
2547 if (colResults[0].type == BLOB_TYPE.MULTIBBOX)
2548 {
2549 int nNum = rgData.Length / 7;
2550
2551 for (int n = 0; n < nNum; n++)
2552 {
2553 int i = (int)rgData[(n * 7)];
2554 int nLabel = (int)rgData[(n * 7) + 1];
2555 double dfScore = rgData[(n * 7) + 2];
2556 double[] rgExtra = new double[4];
2557 rgExtra[0] = rgData[(n * 7) + 3]; // xmin
2558 rgExtra[1] = rgData[(n * 7) + 4]; // ymin
2559 rgExtra[2] = rgData[(n * 7) + 5]; // xmax
2560 rgExtra[3] = rgData[(n * 7) + 6]; // ymax
2561
2562 rgResults.Add(new Result(nLabel, dfScore, rgExtra));
2563 }
2564 }
2565 else
2566 {
2567 for (int i = 0; i < rgData.Length; i++)
2568 {
2569 double dfProb = rgData[i];
2570 rgResults.Add(new Result(i, dfProb));
2571 }
2572 }
2573
2574 result = new ResultCollection(rgResults, lastLayerType);
2575 if (m_db != null && m_db.GetVersion() != DB_VERSION.TEMPORAL)
2577 }
2578 catch (Exception excpt)
2579 {
2580 throw excpt;
2581 }
2582 finally
2583 {
2584 if (blob != null)
2585 blob.Dispose();
2586 }
2587
2588 return result;
2589 }
2590
2599 public List<ResultCollection> Run(List<SimpleDatum> rgSd, ref Blob<T> blob, bool bUseSolverNet = false, int nMax = int.MaxValue)
2600 {
2601 m_log.CHECK(m_dataTransformer != null, "The data transformer is not initialized!");
2602
2603 if (m_net == null)
2604 throw new Exception("The Run net has not been created!");
2605
2606 List<ResultCollection> rgFinalResults = new List<ResultCollection>();
2607 int nBatchSize = rgSd.Count;
2608 int nChannels = rgSd[0].Channels;
2609 int nHeight = rgSd[0].Height;
2610 int nWidth = rgSd[0].Width;
2611 List<T> rgDataInput = new List<T>();
2612
2613 if (blob == null)
2614 blob = new common.Blob<T>(m_cuda, m_log, nBatchSize, nChannels, nHeight, nWidth, false);
2615
2616 int nCount = 0;
2617 for (int i=0; i<rgSd.Count && i < nMax; i++)
2618 {
2619 rgDataInput.AddRange(m_dataTransformer.Transform(rgSd[i]));
2620 nCount++;
2621 }
2622
2623 blob.Reshape(nCount, nChannels, nHeight, nWidth);
2624 blob.mutable_cpu_data = rgDataInput.ToArray();
2625 m_dataTransformer.SetRange(blob);
2626
2627 BlobCollection<T> colBottom = new BlobCollection<T>() { blob };
2628 double dfLoss = 0;
2629
2630 BlobCollection<T> colResults;
2631 LayerParameter.LayerType lastLayerType;
2632
2633 if (bUseSolverNet)
2634 {
2635 lastLayerType = m_solver.TrainingNet.layers[m_net.layers.Count - 1].type;
2636 m_solver.TrainingNet.SetEnablePassthrough(true);
2637 colResults = m_solver.TrainingNet.Forward(colBottom, out dfLoss);
2638 m_solver.TrainingNet.SetEnablePassthrough(false);
2639 }
2640 else
2641 {
2642 lastLayerType = m_net.layers[m_net.layers.Count - 1].type;
2643 colResults = m_net.Forward(colBottom, out dfLoss, true);
2644 }
2645
2646 T[] rgDataOutput = colResults[0].update_cpu_data();
2647 int nOutputCount = rgDataOutput.Length / rgSd.Count;
2648
2649 for (int i = 0; i < rgSd.Count && i < nMax; i++)
2650 {
2651 List<Result> rgResults = new List<Result>();
2652
2653 for (int j = 0; j < nOutputCount; j++)
2654 {
2655 int nIdx = i * nOutputCount + j;
2656 double dfProb = (double)Convert.ChangeType(rgDataOutput[nIdx], typeof(double));
2657 rgResults.Add(new Result(j, dfProb));
2658 }
2659
2660 ResultCollection result = new ResultCollection(rgResults, lastLayerType);
2661
2662 if (m_db != null && m_dataSet != null && m_db.GetVersion() != DB_VERSION.TEMPORAL)
2664
2665 rgFinalResults.Add(result);
2666 }
2667
2668 return rgFinalResults;
2669 }
2670
2680 public List<ResultCollection> Run(Blob<T> blob, bool bSort = true, bool bUseSolverNet = false, int nMax = int.MaxValue, List<int> rgIgnoreLabels = null)
2681 {
2682 m_log.CHECK(m_dataTransformer != null, "The data transformer is not initialized!");
2683
2684 if (m_net == null)
2685 throw new Exception("The Run net has not been created!");
2686
2687 if (m_dataSet == null && (m_loadToRunShape == null || m_loadToRunShape.dim.Count < 4))
2688 throw new Exception("Cannot determine the blob shape, you must either load with a database, or use LoadToRun before calling Run with a Blob. When using LoadToRun, the shape must have at least 4 dimensions.");
2689
2690 List<ResultCollection> rgFinalResults = new List<ResultCollection>();
2691 int nBatchSize = blob.num;
2692 int nChannels = (m_dataSet != null) ? m_dataSet.TestingSource.Channels : m_loadToRunShape.dim[1];
2693 if (blob.channels != nChannels)
2694 throw new Exception("The blob channels must match those of the testing dataset which has channels = " + m_dataSet.TestingSource.Channels.ToString());
2695
2696 int nHeight = (m_dataSet != null) ? m_dataSet.TestingSource.Height : m_loadToRunShape.dim[2];
2697 int nWidth = (m_dataSet != null) ? m_dataSet.TestingSource.Width : m_loadToRunShape.dim[3];
2698
2699 if (m_dataTransformer.param.resize_param != null && m_dataTransformer.param.resize_param.Active)
2700 {
2701 List<int> rgShape = m_dataTransformer.InferBlobShape(nChannels, nWidth, nHeight);
2702 nHeight = rgShape[2];
2703 nWidth = rgShape[3];
2704 }
2705
2706 if (blob.height != nHeight)
2707 throw new Exception("The blob height must match those of the testing dataset which has height = " + nHeight.ToString());
2708
2709 if (blob.width != nWidth)
2710 throw new Exception("The blob width must match those of the testing dataset which as width = " + nWidth.ToString());
2711
2712 m_dataTransformer.SetRange(blob);
2713
2714 BlobCollection<T> colBottom = new BlobCollection<T>() { blob };
2715 double dfLoss = 0;
2716
2717 BlobCollection<T> colResults;
2718 LayerParameter.LayerType lastLayerType;
2719
2720 if (bUseSolverNet)
2721 {
2722 lastLayerType = m_solver.TrainingNet.layers[m_net.layers.Count - 1].type;
2723 m_solver.TrainingNet.SetEnablePassthrough(true);
2724 colResults = m_solver.TrainingNet.Forward(colBottom, out dfLoss);
2725 m_solver.TrainingNet.SetEnablePassthrough(false);
2726 }
2727 else
2728 {
2729 lastLayerType = m_net.layers[m_net.layers.Count - 1].type;
2730 colResults = m_net.Forward(colBottom, out dfLoss, true);
2731 }
2732
2733 float[] rgData = Utility.ConvertVecF<T>(colResults[0].update_cpu_data());
2734 int nOutputCount = rgData.Length / blob.num;
2735
2736 int nNum = blob.num;
2737 if (blob.Padded)
2738 nNum--;
2739
2741
2742 for (int n = 0; n < nNum && n < nMax; n++)
2743 {
2744 List<Result> rgResults = new List<Result>();
2745
2746 if (colResults[0].type == BLOB_TYPE.MULTIBBOX)
2747 {
2748 int i = (int)rgData[(n * 7)];
2749 int nLabel = (int)rgData[(n * 7) + 1];
2750 double dfScore = rgData[(n * 7) + 2];
2751 double[] rgExtra = new double[4];
2752 rgExtra[0] = rgData[(n * 7) + 3]; // xmin
2753 rgExtra[1] = rgData[(n * 7) + 4]; // ymin
2754 rgExtra[2] = rgData[(n * 7) + 5]; // xmax
2755 rgExtra[3] = rgData[(n * 7) + 6]; // ymax
2756
2757 rgResults.Add(new Result(nLabel, dfScore, rgExtra));
2758 }
2759 else
2760 {
2761 for (int j = 0; j < nOutputCount; j++)
2762 {
2763 int nIdx = n * nOutputCount + j;
2764 double dfProb = rgData[nIdx];
2765
2766 if (rgIgnoreLabels != null && rgIgnoreLabels.Contains(j))
2767 {
2768 if (resType == ResultCollection.RESULT_TYPE.DISTANCES)
2769 dfProb = double.MaxValue;
2770 else
2771 dfProb = 0;
2772 }
2773
2774 rgResults.Add(new Result(j, dfProb));
2775 }
2776 }
2777
2778 ResultCollection result = new ResultCollection(rgResults, lastLayerType);
2779 if (m_db != null && m_db.GetVersion() != DB_VERSION.TEMPORAL)
2781
2782 rgFinalResults.Add(result);
2783 }
2784
2785 return rgFinalResults;
2786 }
2787
2798 public ResultCollection Run(Bitmap img, bool bSort = true, bool bPad = true)
2799 {
2800 if (m_net == null)
2801 throw new Exception("The Run net has not been created!");
2802
2803 int nChannels = m_inputShape.dim[1];
2804
2805 if (typeof(T) == typeof(double))
2806 return Run(ImageData.GetImageDataD(img, nChannels, false, -1), bSort, bPad);
2807 else
2808 return Run(ImageData.GetImageDataF(img, nChannels, false, -1), bSort, bPad);
2809 }
2810
2818 public ResultCollection Run(SimpleDatum d, bool bSort = true, bool bPad = true)
2819 {
2820 return Run(d, bSort, false, bPad);
2821 }
2822
2829 {
2830 Net<T> net = m_net;
2831 Phase phase = Phase.TRAIN;
2832
2833 if (customInput != null)
2834 {
2835 string strPhase = customInput.GetProperty("Phase", false);
2836 if (!string.IsNullOrEmpty(strPhase))
2837 {
2838 if (strPhase == Phase.TRAIN.ToString())
2839 phase = Phase.TRAIN;
2840 else if (strPhase == Phase.TEST.ToString())
2841 phase = Phase.TEST;
2842 else if (strPhase == Phase.RUN.ToString())
2843 phase = Phase.RUN;
2844 }
2845 }
2846
2847 m_log.WriteLine("INFO: Running TestMany with the " + phase.ToString() + " phase.");
2848 net = GetInternalNet(phase);
2849
2850 BlobCollection<T> colTop = net.Forward();
2851
2852 PropertySet res = new PropertySet();
2853
2854 foreach (Blob<T> blob in colTop)
2855 {
2856 string strName = blob.Name;
2857 float[] rgData = Utility.ConvertVecF<T>(blob.mutable_cpu_data);
2858 byte[] rgBytes = blob.ToByteArray();
2859
2860 res.SetPropertyBlob(strName, rgBytes);
2861 res.SetPropertyInt(strName, (int)blob.type);
2862 }
2863
2864 return res;
2865 }
2866
2873 {
2874 Net<T> net = m_net;
2875 Phase phase = Phase.TRAIN;
2876
2877 if (customInput != null)
2878 {
2879 string strPhase = customInput.GetProperty("Phase", false);
2880 if (!string.IsNullOrEmpty(strPhase))
2881 {
2882 if (strPhase == Phase.TRAIN.ToString())
2883 phase = Phase.TRAIN;
2884 else if (strPhase == Phase.TEST.ToString())
2885 phase = Phase.TEST;
2886 else if (strPhase == Phase.RUN.ToString())
2887 phase = Phase.RUN;
2888 }
2889 }
2890
2891 m_log.WriteLine("INFO: Running TestMany with the " + phase.ToString() + " phase.");
2892 net = GetInternalNet(phase);
2893
2894 return net.Forward();
2895 }
2896
2906 public PropertySet Run(PropertySet customInput, int nK = 1, double dfThreshold = 0.01, int nMax = 80, bool bBeamSearch = false)
2907 {
2908 m_log.CHECK_GE(nK, 1, "The K must be >= 1!");
2909
2910 BlobCollection<T> colBottom = null;
2911 Layer<T> layerInput = null;
2912 string strInput = customInput.GetProperty("InputData");
2913 string[] rgstrInput = strInput.Split('|');
2914 List<string> rgstrOutput = new List<string>();
2915 int nSeqLen = nMax;
2916
2917 foreach (string strInput1 in rgstrInput)
2918 {
2919 PropertySet input = new PropertySet("InputData=" + strInput1);
2920 string strOut = "\n";
2921
2922 if (!bBeamSearch)
2923 {
2924 foreach (Layer<T> layer in m_net.layers)
2925 {
2926 int nSeqLen1;
2927 colBottom = layer.PreProcessInput(input, out nSeqLen1, colBottom);
2928 if (colBottom != null)
2929 {
2930 layerInput = layer;
2931 nSeqLen = nSeqLen1;
2932 break;
2933 }
2934 }
2935
2936 if (colBottom == null)
2937 throw new Exception("At least one layer must support the 'PreprocessInput' method!");
2938
2939 double dfLoss;
2940 int nAxis = 1;
2941 BlobCollection<T> colTop = m_net.Forward(colBottom, out dfLoss, layerInput.SupportsPostProcessingLogits);
2942 Blob<T> blobTop = colTop[0];
2943 Layer<T> softmax = null;
2944
2945 if (m_net.layers[m_net.layers.Count - 1].layer_param.type == LayerParameter.LayerType.SOFTMAX)
2946 {
2947 softmax = m_net.layers[m_net.layers.Count - 1];
2948 nAxis = softmax.layer_param.softmax_param.axis;
2949 }
2950 else if (m_net.layers[m_net.layers.Count - 2].layer_param.type == LayerParameter.LayerType.SOFTMAX)
2951 {
2952 softmax = m_net.layers[m_net.layers.Count - 2];
2953 nAxis = softmax.layer_param.softmax_param.axis;
2954 }
2955
2956 List<string> rgOutput = new List<string>();
2957 List<Tuple<string, int, double>> res;
2958 int nCount = 0;
2959 Stopwatch sw = new Stopwatch();
2960 sw.Start();
2961
2962 for (int i = 0; i < nMax; i++)
2963 {
2964 if (layerInput.SupportsPostProcessingLogits)
2965 {
2966 nK = 10;
2967 blobTop = m_net.FindBlob("logits");
2968 if (blobTop == null)
2969 throw new Exception("Could not find the 'logits' blob!");
2970 res = layerInput.PostProcessLogitsOutput(i, blobTop, softmax, nAxis, nK);
2971 }
2972 else
2973 res = layerInput.PostProcessOutput(blobTop);
2974
2975 if (!layerInput.PreProcessInput(null, res[0].Item2, colBottom))
2976 break;
2977
2978 rgOutput.Add(res[0].Item1);
2979
2980 colTop = m_net.Forward(colBottom, out dfLoss, layerInput.SupportsPostProcessingLogits);
2981 blobTop = colTop[0];
2982 nCount++;
2983
2984 if (sw.Elapsed.TotalMilliseconds > 1000)
2985 {
2986 double dfPct = (double)nCount / nMax;
2987 m_log.WriteLine("Generating response at " + dfPct.ToString("P") + "...");
2988 }
2989 }
2990
2991 strOut = "";
2992 foreach (string str in rgOutput)
2993 {
2994 strOut += str;
2995 }
2996 }
2997 else
2998 {
2999 BeamSearch<T> search = new BeamSearch<T>(m_net);
3000
3001 List<Tuple<double, bool, List<Tuple<string, int, double>>>> res = search.Search(input, nK, dfThreshold, nMax);
3002
3003 if (res.Count > 0)
3004 {
3005 for (int i = 0; i < res[0].Item3.Count; i++)
3006 {
3007 strOut += res[0].Item3[i].Item1.ToString() + " ";
3008 }
3009 }
3010
3011 strOut = strOut.Trim();
3012 }
3013
3014 rgstrOutput.Add(strOut);
3015 }
3016
3017 string strFinal = "";
3018 foreach (string str in rgstrOutput)
3019 {
3020 strFinal += str + "|";
3021 }
3022
3023 strFinal = clean(strFinal);
3024 return new PropertySet("Results=" + strFinal);
3025 }
3026
3033 {
3034 double dfLoss;
3035 return m_net.Forward(colBottom, out dfLoss, true);
3036 }
3037
3038 private string clean(string strFinal)
3039 {
3040 string str = "";
3041
3042 foreach (char ch in strFinal)
3043 {
3044 if (ch == ';')
3045 str += ' ';
3046 else
3047 str += ch;
3048 }
3049
3050 return str;
3051 }
3052
3060 public Bitmap GetTestImage(Phase phase, out int nLabel, out string strLabel)
3061 {
3062 if (m_db.GetVersion() == DB_VERSION.TEMPORAL)
3063 throw new Exception("The GetTestImage only works with non-temporal databases.");
3064
3065 int nSrcId = (phase == Phase.TRAIN) ? m_dataSet.TrainingSource.ID : m_dataSet.TestingSource.ID;
3067 m_dataTransformer.TransformLabel(sd);
3068
3069 nLabel = sd.Label;
3070 strLabel = ((IXImageDatabaseBase)m_db).GetLabelName(nSrcId, nLabel);
3071
3072 if (strLabel == null || strLabel.Length == 0)
3073 strLabel = nLabel.ToString();
3074
3075 return new Bitmap(ImageData.GetImage(new Datum(sd), null));
3076 }
3077
3084 public Bitmap GetTestImage(Phase phase, int nLabel)
3085 {
3086 if (m_db.GetVersion() == DB_VERSION.TEMPORAL)
3087 throw new Exception("The GetTestImage only works with non-temporal databases.");
3088
3089 int nSrcId = (phase == Phase.TRAIN) ? m_dataSet.TrainingSource.ID : m_dataSet.TestingSource.ID;
3091 m_dataTransformer.TransformLabel(sd);
3092
3093 return new Bitmap(ImageData.GetImage(new Datum(sd), null));
3094 }
3095
3106 public Bitmap GetTargetImage(int nSrcId, int nIdx, out int nLabel, out string strLabel, out byte[] rgCriteria, out SimpleDatum.DATA_FORMAT fmtCriteria)
3107 {
3108 if (m_db.GetVersion() == DB_VERSION.TEMPORAL)
3109 throw new Exception("The GetTestImage only works with non-temporal databases.");
3110
3112 m_dataTransformer.TransformLabel(sd);
3113
3114 nLabel = sd.Label;
3115 strLabel = ((IXImageDatabaseBase)m_db).GetLabelName(nSrcId, nLabel);
3116
3117 if (strLabel == null || strLabel.Length == 0)
3118 strLabel = nLabel.ToString();
3119
3120 rgCriteria = sd.DataCriteria;
3121 fmtCriteria = sd.DataCriteriaFormat;
3122
3123 return new Bitmap(ImageData.GetImage(new Datum(sd), null));
3124 }
3125
3135 public Bitmap GetTargetImage(int nImageID, out int nLabel, out string strLabel, out byte[] rgCriteria, out SimpleDatum.DATA_FORMAT fmtCriteria)
3136 {
3137 if (m_db.GetVersion() == DB_VERSION.TEMPORAL)
3138 throw new Exception("The GetTestImage only works with non-temporal databases.");
3139
3141
3142 nLabel = d.Label;
3143 strLabel = ((IXImageDatabaseBase)m_db).GetLabelName(m_dataSet.TestingSource.ID, nLabel);
3144
3145 if (strLabel == null || strLabel.Length == 0)
3146 strLabel = nLabel.ToString();
3147
3148 rgCriteria = d.DataCriteria;
3149 fmtCriteria = d.DataCriteriaFormat;
3150
3151 return new Bitmap(ImageData.GetImage(new Datum(d), null));
3152 }
3153
3159 {
3160 if (m_db == null)
3161 throw new Exception("The image database is null!");
3162
3163 if (m_solver == null)
3164 throw new Exception("The solver is null - make sure that you are loaded for training.");
3165
3166 if (m_solver.net == null)
3167 throw new Exception("The solver net is null - make sure that you are loaded for training.");
3168
3169 string strSrc = m_solver.net.GetDataSource();
3170 int nSrcId = m_db.GetSourceID(strSrc);
3171
3172 return m_db.GetItemMean(nSrcId);
3173 }
3174
3180 {
3181 return m_dataSet;
3182 }
3183
3188 public byte[] GetWeights()
3189 {
3190 if (m_net != null)
3191 {
3192 m_net.ShareTrainedLayersWith(m_solver.net);
3193 return m_net.SaveWeights(m_persist);
3194 }
3195 else
3196 {
3197 return m_solver.net.SaveWeights(m_persist);
3198 }
3199 }
3200
3206 public void UpdateRunWeights(bool bOutputStatus = false, bool bVerifyWeights = true)
3207 {
3208 bool? bLogEnabled = null;
3209
3210 try
3211 {
3212 if (!bOutputStatus)
3213 {
3214 bLogEnabled = m_log.IsEnabled;
3215 m_log.Enable = false;
3216 }
3217
3218 if (m_net != null && m_bOwnRunNet)
3219 {
3220 try
3221 {
3222 int nCopyCount = 0;
3223
3224 if (m_solver.net.learnable_parameters.Count == m_net.learnable_parameters.Count)
3225 {
3226 for (int i = 0; i < m_solver.net.learnable_parameters.Count; i++)
3227 {
3228 Blob<T> b = m_solver.net.learnable_parameters[i];
3229 Blob<T> bRun = m_net.learnable_parameters[i];
3230
3231 m_log.CHECK(b.Name == bRun.Name, "The learnable parameter names do not match!");
3232 if (bRun.CopyFrom(b, false, true) != 0)
3233 nCopyCount++;
3234 }
3235 }
3236 else
3237 {
3238 for (int i = 0; i < m_net.learnable_parameters.Count; i++)
3239 {
3240 Blob<T> bRun = m_net.learnable_parameters[i];
3241 Blob<T> b = m_solver.net.FindBlob(bRun.Name);
3242
3243 if (b == null)
3244 m_log.FAIL("Could not find the run blob '" + bRun.Name + "' in the solver net!");
3245
3246 bRun.CopyFrom(b, false, true);
3247 }
3248 }
3249
3250 if (nCopyCount == 0)
3251 loadWeights(m_net, m_solver.net.SaveWeights(m_persist));
3252 }
3253 catch (Exception excpt)
3254 {
3255 m_log.WriteLine("WARNING: " + excpt.Message + ", attempting to load with legacy (slower method)...");
3256 loadWeights(m_net, m_solver.net.SaveWeights(m_persist));
3257 }
3258 }
3259
3260 if (bVerifyWeights)
3261 {
3262 if (!CompareWeights(m_net, m_solver.net))
3263 m_log.WriteLine("WARNING: The run weights differ from the training weights!");
3264 }
3265 }
3266 finally
3267 {
3268 if (bLogEnabled.HasValue)
3269 m_log.Enable = bLogEnabled.Value;
3270 }
3271 }
3272
3277 public void UpdateWeights(byte[] rgWeights)
3278 {
3279 if (m_net != null)
3280 loadWeights(m_net, rgWeights);
3281
3282 m_log.WriteLine("Updating weights in solver.");
3283
3284 List<string> rgExpectedShapes = new List<string>();
3285
3286 foreach (Blob<T> b in m_solver.TrainingNet.learnable_parameters)
3287 {
3288 rgExpectedShapes.Add(b.shape_string);
3289 }
3290
3291 bool bLoadDiffs;
3292 m_persist.LoadWeights(rgWeights, rgExpectedShapes, m_solver.TrainingNet.learnable_parameters, false, out bLoadDiffs);
3293
3294 m_solver.WeightsUpdated = true;
3295 m_log.WriteLine("Solver weights updated.");
3296 }
3297
3304 public Net<T> CreateNet(byte[] rgWeights, CudaDnn<T> cudaOverride = null)
3305 {
3306 if (cudaOverride == null)
3307 cudaOverride = m_cuda;
3308
3309 NetParameter p = (m_net != null) ? m_net.ToProto(false) : m_solver.net.ToProto(false);
3310 Net<T> net = new Net<T>(cudaOverride, m_log, p, m_evtCancel, m_db);
3311 loadWeights(net, rgWeights);
3312 return net;
3313 }
3314
3328 public Net<T> GetInternalNet(Phase phase = Phase.RUN)
3329 {
3330 if (phase == Phase.ALL)
3331 phase = m_lastPhaseRun;
3332
3333 if (phase == Phase.NONE)
3334 phase = Phase.RUN;
3335
3336 if (phase == Phase.TEST)
3337 return (m_solver != null) ? m_solver.TestingNet : null;
3338
3339 else if (phase == Phase.TRAIN)
3340 return (m_solver != null) ? m_solver.TrainingNet : null;
3341
3342 return m_net;
3343 }
3344
3350 {
3351 return m_solver;
3352 }
3353
3358 public void Snapshot(bool bUpdateDatabase = true)
3359 {
3360 m_solver.Snapshot(true, false, bUpdateDatabase);
3361 }
3362
3371 public static void ResetDevice(int nDeviceID)
3372 {
3373 }
3374
3380 public static string GetLicenseTextEx(string strOtherLicenses)
3381 {
3382 string str = Properties.Resources.LICENSE;
3383 int nYear = DateTime.Now.Year;
3384
3385 if (nYear > 2016)
3386 str = replaceMacro(str, "$$YEAR$$", "-" + nYear.ToString());
3387 else
3388 str = replaceMacro(str, "$$YEAR$$", "");
3389
3390 if (strOtherLicenses != null && strOtherLicenses.Length > 0)
3391 str = replaceMacro(str, "$$OTHERLICENSES$$", strOtherLicenses);
3392
3393 return fixupReturns(str);
3394 }
3395
3401 public string GetLicenseText(string strOtherLicenses)
3402 {
3403 return GetLicenseTextEx(strOtherLicenses);
3404 }
3405
3413 public bool VerifyCompute(string strExtra = null, int nDeviceID = -1, bool bThrowException = true)
3414 {
3415 if (m_cuda == null)
3416 throw new Exception("You must initialize the MyCaffeControl with an instance of CudaDnn<T>, or Load a new project.");
3417
3418 int nMinMajor;
3419 int nMinMinor;
3420 string strDll = m_cuda.GetRequiredCompute(out nMinMajor, out nMinMinor);
3421
3422 if (nDeviceID == -1)
3423 nDeviceID = m_cuda.GetDeviceID();
3424
3425 string strDevName = m_cuda.GetDeviceName(nDeviceID);
3426 string strCompute = parse(strDevName, "compute ", ")");
3427 string[] rgstr = strCompute.Split('.');
3428 string strMajor = rgstr[0];
3429 string strMinor = rgstr[1];
3430 if (strMajor == null || strMinor == null)
3431 throw new Exception("Could not find the current device's major and minor version information!");
3432
3433 int nMajor = int.Parse(strMajor);
3434 int nMinor = int.Parse(strMinor);
3435
3436 if (nMajor < nMinMajor || (nMajor == nMinMajor && nMinor < nMinMinor))
3437 {
3438 string strErr = "The device " + nDeviceID.ToString() + " - '" + strDevName + " does not meet the minimum compute of '" + nMinMajor.ToString() + "." + nMinMinor.ToString() + "' required by the CudaDnnDll used ('" + strDll + "')!";
3439 if (!string.IsNullOrEmpty(strExtra))
3440 strErr += strExtra;
3441 throw new Exception(strErr);
3442 }
3443
3444 return true;
3445 }
3446
3447 private string parse(string str, string strT1, string strT2)
3448 {
3449 int nPos = str.IndexOf(strT1);
3450 if (nPos < 0)
3451 return null;
3452
3453 str = str.Substring(nPos + strT1.Length);
3454 nPos = str.IndexOf(strT2);
3455 if (nPos < 0)
3456 return null;
3457
3458 return str.Substring(0, nPos).Trim();
3459 }
3460
3461 private static string replaceMacro(string str, string strMacro, string strReplacement)
3462 {
3463 int nPos = str.IndexOf(strMacro);
3464
3465 if (nPos < 0)
3466 return str;
3467
3468 string strA = str.Substring(0, nPos);
3469
3470 strA += strReplacement;
3471 strA += str.Substring(nPos + strMacro.Length);
3472
3473 return strA;
3474 }
3475
3476 private static string fixupReturns(string str)
3477 {
3478 string strOut = "";
3479
3480 foreach (char ch in str)
3481 {
3482 if (ch == '\n')
3483 strOut += "\r\n";
3484 else
3485 strOut += ch;
3486 }
3487
3488 return strOut;
3489 }
3490
3496 public Blob<T> CreateBlob(string strName)
3497 {
3498 Blob<T> b = new Blob<T>(m_cuda, m_log);
3499 b.Name = strName;
3500 return b;
3501 }
3502
3508 public long CreateExtension(string strExtensionDLLPath)
3509 {
3510 return m_cuda.CreateExtension(strExtensionDLLPath);
3511 }
3512
3517 public void FreeExtension(long hExtension)
3518 {
3519 m_cuda.FreeExtension(hExtension);
3520 }
3528 public T[] RunExtension(long hExtension, long lfnIdx, T[] rgParam)
3529 {
3530 return m_cuda.RunExtension(hExtension, lfnIdx, rgParam);
3531 }
3539 public double[] RunExtensionD(long hExtension, long lfnIdx, double[] rgParam)
3540 {
3541 T[] rgP = (rgParam == null) ? null : Utility.ConvertVec<T>(rgParam);
3542 T[] rg = m_cuda.RunExtension(hExtension, lfnIdx, rgP);
3543
3544 if (rg == null)
3545 return null;
3546
3547 return Utility.ConvertVec<T>(rg);
3548 }
3556 public float[] RunExtensionF(long hExtension, long lfnIdx, float[] rgParam)
3557 {
3558 T[] rgP = (rgParam == null) ? null : Utility.ConvertVec<T>(rgParam);
3559 T[] rg = m_cuda.RunExtension(hExtension, lfnIdx, rgP);
3560
3561 if (rg == null)
3562 return null;
3563
3564 return Utility.ConvertVecF<T>(rg);
3565 }
3566 }
3567}
The MyCaffeControl is the main object used to manage all training, testing and running of the MyCaffe...
DatasetDescriptor m_dataSet
The dataset descriptor of the dataset used in the image database.
string GetDeviceName(int nDeviceID)
Returns the device name of a given device ID.
EventHandler< TrainingIterationArgs< T > > OnTrainingIteration
The OnTrainingIteration event fires at the end of each training iteration.
static FileVersionInfo Version
Get the file version of the MyCaffe assembly running.
long CreateExtension(string strExtensionDLLPath)
Create and load a new extension DLL.
ResultCollection Run(SimpleDatum d, bool bSort=true, bool bPad=true)
Run on a given Datum.
void Snapshot(bool bUpdateDatabase=true)
The Snapshot function forces a snapshot to occur.
void RemoveCancelOverrideByName(string strEvtCancel)
Remove a cancel override.
bool Load(Phase phase, string strSolver, string strModel, byte[] rgWeights, DB_LABEL_SELECTION_METHOD? labelSelectionOverride=null, DB_ITEM_SELECTION_METHOD? itemSelectionOverride=null, bool bResetFirst=false, IXDatabaseBase db=null, bool bUseDb=true, bool bCreateRunNet=true, string strStage=null, bool bEnableMemTrace=false)
Load a project and optionally the MyCaffeImageDatabase.
bool EnableTesting
Enable/disable testing. For example reinforcement learning does not use testing.
bool? EnableDetailedNanDetection
When enabled (requires EnableBlobDebugging = true), the detailed Nan (and Infinity) detection is pero...
void LoadToRun(string strModel, byte[] rgWeights, BlobShape shape, SimpleDatum sdMean=null, TransformationParameter transParam=null, bool bForceBackward=false, bool bConvertToRunNet=true)
The LoadToRun method loads the MyCaffeControl for running only (e.g. deployment).
static void ResetDevice(int nDeviceID)
Reset the device at the given device ID.
void dispose()
Releases all GPU and Host resources used by the CaffeControl.
List< int > ActiveGpus
Returns a list of Active GPU's used by the control.
EventHandler< TestingIterationArgs< T > > OnTestingIteration
The OnTestingIteration event fires at the end of each testing iteration.
AutoResetEvent m_evtForceSnapshot
An auto-reset event used to force a snapshot.
MyCaffeControl(SettingsCaffe settings, Log log, CancelEvent evtCancel, AutoResetEvent evtSnapshot=null, AutoResetEvent evtForceTest=null, ManualResetEvent evtPause=null, List< int > rgGpuId=null, string strCudaPath="", bool bCreateCudaDnn=false, ConnectInfo ci=null)
The MyCaffeControl constructor.
SettingsCaffe Settings
Returns the settings used to create the control.
PropertySet TestMany(PropertySet customInput)
Test on custom input data.
Bitmap GetTargetImage(int nSrcId, int nIdx, out int nLabel, out string strLabel, out byte[] rgCriteria, out SimpleDatum.DATA_FORMAT fmtCriteria)
Retrives the image at a given index within the Testing data set.
bool ReInitializeParameters(WEIGHT_TARGET target, params string[] rgstrLayers)
Re-initializes each of the specified layers by re-running the filler (if any) specified by the layer....
ProjectEx m_project
The active project (if any).
ResultCollection Run(Bitmap img, bool bSort=true, bool bPad=true)
Run on a given bitmap image.
void SetOnTestingStartOverride(EventHandler onTestingStart)
Sets the root solver's onTestingStart event function triggered on the start of each testing pass.
static NetParameter CreateNetParameterForRunning(BlobShape shape, string strModel, out TransformationParameter transform_param, Stage stage=Stage.NONE, bool bSkipLossLayer=false, bool bMaintainBatchSize=false)
Creates a net parameter for the RUN phase.
int GetDeviceCount()
Returns the total number of devices installed on this computer.
int CurrentIteration
Returns the current iteration.
EventHandler< SnapshotArgs > OnSnapshot
The OnSnapshot event fires each time a snap-shot is taken.
void SetOnTestOverride(EventHandler< TestArgs > onTest)
Sets the root solver's onTest event function.
ResultCollection Run(SimpleDatum d, bool bSort, bool bUseSolverNet, bool bPad=true)
Run on a given Datum.
void Train(int nIterationOverride=-1, int nTrainingTimeLimitInMinutes=0, TRAIN_STEP step=TRAIN_STEP.NONE, double dfLearningRateOverride=0, bool bReset=false)
Train the network a set number of iterations.
bool? EnableLayerDebugging
Enable/disable layer debugging which causes each layer to check for NAN/INF on each forward/backward ...
string LabelQueryEpochs
Returns a string describing the label query epochs observed during training.
void AddCancelOverride(CancelEvent evtCancel)
Adds a cancel override.
SimpleDatum GetItemMean()
Returns the item (e.g., image or temporal item) mean used by the solver network used during training.
Net< T > GetInternalNet(Phase phase=Phase.RUN)
Returns the internal net based on the Phase specified: TRAIN, TEST or RUN.
SettingsCaffe m_settings
The settings used to configure the control.
void UpdateRunWeights(bool bOutputStatus=false, bool bVerifyWeights=true)
Loads the weights from the training net into the Net used for running.
void CopyWeightsFrom(MyCaffeControl< T > src)
Copy the learnable parameter data from the source MyCaffeControl into this one.
Blob< T > CreateBlob(string strName)
Create an unsized blob and set its name.
void FreeExtension(long hExtension)
Free an existing extension and unload it.
byte[] GetWeights()
Retrieves the weights of the training network.
ManualResetEvent m_evtPause
An auto-reset event used to pause training.
NetParameter createNetParameterForRunning(ProjectEx p, out TransformationParameter transform_param)
Creates a net parameter for the RUN phase.
ConnectInfo DatasetConnectInfo
Returns the dataset connection information, if used (default = null).
bool CompareWeights(Net< T > net1, Net< T > net2)
The CompareWeights method compares the weights held in two different Net objects.
string CurrentDevice
Returns the name of the current device used.
Solver< T > GetInternalSolver()
Get the internal solver.
bool VerifyCompute(string strExtra=null, int nDeviceID=-1, bool bThrowException=true)
VerifyCompute compares the current compute of the current device (or device specified) against the re...
IXDatabaseBase m_db
The image database.
void CopyGradientsFrom(MyCaffeControl< T > src)
Copy the learnable parameter diffs from the source MyCaffeControl into this one.
void UpdateWeights(byte[] rgWeights)
Loads the training Net with new weights.
List< ResultCollection > Run(Blob< T > blob, bool bSort=true, bool bUseSolverNet=false, int nMax=int.MaxValue, List< int > rgIgnoreLabels=null)
Run on a Blob of data.
string GetLicenseText(string strOtherLicenses)
Returns the license text for MyCaffe.
string ActiveLabelCounts
Returns a string describing the active label counts observed during training.
float[] RunExtensionF(long hExtension, long lfnIdx, float[] rgParam)
Run a function on an existing extension using the float base type.
BlobCollection< T > Run(BlobCollection< T > colBottom)
Run the network forward on the bottom blobs.
void PrepareImageMeans(ProjectEx prj)
Prepare the testing image mean by copying the training image mean if the testing image mean is missin...
Bitmap GetTestImage(Phase phase, int nLabel)
Retrieves a random image from either the training or test set depending on the Phase specified.
ResultCollection Run(int nImageIdx, bool bPad=true)
Run on a given image in the MyCaffeImageDatabase based on its image index.
Bitmap GetTargetImage(int nImageID, out int nLabel, out string strLabel, out byte[] rgCriteria, out SimpleDatum.DATA_FORMAT fmtCriteria)
Retrives the image with a given ID.
double ApplyUpdate(int nIteration)
Directs the solver to apply the leanred blob diffs to the weights using the solver's learning rate an...
AutoResetEvent m_evtForceTest
An auto-reset event used to force a test cycle.
MyCaffeControl< T > Clone(int nGpuID)
Clone the current instance of the MyCaffeControl creating a second instance.
Net< T > CreateNet(byte[] rgWeights, CudaDnn< T > cudaOverride=null)
Creates a new Net, loads the weights specified into it and returns it.
IXPersist< T > Persist
Returns the persist used to load and save weights.
BlobCollection< T > TestManyEx(PropertySet customInput)
Test on custom input data.
bool Load(Phase phase, ProjectEx p, DB_LABEL_SELECTION_METHOD? labelSelectionOverride=null, DB_ITEM_SELECTION_METHOD? itemSelectionOverride=null, bool bResetFirst=false, IXDatabaseBase db=null, bool bUseDb=true, bool bCreateRunNet=true, string strStage=null, bool bEnableMemTrace=false)
Load a project and optionally the MyCaffeImageDatabase.
DataTransformer< T > m_dataTransformer
The data transformer used to transform data.
CancelEvent m_evtCancel
The CancelEvent used to cancel training and testing operations.
void AddCancelOverrideByName(string strEvtCancel)
Adds a cancel override.
string LabelQueryHitPercents
Returns a string describing the label query hit percentages observed during training.
CudaDnn< T > Cuda
Returns the CudaDnn connection used.
ProjectEx CurrentProject
Returns the name of the currently loaded project.
void SetOnTrainingStartOverride(EventHandler onTrainingStart)
Sets the root solver's onStart event function triggered on the start of each training pass.
Log m_log
The log used for output.
NetParameter createNetParameterForRunning(DatasetDescriptor ds, string strModel, out TransformationParameter transform_param, Stage stage=Stage.NONE)
Creates a net parameter for the RUN phase.
Blob< T > CreateDataBlob(SimpleDatum d, Blob< T > blob=null, bool bPad=true)
Create a data blob from a SimpleDatum by transforming the data and placing the results in the blob re...
Bitmap GetTestImage(Phase phase, out int nLabel, out string strLabel)
Retrieves a random image from either the training or test set depending on the Phase specified.
void Unload(bool bUnloadImageDb=true, bool bIgnoreExceptions=false)
Unload the currently loaded project, if any.
static string GetLicenseTextEx(string strOtherLicenses)
Returns the license text for MyCaffe.
string CurrentStage
Returns the stage under which the project was loaded, if any.
NetParameter createNetParameterForRunning(SimpleDatum sdMean, string strModel, out TransformationParameter transform_param, out int nC, out int nH, out int nW, Stage stage=Stage.NONE)
Creates a net parameter for the RUN phase.
void RemoveCancelOverride(CancelEvent evtCancel)
Remove a cancel override.
bool EnableVerboseStatus
Get/set whether or not to use verbose status. When enabled, the full status is output when loading a ...
double[] RunExtensionD(long hExtension, long lfnIdx, double[] rgParam)
Run a function on an existing extension using the double base type.
BlobCollection< T > RunModelEx(PropertySet customInput)
Run the model using data from the model itself - requires a Data layer with the RUN phase.
bool m_bDbOwner
Whether or not the control owns the image database.
List< ResultCollection > Run(List< SimpleDatum > rgSd, ref Blob< T > blob, bool bUseSolverNet=false, int nMax=int.MaxValue)
Run on a given list of Datum.
PropertySet RunModel(PropertySet customInput)
Run the model using data from the model itself - requires a Data layer with the RUN phase.
List< Tuple< SimpleDatum, ResultCollection > > TestMany(int nCount, bool bOnTrainingSet, bool bOnTargetSet=false, DB_ITEM_SELECTION_METHOD imgSelMethod=DB_ITEM_SELECTION_METHOD.RANDOM, int nImageStartIdx=0, DateTime? dtImageStartTime=null, double? dfThreshold=null)
Test on a number of images by selecting random images from the database, running them through the Run...
double Test(int nIterationOverride=-1)
Test the network a given number of iterations.
bool? EnableBreakOnFirstNaN
Enable/disable break training after first detecting a NaN.
NetParameter createNetParameterForRunning(BlobShape shape, string strModel, out TransformationParameter transform_param, Stage stage=Stage.NONE)
Creates a net parameter for the RUN phase.
string m_strCudaPath
The low-level path of the underlying CudaDnn DLL.
Phase LastPhase
Returns the last phase run (TRAIN, TEST or RUN).
bool? EnableBlobDebugging
Enable/disable blob debugging.
List< ResultCollection > Run(List< int > rgImageIdx, ref Blob< T > blob)
Run on a set of images in the MyCaffeImageDatabase based on their image indexes.
T[] RunExtension(long hExtension, long lfnIdx, T[] rgParam)
Run a function on an existing extension.
bool? EnableSingleStep
Enable/disable single step training.
PropertySet Run(PropertySet customInput, int nK=1, double dfThreshold=0.01, int nMax=80, bool bBeamSearch=false)
Run the model on custom input data.
List< ResultCollection > Run(List< int > rgImageIdx)
Run on a set of images in the MyCaffeImageDatabase based on their image indexes.
bool LoadLite(Phase phase, string strSolver, string strModel, byte[] rgWeights=null, bool bResetFirst=false, bool bCreateRunNet=true, SimpleDatum sdMean=null, string strStage=null, bool bEnableMemTrace=false)
Load a solver and model without using the MyCaffeImageDatabase.
int MaximumIteration
Returns the maximum iteration.
List< int > m_rgGpu
A list of the Device ID's used for training.
DatasetDescriptor GetDataset()
Returns the current dataset used when training and testing.
void Add(AnnotationGroupCollection col)
Add another AnnotationGroupCollection to this one.
Definition: Annotation.cs:369
int Count
Specifies the number of items in the collection.
Definition: Annotation.cs:350
The CancelEvent provides an extension to the manual cancel event that allows for overriding the manua...
Definition: CancelEvent.cs:17
bool RemoveCancelOverride(string strName)
Remove a new cancel override.
Definition: CancelEvent.cs:167
void AddCancelOverride(string strName)
Add a new cancel override.
Definition: CancelEvent.cs:83
string Name
Return the name of the cancel event.
Definition: CancelEvent.cs:263
bool WaitOne(int nMs=int.MaxValue)
Waits for the signal state to occur.
Definition: CancelEvent.cs:290
void Set()
Sets the event to the signaled state.
Definition: CancelEvent.cs:270
The ConnectInfo class specifies the server, database and username/password used to connect to a datab...
Definition: ConnectInfo.cs:14
The Datum class is a simple wrapper to the SimpleDatum class to ensure compatibility with the origina...
Definition: Datum.cs:12
The ImageData class is a helper class used to convert between Datum, other raw data,...
Definition: ImageData.cs:14
static Bitmap GetImage(SimpleDatum d, ColorMapper clrMap=null, List< int > rgClrOrder=null)
Converts a SimplDatum (or Datum) into an image, optionally using a ColorMapper.
Definition: ImageData.cs:506
static Datum GetImageDataF(Bitmap bmp, int nChannels, bool bDataIsReal, int nLabel, bool bUseLockBitmap=true, int[] rgFocusMap=null)
The GetImageDataF function converts a Bitmap into a Datum using the float type for real data.
Definition: ImageData.cs:181
static Datum GetImageDataD(Bitmap bmp, int nChannels, bool bDataIsReal, int nLabel, bool bUseLockBitmap=true, int[] rgFocusMap=null)
The GetImageDataD function converts a Bitmap into a Datum using the double type for real data.
Definition: ImageData.cs:44
The Log class provides general output in text form.
Definition: Log.cs:13
void CHECK(bool b, string str)
Test a flag for true.
Definition: Log.cs:227
bool IsEnabled
Returns whether or not the Log is enabled.
Definition: Log.cs:50
void WriteLine(string str, bool bOverrideEnabled=false, bool bHeader=false, bool bError=false, bool bDisable=false)
Write a line of output.
Definition: Log.cs:80
bool Enable
Enables/disables the Log. When disabled, the Log does not output any data.
Definition: Log.cs:42
void FAIL(string str)
Causes a failure which throws an exception with the desciptive text.
Definition: Log.cs:394
double Progress
Get/set the progress associated with the Log.
Definition: Log.cs:147
void CHECK_EQ(double df1, double df2, string str)
Test whether one number is equal to another.
Definition: Log.cs:239
void WriteHeader(string str)
Write a header as output.
Definition: Log.cs:109
void CHECK_GT(double df1, double df2, string str)
Test whether one number is greater than another.
Definition: Log.cs:299
void CHECK_GE(double df1, double df2, string str)
Test whether one number is greater than or equal to another.
Definition: Log.cs:287
The ProjectEx class manages a project containing the solver description, model description,...
Definition: ProjectEx.cs:15
int TargetDatasetID
Get/set the dataset ID of the target dataset (if exists), otherwise return 0.
Definition: ProjectEx.cs:915
RawProto CreateModelForRunning(string strName, int nNum, int nChannels, int nHeight, int nWidth, out RawProto protoTransform, out bool bSkipTransformParam, Stage stage=Stage.NONE, bool bSkipLossLayer=false)
Create a model description as a RawProto for running the Project.
Definition: ProjectEx.cs:1041
int ID
Returns the ID of the Project in the database.
Definition: ProjectEx.cs:533
ParameterDescriptorCollection Parameters
Returns any project parameters that may exist (if any).
Definition: ProjectEx.cs:812
DatasetDescriptor Dataset
Return the descriptor of the dataset used.
Definition: ProjectEx.cs:896
string? ModelDescription
Get/set the model description script used by the Project.
Definition: ProjectEx.cs:757
byte[] WeightsState
Get/set the weight state.
Definition: ProjectEx.cs:873
DatasetDescriptor DatasetTarget
Returns the target dataset (if exists) or null if it does not.
Definition: ProjectEx.cs:907
byte[] SolverState
Get/set the solver state.
Definition: ProjectEx.cs:864
Stage Stage
Return the stage under which the project was opened.
Definition: ProjectEx.cs:603
Specifies a key-value pair of properties.
Definition: PropertySet.cs:16
string GetProperty(string strName, bool bThrowExceptions=true)
Returns a property as a string value.
Definition: PropertySet.cs:146
int GetPropertyAsInt(string strName, int nDefault=0)
Returns a property as an integer value.
Definition: PropertySet.cs:287
bool GetPropertyAsBool(string strName, bool bDefault=false)
Returns a property as a boolean value.
Definition: PropertySet.cs:267
double GetPropertyAsDouble(string strName, double dfDefault=0)
Returns a property as an double value.
Definition: PropertySet.cs:307
void SetProperty(string strName, string strVal)
Sets a property in the property set to a value if it exists, otherwise it adds the new property.
Definition: PropertySet.cs:211
The RawProtoCollection class is a list of RawProto objects.
int Count
Returns the number of items in the collection.
The RawProto class is used to parse and output Google prototxt file data.
Definition: RawProto.cs:17
string Value
Get/set the value of the node.
Definition: RawProto.cs:79
RawProto FindChild(string strName)
Searches for a given node.
Definition: RawProto.cs:231
override string ToString()
Returns the RawProto as its full prototxt string.
Definition: RawProto.cs:681
static RawProto Parse(string str)
Parses a prototxt and places it in a new RawProto.
Definition: RawProto.cs:306
RawProtoCollection FindChildren(params string[] rgstrName)
Searches for all children with a given name in this node's children.
Definition: RawProto.cs:263
The Result class contains a single result.
Definition: Result.cs:14
int Label
Returns the label.
Definition: Result.cs:36
The SettingsCaffe defines the settings used by the MyCaffe CaffeControl.
SNAPSHOT_WEIGHT_UPDATE_METHOD SnapshotWeightUpdateMethod
Get/set the snapshot update method.
SettingsCaffe Clone()
Returns a copy of the SettingsCaffe object.
DB_LOAD_METHOD DbLoadMethod
Get/set the image database loading method.
bool ItemDbLoadDebugData
Specifies whether or not to load the debug data from file (default = false).
string GpuIds
Get/set the default GPU ID's to use when training.
int MaximumIterationOverride
Get/set the maximum iteration override. When set, this overrides the training iterations specified in...
DB_VERSION DbVersion
Get/set the version of the MyCaffeImageDatabase to use.
bool ItemDbLoadDataCriteria
Specifies whether or not to load the image criteria data from file (default = false).
int TestingIterationOverride
Get/set the testing iteration override. When set, this overrides the testing iterations specified in ...
The SimpleDatum class holds a data input within host memory.
Definition: SimpleDatum.cs:161
bool GetDataValid(bool bByType=true)
Returns true if the ByteData or RealDataD or RealDataF are not null, false otherwise.
int Channels
Return the number of channels of the data.
AnnotationGroupCollection annotation_group
When using annoations, each annotation group contains an annotation for a particular class used with ...
byte[] DataCriteria
Get/set data criteria associated with the data.
DATA_FORMAT
Defines the data format of the DebugData and DataCriteria when specified.
Definition: SimpleDatum.cs:223
int Width
Return the width of the data.
int ImageID
Returns the ID of the image in the database.
int Height
Return the height of the data.
DATA_FORMAT DataCriteriaFormat
Get/set the data format of the data criteria.
int Label
Return the known label of the data.
The Utility class provides general utility funtions.
Definition: Utility.cs:35
static double[] ConvertVec(float[] rgf)
Convert an array of float to an array of generics.
Definition: Utility.cs:550
int ID
Get/set the database ID of the item.
string Name
Get/set the name of the item.
The DatasetDescriptor class describes a dataset which contains both a training data source and testin...
bool? IsGym
Returns whether or not this dataset is from a Gym.
SourceDescriptor TrainingSource
Get/set the training data source.
string? TrainingSourceName
Returns the training source name, or null if not specifies.
bool? IsModelData
Returns whether or not this dataset is from the model itself.
SourceDescriptor TestingSource
Get/set the testing data source.
string? TestingSourceName
Returns the testing source name or null if not specified.
ParameterDescriptor Find(string strName)
Searches for a parameter by name in the collection.
The ParameterDescriptor class describes a parameter in the database.
override string ToString()
Creates the string representation of the descriptor.
string Value
Get/set the value of the item.
The SourceDescriptor class contains all information describing a data source.
override string ToString()
Return a string representation of thet SourceDescriptor.
int Height
Returns the height of each data item in the data source.
int Width
Returns the width of each data item in the data source.
int Channels
Returns the item colors - 1 channel = black/white, 3 channels = RGB color.
The BeamSearch uses the softmax output from the network and continually runs the net on each output (...
Definition: BeamSearch.cs:19
List< Tuple< double, bool, List< Tuple< string, int, double > > > > Search(PropertySet input, int nK, double dfThreshold=0.01, int nMax=80)
Perform the beam-search.
Definition: BeamSearch.cs:56
The BlobCollection contains a list of Blobs.
int Count
Returns the number of items in the collection.
void Reshape(int[] rgShape)
Reshapes all blobs in the collection to the given shape.
The Blob is the main holder of data that moves through the Layers of the Net.
Definition: Blob.cs:25
int channels
DEPRECIATED; legacy shape accessor channels: use shape(1) instead.
Definition: Blob.cs:800
int height
DEPRECIATED; legacy shape accessor height: use shape(2) instead.
Definition: Blob.cs:808
long mutable_gpu_data
Returns the data GPU handle used by the CudaDnn connection.
Definition: Blob.cs:1487
string shape_string
Returns a string describing the Blob's shape.
Definition: Blob.cs:657
T[] mutable_cpu_data
Get data from the GPU and bring it over to the host, or Set data from the Host and send it over to th...
Definition: Blob.cs:1461
BLOB_TYPE type
Returns the BLOB_TYPE of the Blob.
Definition: Blob.cs:2761
byte[] ToByteArray()
Saves this Blob to a byte array.
Definition: Blob.cs:2436
void CopyFrom(Blob< T > src, int nSrcOffset, int nDstOffset, int nCount, bool bCopyData, bool bCopyDiff)
Copy from a source Blob.
Definition: Blob.cs:903
int width
DEPRECIATED; legacy shape accessor width: use shape(3) instead.
Definition: Blob.cs:816
T asum_data()
Compute the sum of absolute values (L1 norm) of the data.
Definition: Blob.cs:1706
int count()
Returns the total number of items in the Blob.
Definition: Blob.cs:739
void ReshapeLike(Blob< T > b, bool? bUseHalfSize=null)
Reshape this Blob to have the same shape as another Blob.
Definition: Blob.cs:648
string Name
Get/set the name of the Blob.
Definition: Blob.cs:2184
virtual void Dispose(bool bDisposing)
Releases all resources used by the Blob (including both GPU and Host).
Definition: Blob.cs:402
bool Padded
Get/set the padding state of the blob.
Definition: Blob.cs:284
int num
DEPRECIATED; legacy shape accessor num: use shape(0) instead.
Definition: Blob.cs:792
long gpu_data
Returns the data GPU handle used by the CudaDnn connection.
Definition: Blob.cs:1479
The CudaDnn object is the main interface to the Low-Level Cuda C++ DLL.
Definition: CudaDnn.cs:969
int GetDeviceID()
Returns the current device id set within Cuda.
Definition: CudaDnn.cs:2013
string GetRequiredCompute(out int nMinMajor, out int nMinMinor)
The GetRequiredCompute function returns the Major and Minor compute values required by the current Cu...
Definition: CudaDnn.cs:2216
void sub(int n, long hA, long hB, long hY, int nAOff=0, int nBOff=0, int nYOff=0, int nB=0)
Subtracts B from A and places the result in Y.
Definition: CudaDnn.cs:7312
string Path
Specifies the file path used to load the Low-Level Cuda DNN Dll file.
Definition: CudaDnn.cs:1924
void FreeExtension(long hExtension)
Free an instance of an Extension.
Definition: CudaDnn.cs:3474
void FreeHostBuffer(long hMem)
Free previously allocated host memory.
Definition: CudaDnn.cs:2602
int GetDeviceCount()
Query the number of devices (gpu's) installed.
Definition: CudaDnn.cs:2127
string GetDeviceName(int nDeviceID)
Query the name of a device.
Definition: CudaDnn.cs:2035
T[] RunExtension(long hExtension, long lfnIdx, T[] rgParam)
Run a function on the extension specified.
Definition: CudaDnn.cs:3489
long CreateExtension(string strExtensionDllPath)
Create an instance of an Extension DLL.
Definition: CudaDnn.cs:3456
virtual void Dispose(bool bDisposing)
Disposes this instance freeing up all of its host and GPU memory.
Definition: CudaDnn.cs:1612
The NCCL class manages the multi-GPU operations using the low-level NCCL functionality provided by th...
Definition: Parallel.cs:267
void Run(List< int > rgGpus, int nIterationOverride=-1)
Run the root Solver and coordinate with all other Solver's participating in the multi-GPU training.
Definition: Parallel.cs:351
Connects Layer's together into a direct acrylic graph (DAG) specified by a NetParameter
Definition: Net.cs:23
List< Layer< T > > layers
Returns the layers.
Definition: Net.cs:2003
BlobCollection< T > parameters
Returns the parameters.
Definition: Net.cs:2085
BlobCollection< T > Forward()
Run forward with the input Blob's already fed separately.
Definition: Net.cs:1445
BlobCollection< T > input_blobs
Returns the collection of input Blobs.
Definition: Net.cs:2201
virtual void Dispose(bool bDisposing)
Releases all resources (GPU and Host) used by the Net.
Definition: Net.cs:184
void LoadWeights(byte[] rgWeights, IXPersist< T > persist, List< string > inputWtInfo=null, List< string > targetWtInfo=null, string strSkipBlobType=null)
Loads new weights into the Net.
Definition: Net.cs:2510
NetParameter ToProto(bool bIncludeBlobs)
Writes the net to a proto.
Definition: Net.cs:1865
Blob< T > FindBlob(string strName)
Finds a Blob in the Net by name.
Definition: Net.cs:2592
byte[] SaveWeights(IXPersist< T > persist, bool bSaveDiff=false)
Save the weights to a byte array.
Definition: Net.cs:2541
BlobCollection< T > learnable_parameters
Returns the learnable parameters.
Definition: Net.cs:2117
void ShareTrainedLayersWith(Net< T > srcNet, bool bEnableLog=false)
For an already initialized net, implicitly compies (i.e., using no additional memory) the pre-trained...
Definition: Net.cs:1653
bool ReInitializeParameters(WEIGHT_TARGET target, params string[] rgstrLayers)
Re-initializes the blobs and each of the specified layers by re-running the filler (if any) specified...
Definition: Net.cs:2729
The PersistCaffe class is used to load and save weight files in the .caffemodel format.
Definition: PersistCaffe.cs:20
PersistCaffe(Log log, bool bFailOnFirstTry)
The PersistCaffe constructor.
Definition: PersistCaffe.cs:30
BlobCollection< T > LoadWeights(byte[] rgWeights, List< string > rgExpectedShapes, BlobCollection< T > colBlobs, bool bSizeToFit, out bool bLoadedDiffs, List< string > inputWtInfo=null, List< string > targetWtInfo=null, string strSkipBlobType=null)
Loads new weights into a BlobCollection
The ResultCollection contains the result of a given CaffeControl::Run.
RESULT_TYPE ResultType
Returns the result type of the result data: PROBABILITIES (Sigmoid), DISTANCES (Decode),...
List< Result > ResultsSorted
Returns the original results in sorted order.
double DetectedLabelOutput
Returns the detected label output depending on the result type (distance or probability) with a defau...
static RESULT_TYPE GetResultType(LayerParameter.LayerType type)
Get the result type based on the layer-type used.
void SetLabels(List< LabelDescriptor > rgLabels)
Sets the label names in the label dictionary lookup.
int DetectedLabel
Returns the detected label depending on the result type (distance or probability) with a default type...
RESULT_TYPE
Defines the type of result.
List< Result > ResultsOriginal
Returns the original results.
The SnapshotArgs is sent to the Solver::OnSnapshot event which fires each time the Solver::Snapshot m...
Definition: EventArgs.cs:416
Specifies the TestingIterationArgs sent to the Solver::OnTestingIteration, which is called at the end...
Definition: EventArgs.cs:216
The TrainingIterationArgs is sent to the Solver::OnTrainingIteration event that fires at the end of a...
Definition: EventArgs.cs:264
Applies common transformations to the input data, such as scaling, mirroring, subtracting the image m...
The Database class manages the actual connection to the physical database using Entity Framworks from...
Definition: Database.cs:23
The DatasetFactory manages the connection to the Database object.
SimpleDatum QueryImageMean(int nSrcId=0)
Return the SimpleDatum for the image mean from the open data source.
int GetRawImageMeanID(int nSrcId=0)
Returns the raw image ID for the image mean associated with a data source.
DatasetDescriptor LoadDataset(string strDataset, ConnectInfo ci=null)
Load a dataset descriptor from a dataset name.
bool CopyImageMean(string strSrcSrc, string strDstSrc)
Copy the raw image mean from one source to another.
SourceDescriptor LoadSource(string strSource)
Load the source descriptor from a data source name.
[V2 Image Database] The MyCaffeImageDatabase2 provides an enhanced in-memory image database used for ...
The MyCaffeImageDatabase provides an enhanced in-memory image database used for quick image retrieval...
static Tuple< DB_LABEL_SELECTION_METHOD, DB_ITEM_SELECTION_METHOD > GetSelectionMethod(SettingsCaffe s)
Returns the label/image selection methods based on the SettingsCaffe settings.
An interface for the units of computation which can be composed into a Net.
Definition: Layer.cs:31
LayerParameter.LayerType type
Returns the LayerType of this Layer.
Definition: Layer.cs:927
LayerParameter layer_param
Returns the LayerParameter for this Layer.
Definition: Layer.cs:899
virtual BlobCollection< T > PreProcessInput(PropertySet customInput, out int nSeqLen, BlobCollection< T > colBottom=null)
The PreprocessInput allows derivative data layers to convert a property set of input data into the bo...
Definition: Layer.cs:294
Specifies the parameters for the AccuracyLayer.
Specifies the shape of a Blob.
Definition: BlobShape.cs:15
List< int > dim
The blob shape dimensions.
Definition: BlobShape.cs:93
string source
When used with the DATA parameter, specifies the data 'source' within the database....
/b DEPRECIATED (use DataLayer DataLabelMappingParameter instead) Specifies the parameters for the Lab...
Specifies the base parameter for all layers.
void PrepareRunModel()
Prepare the layer settings for a run model.
LayerType type
Specifies the type of this LayerParameter.
SoftmaxParameter softmax_param
Returns the parameter set when initialized with LayerType.SOFTMAX
List< NetStateRule > include
Specifies the NetStateRule's for which this LayerParameter should be included.
AccuracyParameter accuracy_param
Returns the parameter set when initialized with LayerType.ACCURACY
string PrepareRunModelInputs()
Prepare model inputs for the run-net (if any are needed for the layer).
TransformationParameter transform_param
Returns the parameter set when initialized with LayerType.TRANSFORM
DataParameter data_param
Returns the parameter set when initialized with LayerType.DATA
LayerType
Specifies the layer type.
LabelMappingParameter labelmapping_param
Returns the parameter set when initialized with LayerType.LABELMAPPING
Specifies the parameters use to create a Net
Definition: NetParameter.cs:18
static NetParameter FromProto(RawProto rp)
Parse a RawProto into a new instance of the parameter.
NetState state
The current 'state' of the network, including the phase, level and stage. Some layers may be included...
bool force_backward
Whether the network will force every layer to carry out backward operation. If set False,...
override RawProto ToProto(string strName)
Constructor for the parameter.
List< LayerParameter > layer
The layers that make up the net. Each of their configurations, including connectivity and behavior,...
static Dictionary< string, BlobShape > InputFromProto(RawProto rp)
Collect the inputs from the RawProto.
Phase phase
Specifies the Phase of the NetState.
Definition: NetState.cs:63
List< string > stage
Specifies the stages of the NetState.
Definition: NetState.cs:83
Specifies a NetStateRule used to determine whether a Net falls within a given include or exclude patt...
Definition: NetStateRule.cs:20
Phase phase
Set phase to require the NetState to have a particular phase (TRAIN or TEST) to meet this rule.
Definition: NetStateRule.cs:99
int axis
The axis along which to perform the softmax – may be negative to index from the end (e....
The SolverParameter is a parameter for the solver, specifying the train and test networks.
NetParameter net_param
Inline train net param, possibly combined with one or more test nets.
static SolverParameter FromProto(RawProto rp)
Parses a new SolverParameter from a RawProto.
Stores parameters used to apply transformation to the data layer's data.
bool use_imagedb_mean
Specifies whether to subtract the mean image from the image database, subtract the mean values,...
static TransformationParameter FromProto(RawProto rp)
Parses the parameter from a RawProto.
Specifies the parameters for the DecodeLayer and the AccuracyEncodingLayer.
An interface for classes that perform optimization on Nets - this class serves as the base class for ...
Definition: Solver.cs:28
void Dispose()
Discards the resources (GPU and Host) used by this Solver.
Definition: Solver.cs:218
EventHandler< TrainingIterationArgs< T > > OnTrainingIteration
The OnTrainingIteration event fires at the end of each training iteration.
Definition: Solver.cs:134
static SGDSolver< T > Create(CudaDnn< T > cuda, Log log, ProjectEx p, CancelEvent evtCancel, AutoResetEvent evtForceSnapshot, AutoResetEvent evtForceTest, IXDatabaseBase db, IXPersist< T > persist, int nSolverCount=1, int nSolverRank=0, Net< T > shareNet=null, onGetWorkspace getws=null, onSetWorkspace setws=null)
Create a new Solver based on the project containing the SolverParameter.
Definition: Solver.cs:1889
void Restore(byte[] rgWeights, byte[] rgState, string strSkipBlobTypes=null)
The restore method simply calls the RestoreSolverState method of the inherited class.
Definition: Solver.cs:1097
void Snapshot(bool bForced, bool bScheduled, bool bUpdateDatabase=true)
The snapshot function implements the basic snapshotting utility that stores the learned net....
Definition: Solver.cs:1115
int MaximumIteration
Returns the maximum training iterations.
Definition: Solver.cs:700
EventHandler< SnapshotArgs > OnSnapshot
The OnSnapshot event fires when the Solver detects that a snapshot is needed.
Definition: Solver.cs:130
bool EnableBlobDebugging
When enabled, the OnTrainingIteration event is set extra debugging information describing the state o...
Definition: Solver.cs:361
Net< T > net
Returns the main training Net.
Definition: Solver.cs:1229
bool ForceOnTrainingIterationEvent()
Force an OnTrainingIterationEvent to fire.
Definition: Solver.cs:236
double LearningRateOverride
Get/set the learning rate override. When 0, this setting is ignored.
Definition: Solver.cs:227
bool EnableTesting
When enabled, the training cycle calls TestAll periodically based on the SolverParameter....
Definition: Solver.cs:352
Net< T > TrainingNet
Returns the training Net used by the solver.
Definition: Solver.cs:445
EventHandler< TestingIterationArgs< T > > OnTestingIteration
The OnTestingIteration event fires at the end of each testing iteration.
Definition: Solver.cs:138
bool EnableBreakOnFirstNaN
When enabled (requires EnableBlobDebugging = true), the Solver immediately stop training upon detecti...
Definition: Solver.cs:382
SNAPSHOT_WEIGHT_UPDATE_METHOD SnapshotWeightUpdateMethod
Get/set the snapshot weight update method.
Definition: Solver.cs:301
int TrainingTimeLimitInMinutes
Get/set the training time limit in minutes. When set to 0, no time limit is imposed on training.
Definition: Solver.cs:292
void Reset()
Reset the iterations of the net.
Definition: Solver.cs:478
double TestAll(int nIterationOverride=-1)
Run a TestAll by running all test Nets.
Definition: Solver.cs:1322
string LabelQueryEpochs
Return the label query epochs for the active datasource.
Definition: Solver.cs:684
int TrainingIterationOverride
Get/set the training iteration override.
Definition: Solver.cs:1179
EventHandler OnTestStart
The OnTestStart event fires at the start of each testing iteration.
Definition: Solver.cs:150
bool WeightsUpdated
Get/set when the weights have been updated.
Definition: Solver.cs:413
bool EnableDetailedNanDetection
When enabled (requires EnableBlobDebugging = true), the detailed Nan (and Infinity) detection is pero...
Definition: Solver.cs:395
EventHandler< TestArgs > OnTest
When specified, the OnTest event fires during a TestAll and overrides the call to Test.
Definition: Solver.cs:146
int TestingIterationOverride
Get/set the testing iteration override.
Definition: Solver.cs:1188
virtual void Solve(int nIterationOverride=-1, byte[] rgWeights=null, byte[] rgState=null, TRAIN_STEP step=TRAIN_STEP.NONE)
The main entry of the solver function. In default, iter will be zero. Pass in a non-zero iter number ...
Definition: Solver.cs:744
EventHandler OnStart
The OnStart event fires at the start of each training iteration.
Definition: Solver.cs:118
string ActiveLabelCounts
Returns a string describing the labels detected in the training along with the % that each label has ...
Definition: Solver.cs:668
abstract double ApplyUpdate(int nIterationOverride=-1)
Make and apply the update value for the current iteration.
bool EnableSingleStep
When enabled (requires EnableBlobDebugging = true), the Solver only runs one training cycle.
Definition: Solver.cs:404
string LabelQueryHitPercents
Return the label query hit percentages for the active datasource.
Definition: Solver.cs:676
Net< T > TestingNet
Returns the testing Net used by the solver.
Definition: Solver.cs:431
bool EnableLayerDebugging
Enable/disable layer debugging which causes each layer to check for NAN/INF on each forward/backward ...
Definition: Solver.cs:373
int CurrentIteration
Returns the current training iteration.
Definition: Solver.cs:692
The Component class is a standard Microsoft.NET class that implements the IComponent interface and is...
Definition: Component.cs:18
The IXDatabaseBase interface defines the general interface to the in-memory database.
Definition: Interfaces.cs:444
SimpleDatum GetItem(int nItemID, params int[] rgSrcId)
Get the item (e.g., image or temporal item) with a given Raw Item ID.
SimpleDatum GetItemMean(int nSrcId)
Returns the item (e.g., image or temporal item) mean for a data source.
int GetSourceID(string strSrc)
Returns a data source ID given its name.
SimpleDatum QueryItem(int nSrcId, int nIdx, DB_LABEL_SELECTION_METHOD? labelSelectionOverride=null, DB_ITEM_SELECTION_METHOD? imageSelectionOverride=null, int? nLabel=null, bool bLoadDataCriteria=false, bool bLoadDebugData=false)
Query an image in a given data source.
Tuple< DB_LABEL_SELECTION_METHOD, DB_ITEM_SELECTION_METHOD > GetSelectionMethod()
Returns the label and image selection method used.
SimpleDatum QueryItemMean(int nSrcId)
Queries the item (e.g., image or temporal item) mean for a data source from the database on disk.
void SetSelectionMethod(DB_LABEL_SELECTION_METHOD? lbl, DB_ITEM_SELECTION_METHOD? img)
Sets the label and image selection methods.
void CleanUp(int nDsId=0, bool bForce=false)
Releases the image database, and if this is the last instance using the in-memory database,...
DB_VERSION GetVersion()
Returns the version of the MyCaffe Image Database being used.
List< SimpleDatum > GetItemsFromTime(int nSrcId, DateTime dtStart, int nQueryCount=int.MaxValue, string strFilterVal=null, int? nBoostVal=null, bool bBoostValIsExact=false)
Returns the array of items (e.g., images or temporal items) in the item set, possibly filtered with t...
The IXImageDatabase interface defines the eneral interface to the in-memory image database.
Definition: Interfaces.cs:1004
The IXImageDatabase2 interface defines the general interface to the in-memory image database (v2).
Definition: Interfaces.cs:1092
The IXImageDatabaseBase interface defines the general interface to the in-memory image database.
Definition: Interfaces.cs:878
The IXMyCaffeExtension interface allows for easy extension management of the low-level software that ...
Definition: Interfaces.cs:614
The IXMyCaffe interface contains functions used to perform MyCaffe operations that work with the MyCa...
Definition: Interfaces.cs:410
The IXMyCaffeNoDb interface contains functions used to perform MyCaffe operations that run in a light...
Definition: Interfaces.cs:564
The IXMyCaffeState interface contains functions related to the MyCaffeComponent state.
Definition: Interfaces.cs:258
The IXPersist interface is used by the CaffeControl to load and save weights.
Definition: Interfaces.cs:187
The descriptors namespace contains all descriptor used to describe various items stored within the da...
The MyCaffe.basecode contains all generic types used throughout MyCaffe.
Definition: Annotation.cs:12
DB_ITEM_SELECTION_METHOD
Defines the item (e.g., image or temporal item) selection method.
Definition: Interfaces.cs:278
Phase
Defines the Phase under which to run a Net.
Definition: Interfaces.cs:61
DB_VERSION
Defines the image database version to use.
Definition: Interfaces.cs:397
DB_LABEL_SELECTION_METHOD
Defines the label selection method.
Definition: Interfaces.cs:333
Stage
Specifies the stage underwhich to run a custom trainer.
Definition: Interfaces.cs:88
The MyCaffe.common namespace contains common MyCaffe classes.
Definition: BatchInput.cs:8
DEVINIT
Specifies the initialization flags used when initializing CUDA.
Definition: CudaDnn.cs:207
BLOB_TYPE
Defines the tpe of data held by a given Blob.
Definition: Interfaces.cs:62
TRAIN_STEP
Defines the training stepping method (if any).
Definition: Interfaces.cs:131
WEIGHT_TARGET
Defines the type of weight to target in re-initializations.
Definition: Interfaces.cs:38
The MyCaffe.data namespace contains dataset creators used to create common testing datasets such as M...
Definition: BinaryFile.cs:16
The MyCaffe.db.image namespace contains all image database related classes.
Definition: Database.cs:18
The MyCaffe.layers namespace contains all layers that have a solidified code base,...
Definition: LayerFactory.cs:15
The MyCaffe.param.beta parameters are used by the MyCaffe.layer.beta layers.
The MyCaffe.param namespace contains parameters used to create models.
The MyCaffe.solvers namespace contains all solver classes, including the base Solver.
The MyCaffe namespace contains the main body of MyCaffe code that closesly tracks the C++ Caffe open-...
Definition: Annotation.cs:12