Deep learning software for Windows C# programmers.
1using System;
2using System.Collections.Generic;
4using System.Diagnostics;
5using System.Linq;
6using System.Text;
7using System.Threading;
8using System.Drawing;
9using System.Threading.Tasks;
10using System.IO;
11using MyCaffe.basecode;
13using MyCaffe.db.image;
14using MyCaffe.solvers;
15using MyCaffe.common;
16using MyCaffe.param;
18using MyCaffe.layers;
19using System.Globalization;
20using System.Reflection;
21using System.Security.Cryptography;
22using System.Net;
28namespace MyCaffe
34 public partial class MyCaffeControl<T> : Component, IXMyCaffeState<T>, IXMyCaffe<T>, IXMyCaffeNoDb<T>, IXMyCaffeExtension<T>, IDisposable
35 {
43 protected Log m_log;
47 protected IXDatabaseBase m_db = null;
51 protected bool m_bDbOwner = true;
59 protected AutoResetEvent m_evtForceSnapshot;
63 protected AutoResetEvent m_evtForceTest;
67 protected ManualResetEvent m_evtPause;
75 protected ProjectEx m_project = null;
79 protected DatasetDescriptor m_dataSet = null;
83 protected string m_strCudaPath = null;
87 protected List<int> m_rgGpu;
88 CudaDnn<T> m_cuda;
89 Solver<T> m_solver;
90 Net<T> m_net;
91 bool m_bOwnRunNet = true;
92 MemoryStream m_msWeights = new MemoryStream();
93 Guid m_guidUser;
94 PersistCaffe<T> m_persist;
95 BlobShape m_inputShape = null;
96 Phase m_lastPhaseRun = Phase.NONE;
97 long m_hCopyBuffer = 0;
98 string m_strStage = null;
99 bool m_bLoadLite = false;
100 string m_strSolver = null; // Used with LoadLite.
101 string m_strModel = null; // Used with LoadLite.
102 ManualResetEvent m_evtSyncUnload = new ManualResetEvent(false);
103 ManualResetEvent m_evtSyncMain = new ManualResetEvent(false);
104 ConnectInfo m_dsCi = null;
105 bool m_bEnableVerboseStatus = false;
106 T[] m_rgRunData = null;
107 BlobShape m_loadToRunShape = null;
112 public event EventHandler<SnapshotArgs> OnSnapshot;
116 public event EventHandler<TrainingIterationArgs<T>> OnTrainingIteration;
120 public event EventHandler<TestingIterationArgs<T>> OnTestingIteration;
136 public MyCaffeControl(SettingsCaffe settings, Log log, CancelEvent evtCancel, AutoResetEvent evtSnapshot = null, AutoResetEvent evtForceTest = null, ManualResetEvent evtPause = null, List<int> rgGpuId = null, string strCudaPath = "", bool bCreateCudaDnn = false, ConnectInfo ci = null)
137 {
138 m_dsCi = ci;
139 m_guidUser = Guid.NewGuid();
141 InitializeComponent();
143 if (evtCancel == null)
144 throw new ArgumentNullException("The cancel event must be specified!");
146 if (evtSnapshot == null)
147 evtSnapshot = new AutoResetEvent(false);
149 if (evtForceTest == null)
150 evtForceTest = new AutoResetEvent(false);
152 if (evtPause == null)
153 evtPause = new ManualResetEvent(false);
155 m_log = log;
156 m_settings = settings;
157 m_evtCancel = evtCancel;
158 m_evtForceSnapshot = evtSnapshot;
159 m_evtForceTest = evtForceTest;
160 m_evtPause = evtPause;
162 if (rgGpuId == null)
163 {
164 m_rgGpu = new List<int>();
165 string[] rgstrGpuId = settings.GpuIds.Split(',');
167 foreach (string str in rgstrGpuId)
168 {
169 string strGpuId = str.Trim(' ', '\t', '\n', '\r');
170 m_rgGpu.Add(int.Parse(strGpuId));
171 }
172 }
173 else
174 {
175 m_rgGpu = Utility.Clone<int>(rgGpuId);
176 }
178 if (m_rgGpu.Count == 0)
179 m_rgGpu.Add(0);
181 m_strCudaPath = strCudaPath;
182 m_persist = new common.PersistCaffe<T>(m_log, false);
184 if (bCreateCudaDnn)
185 m_cuda = new CudaDnn<T>(m_rgGpu[0], DEVINIT.CUBLAS | DEVINIT.CURAND, null, m_strCudaPath, false);
186 }
191 public void dispose()
192 {
193 if (m_evtSyncMain.WaitOne(0))
194 return;
196 m_evtSyncMain.Set();
198 try
199 {
200 if (m_evtCancel != null)
203 if (m_hCopyBuffer != 0)
204 {
205 try
206 {
207 m_cuda.FreeHostBuffer(m_hCopyBuffer);
208 }
209 catch
210 {
211 }
213 m_hCopyBuffer = 0;
214 }
216 Unload(true, true);
218 if (m_cuda != null)
219 {
220 try
221 {
222 m_cuda.Dispose();
223 }
224 catch
225 {
226 }
228 m_cuda = null;
229 }
231 if (m_msWeights != null)
232 {
233 m_msWeights.Dispose();
234 m_msWeights = null;
235 }
237 if (m_dataTransformer != null)
238 {
239 m_dataTransformer.Dispose();
240 m_dataTransformer = null;
241 }
242 }
243 finally
244 {
245 m_evtSyncMain.Reset();
246 }
247 }
252 public static FileVersionInfo Version
253 {
254 get
255 {
256 string strLocation = Assembly.GetExecutingAssembly().Location;
257 return FileVersionInfo.GetVersionInfo(strLocation);
258 }
259 }
265 {
266 get { return m_dsCi; }
267 }
272 public string CurrentStage
273 {
274 get { return m_strStage; }
275 }
285 public MyCaffeControl<T> Clone(int nGpuID)
286 {
288 s.GpuIds = nGpuID.ToString();
290 MyCaffeControl<T> mycaffe = new MyCaffeControl<T>(s, m_log, m_evtCancel, null, null, null, null, m_strCudaPath);
292 if (m_bLoadLite)
293 mycaffe.LoadLite(Phase.TRAIN, m_strSolver, m_strModel, null);
294 else
295 mycaffe.Load(Phase.TRAIN, m_project, null, null, false, m_db, (m_db == null) ? false : true, true, m_strStage);
297 Net<T> netSrc = GetInternalNet(Phase.TRAIN);
298 Net<T> netDst = mycaffe.GetInternalNet(Phase.TRAIN);
300 m_log.CHECK_EQ(netSrc.learnable_parameters.Count, netDst.learnable_parameters.Count, "The src and dst networks do not have the same number of learnable parameters!");
302 for (int i = 0; i < netSrc.learnable_parameters.Count; i++)
303 {
304 Blob<T> bSrc = netSrc.learnable_parameters[i];
305 Blob<T> bDst = netDst.learnable_parameters[i];
307 mycaffe.m_hCopyBuffer = bDst.CopyFrom(bSrc, false, false, mycaffe.m_hCopyBuffer);
308 }
310 return mycaffe;
311 }
318 {
319 Net<T> netSrc = src.GetInternalNet(Phase.TRAIN);
320 Net<T> netDst = GetInternalNet(Phase.TRAIN);
322 m_log.CHECK_EQ(netSrc.learnable_parameters.Count, netDst.learnable_parameters.Count, "The src and dst networks do not have the same number of learnable parameters!");
324 for (int i = 0; i < netSrc.learnable_parameters.Count; i++)
325 {
326 Blob<T> bSrc = netSrc.learnable_parameters[i];
327 Blob<T> bDst = netDst.learnable_parameters[i];
329 m_hCopyBuffer = bDst.CopyFrom(bSrc, true, false, m_hCopyBuffer);
330 }
331 }
338 {
339 Net<T> netSrc = src.GetInternalNet(Phase.TRAIN);
340 Net<T> netDst = GetInternalNet(Phase.TRAIN);
342 m_log.CHECK_EQ(netSrc.learnable_parameters.Count, netDst.learnable_parameters.Count, "The src and dst networks do not have the same number of learnable parameters!");
344 for (int i = 0; i < netSrc.learnable_parameters.Count; i++)
345 {
346 Blob<T> bSrc = netSrc.learnable_parameters[i];
347 Blob<T> bDst = netDst.learnable_parameters[i];
349 m_hCopyBuffer = bDst.CopyFrom(bSrc, false, false, m_hCopyBuffer);
350 }
351 }
359 public double ApplyUpdate(int nIteration)
360 {
361 return m_solver.ApplyUpdate(nIteration);
362 }
367 public bool EnableTesting
368 {
369 get { return m_solver.EnableTesting; }
370 set { m_solver.EnableTesting = value; }
371 }
377 {
378 get { return m_bEnableVerboseStatus; }
379 set { m_bEnableVerboseStatus = value; }
380 }
387 public void Unload(bool bUnloadImageDb = true, bool bIgnoreExceptions = false)
388 {
389 if (m_solver == null && m_net == null)
390 return;
392 if (m_evtSyncUnload.WaitOne(0))
393 return;
395 m_evtSyncUnload.Set();
397 try
398 {
399 if (m_solver != null)
400 {
401 m_solver.Dispose();
402 m_solver = null;
403 }
405 if (m_net != null)
406 {
407 if (m_bOwnRunNet)
408 m_net.Dispose();
409 m_net = null;
410 }
412 if (m_db != null && bUnloadImageDb)
413 {
414 if (m_bDbOwner)
415 {
416 if (m_dataSet != null)
419 IDisposable idisp = m_db as IDisposable;
420 if (idisp != null)
421 idisp.Dispose();
422 }
424 m_db = null;
425 }
427 m_project = null;
428 }
429 catch (Exception excpt)
430 {
431 if (!bIgnoreExceptions)
432 throw excpt;
433 }
434 finally
435 {
436 m_evtSyncUnload.Reset();
437 }
438 }
448 public bool ReInitializeParameters(WEIGHT_TARGET target, params string[] rgstrLayers)
449 {
450 Net<T> net = GetInternalNet(Phase.TRAIN);
451 net.ReInitializeParameters(target, rgstrLayers);
452 return m_solver.ForceOnTrainingIterationEvent();
453 }
459 public void SetOnTestOverride(EventHandler<TestArgs> onTest)
460 {
461 m_solver.OnTest += onTest;
462 }
468 public void SetOnTrainingStartOverride(EventHandler onTrainingStart)
469 {
470 m_solver.OnStart += onTrainingStart;
471 }
477 public void SetOnTestingStartOverride(EventHandler onTestingStart)
478 {
479 m_solver.OnTestStart += onTestingStart;
480 }
487 public void AddCancelOverrideByName(string strEvtCancel)
488 {
489 m_evtCancel.AddCancelOverride(strEvtCancel);
490 }
496 public void AddCancelOverride(CancelEvent evtCancel)
497 {
499 }
505 public void RemoveCancelOverrideByName(string strEvtCancel)
506 {
507 m_evtCancel.RemoveCancelOverride(strEvtCancel);
508 }
514 public void RemoveCancelOverride(CancelEvent evtCancel)
515 {
517 }
526 {
527 get { return (m_solver == null) ? false : m_solver.EnableBlobDebugging; }
528 set
529 {
530 if (m_solver != null)
531 m_solver.EnableBlobDebugging = value;
532 }
533 }
542 {
543 get { return (m_solver == null) ? false : m_solver.EnableBreakOnFirstNaN; }
544 set
545 {
546 if (m_solver != null)
547 m_solver.EnableBreakOnFirstNaN = value;
548 }
549 }
555 {
556 get { return (m_solver == null) ? false : m_solver.EnableDetailedNanDetection; }
557 set
558 {
559 if (m_solver != null)
560 m_solver.EnableDetailedNanDetection = value;
561 }
562 }
571 {
572 get { return (m_solver == null) ? false : m_solver.EnableLayerDebugging; }
573 set
574 {
575 if (m_solver != null)
576 m_solver.EnableLayerDebugging = value;
577 }
578 }
587 {
588 get { return (m_solver == null) ? false : m_solver.EnableSingleStep; }
589 set
590 {
591 if (m_solver != null)
592 m_solver.EnableSingleStep = value;
593 }
594 }
600 {
601 get { return m_dataTransformer; }
602 }
608 {
609 get { return m_settings; }
610 }
616 {
617 get { return m_cuda; }
618 }
623 public Log Log
624 {
625 get { return m_log; }
626 }
632 {
633 get { return m_persist; }
634 }
640 {
641 get { return m_db; }
642 }
648 {
649 get { return m_evtCancel; }
650 }
655 public List<int> ActiveGpus
656 {
657 get { return m_rgGpu; }
658 }
666 public string ActiveLabelCounts
667 {
668 get { return m_solver.ActiveLabelCounts; }
669 }
678 {
679 get { return m_solver.LabelQueryHitPercents; }
680 }
688 public string LabelQueryEpochs
689 {
690 get { return m_solver.LabelQueryEpochs; }
691 }
696 public string CurrentDevice
697 {
698 get
699 {
700 int nId = m_cuda.GetDeviceID();
701 return "GPU #" + nId.ToString() + " " + m_cuda.GetDeviceName(nId);
702 }
703 }
709 {
710 get { return m_project; }
711 }
717 {
718 get { return m_solver.CurrentIteration; }
719 }
725 {
726 get { return m_solver.MaximumIteration; }
727 }
733 public int GetDeviceCount()
734 {
735 return m_cuda.GetDeviceCount();
736 }
743 public string GetDeviceName(int nDeviceID)
744 {
745 return m_cuda.GetDeviceName(nDeviceID);
746 }
752 {
753 get { return m_lastPhaseRun; }
754 }
766 {
767 return createNetParameterForRunning(p.Dataset, p.ModelDescription, out transform_param, p.Stage);
768 }
781 protected NetParameter createNetParameterForRunning(DatasetDescriptor ds, string strModel, out TransformationParameter transform_param, Stage stage = Stage.NONE)
782 {
783 BlobShape shape = datasetToShape(ds);
784 return createNetParameterForRunning(shape, strModel, out transform_param, stage);
785 }
798 protected NetParameter createNetParameterForRunning(BlobShape shape, string strModel, out TransformationParameter transform_param, Stage stage = Stage.NONE)
799 {
800 NetParameter param = CreateNetParameterForRunning(shape, strModel, out transform_param, stage);
801 m_inputShape = shape;
802 return param;
803 }
819 protected NetParameter createNetParameterForRunning(SimpleDatum sdMean, string strModel, out TransformationParameter transform_param, out int nC, out int nH, out int nW, Stage stage = Stage.NONE)
820 {
821 nC = 0;
822 nH = 0;
823 nW = 0;
825 if (sdMean != null)
826 {
827 nC = sdMean.Channels;
828 nH = sdMean.Height;
829 nW = sdMean.Width;
830 }
831 else
832 {
833 RawProto protoModel = RawProto.Parse(strModel);
834 RawProtoCollection layers = protoModel.FindChildren("layer");
836 foreach (RawProto layer in layers)
837 {
838 RawProto type = layer.FindChild("type");
839 if (type != null && type.Value == "Input")
840 {
841 RawProto input_param = layer.FindChild("input_param");
842 if (input_param != null)
843 {
844 RawProto shape1 = input_param.FindChild("shape");
845 if (shape1 != null)
846 {
847 RawProtoCollection rgDim = shape1.FindChildren("dim");
848 int nNum = (rgDim.Count > 0) ? int.Parse(rgDim[0].Value) : 1;
849 nC = (rgDim.Count > 1) ? int.Parse(rgDim[1].Value) : 1;
850 nH = (rgDim.Count > 2) ? int.Parse(rgDim[2].Value) : 1;
851 nW = (rgDim.Count > 3) ? int.Parse(rgDim[3].Value) : 1;
852 break;
853 }
854 }
855 }
856 }
857 }
859 if (nC == 0 && nH == 0 && nW == 0)
860 throw new Exception("Could not dicern the shape to use for no 'sdMean' parameter was supplied and the model does not contain an 'Input' layer!");
862 BlobShape shape = new BlobShape(1, nC, nH, nW);
863 NetParameter param = CreateNetParameterForRunning(shape, strModel, out transform_param, stage);
864 m_inputShape = shape;
866 return param;
867 }
882 public static NetParameter CreateNetParameterForRunning(BlobShape shape, string strModel, out TransformationParameter transform_param, Stage stage = Stage.NONE, bool bSkipLossLayer = false, bool bMaintainBatchSize = false)
883 {
884 int nNum = (bMaintainBatchSize) ? shape.dim[0] : 1;
885 int nImageChannels = shape.dim[1];
886 int nImageHeight = (shape.dim.Count > 2) ? shape.dim[2] : 1;
887 int nImageWidth = (shape.dim.Count > 3) ? shape.dim[3] : 1;
889 transform_param = null;
891 RawProto protoTransform = null;
892 bool bSkipTransformParam = false;
893 RawProto protoModel = ProjectEx.CreateModelForRunning(strModel, "data", nNum, nImageChannels, nImageHeight, nImageWidth, out protoTransform, out bSkipTransformParam, stage, bSkipLossLayer);
895 if (!bSkipTransformParam)
896 {
897 if (protoTransform != null)
898 transform_param = TransformationParameter.FromProto(protoTransform);
899 else
900 transform_param = new param.TransformationParameter();
902 if (transform_param.resize_param != null && transform_param.resize_param.Active)
903 {
904 shape.dim[2] = (int)transform_param.resize_param.height;
905 shape.dim[3] = (int)transform_param.resize_param.width;
906 }
907 }
909 NetParameter np = NetParameter.FromProto(protoModel);
911 string strInput = "";
912 foreach (LayerParameter layer in np.layer)
913 {
914 layer.PrepareRunModel();
916 string strInput1 = layer.PrepareRunModelInputs();
917 if (!string.IsNullOrEmpty(strInput1))
918 strInput += strInput1;
919 }
921 if (!string.IsNullOrEmpty(strInput))
922 {
923 RawProto proto = RawProto.Parse(strInput);
924 Dictionary<string, BlobShape> rgInput = NetParameter.InputFromProto(proto);
926 if (rgInput.Count > 0)
927 {
928 np.input = new List<string>();
929 np.input_dim = new List<int>();
930 np.input_shape = new List<BlobShape>();
932 foreach (KeyValuePair<string, BlobShape> kv in rgInput)
933 {
934 np.input.Add(kv.Key);
935 np.input_shape.Add(kv.Value);
936 }
937 }
938 }
941 np.ProjectID = 0;
942 np.state.phase = Phase.RUN;
944 return np;
945 }
947 private BlobShape datasetToShape(DatasetDescriptor ds)
948 {
949 int nH = 1;
950 int nW = 1;
951 int nC = 1;
953 if (!ds.IsModelData)
954 {
955 nH = ds.TestingSource.Height;
956 nW = ds.TestingSource.Width;
957 nC = ds.TestingSource.Channels;
958 }
960 List<int> rgShape = new List<int>() { 1, nC, nH, nW };
961 return new BlobShape(rgShape);
962 }
964 private Stage getStage(string strStage)
965 {
966 if (strStage == Stage.RNN.ToString())
967 return Stage.RNN;
969 if (strStage == Stage.RL.ToString())
970 return Stage.RL;
972 return Stage.NONE;
973 }
975 private string addStage(string strModel, Phase phase, string strStage)
976 {
977 if (string.IsNullOrEmpty(strStage))
978 return strModel;
980 RawProto proto = RawProto.Parse(strModel);
981 NetParameter param = NetParameter.FromProto(proto);
983 param.state.stage.Clear();
984 param.state.phase = phase;
985 param.state.stage.Add(strStage);
987 return param.ToProto("root", true).ToString();
988 }
995 {
996 if (prj.Dataset.IsModelData || prj.Dataset.IsGym)
997 return;
999 DatasetFactory factory = new DatasetFactory();
1001 // Copy the training image mean to the testing source if it does not have a mean.
1002 // NOTE: This this will not impact a service based image database that is already loaded,
1003 // - it must be reloaded.
1004 int nDstID = factory.GetRawImageMeanID(prj.Dataset.TestingSource.ID);
1005 if (nDstID == 0)
1006 {
1007 int nSrcID = factory.GetRawImageMeanID(prj.Dataset.TrainingSource.ID);
1008 if (nSrcID != 0)
1010 }
1012 if (prj.DatasetTarget != null)
1013 {
1014 // Copy the training image mean to the testing source if it does not have a mean.
1015 // NOTE: This this will not impact a service based image database that is already loaded,
1016 // - it must be reloaded.
1017 nDstID = factory.GetRawImageMeanID(prj.DatasetTarget.TestingSource.ID);
1018 if (nDstID == 0)
1019 {
1020 int nSrcID = factory.GetRawImageMeanID(prj.DatasetTarget.TrainingSource.ID);
1021 if (nSrcID != 0)
1023 }
1024 }
1025 }
1027 private bool verifySharedWeights()
1028 {
1029 Net<T> netTest = m_solver.TestingNet;
1030 if (netTest != null)
1031 {
1032 Net<T> netTrain = m_solver.TrainingNet;
1034 if (netTrain.parameters.Count != netTest.parameters.Count)
1035 {
1036 m_log.WriteLine("WARNING: Training net has a different number of parameters than the testing net!");
1037 return false;
1038 }
1040 for (int i = 0; i < netTrain.parameters.Count; i++)
1041 {
1042 if (netTrain.parameters[i].gpu_data != netTest.parameters[i].gpu_data)
1043 {
1044 m_log.WriteLine("WARNING: Training net parameter " + i.ToString() + " is not shared with the testing net!");
1045 return false;
1046 }
1047 }
1048 }
1050 return true;
1051 }
1070 public bool Load(Phase phase, ProjectEx p, DB_LABEL_SELECTION_METHOD? labelSelectionOverride = null, DB_ITEM_SELECTION_METHOD? itemSelectionOverride = null, bool bResetFirst = false, IXDatabaseBase db = null, bool bUseDb = true, bool bCreateRunNet = true, string strStage = null, bool bEnableMemTrace = false)
1071 {
1072 try
1073 {
1074 m_log.Enable = m_bEnableVerboseStatus;
1076 DatasetFactory factory = new DatasetFactory();
1077 m_strStage = strStage;
1078 m_db = db;
1079 m_bDbOwner = false;
1081 if (db != null && bUseDb)
1082 {
1083 if (m_settings.DbVersion != db.GetVersion())
1084 throw new Exception("The database version in the settings (" + m_settings.DbVersion.ToString() + ") must match the database version (" + db.GetVersion().ToString() + ") of the 'db' parameter!");
1085 }
1087 if (m_db == null && bUseDb)
1088 {
1089 switch (m_settings.DbVersion)
1090 {
1091 case DB_VERSION.IMG_V1:
1093 m_log.WriteLine("Loading primary images...", true);
1094 m_bDbOwner = true;
1095 m_log.Enable = true;
1096 ((IXImageDatabase1)m_db).InitializeWithDs1(m_settings, p.Dataset, m_evtCancel.Name);
1097 break;
1100 throw new NotImplementedException("The temporal database is not yet supported.");
1102 default:
1104 m_log.WriteLine("Loading primary images...", true);
1105 m_bDbOwner = true;
1106 m_log.Enable = true;
1107 ((IXImageDatabase2)m_db).InitializeWithDs(m_settings, p.Dataset, m_evtCancel.Name);
1108 break;
1109 }
1111 if (m_evtCancel.WaitOne(0))
1112 return false;
1114 // m_db.UpdateLabelBoosts(p.ID, p.Dataset.TrainingSource.ID);
1116 Tuple<DB_LABEL_SELECTION_METHOD, DB_ITEM_SELECTION_METHOD> selMethod = MyCaffeImageDatabase.GetSelectionMethod(p);
1117 DB_LABEL_SELECTION_METHOD lblSel = selMethod.Item1;
1118 DB_ITEM_SELECTION_METHOD imgSel = selMethod.Item2;
1120 if (labelSelectionOverride.HasValue)
1121 lblSel = labelSelectionOverride.Value;
1123 if (itemSelectionOverride.HasValue)
1124 imgSel = itemSelectionOverride.Value;
1126 m_db.SetSelectionMethod(lblSel, imgSel);
1128 m_log.WriteLine("Images loaded.");
1130 if (p.TargetDatasetID > 0)
1131 {
1132 DatasetDescriptor dsTarget = factory.LoadDataset(p.TargetDatasetID);
1134 m_log.WriteLine("Loading target dataset '" + dsTarget.Name + "' images using " + m_settings.DbLoadMethod.ToString() + " loading method.");
1135 string strType = "images";
1137 switch (m_settings.DbVersion)
1138 {
1139 case DB_VERSION.IMG_V1:
1140 ((IXImageDatabase1)m_db).LoadDatasetByID1(dsTarget.ID);
1141 break;
1144 strType = "items";
1145 throw new NotImplementedException("The temporal database is not yet supported.");
1147 default:
1148 ((IXImageDatabase2)m_db).LoadDatasetByID(dsTarget.ID);
1149 break;
1150 }
1153 m_log.WriteLine("Target dataset " + strType + " loaded.");
1154 }
1156 m_log.Enable = m_bEnableVerboseStatus;
1157 }
1159 p.ModelDescription = addStage(p.ModelDescription, phase, strStage);
1160 m_project = p;
1161 m_project.Stage = getStage(m_strStage);
1163 if (m_project == null)
1164 throw new Exception("You must specify a project.");
1168 if (m_cuda != null)
1169 m_cuda.Dispose();
1171 m_cuda = new CudaDnn<T>(m_rgGpu[0], DEVINIT.CUBLAS | DEVINIT.CURAND, null, m_strCudaPath, bResetFirst, bEnableMemTrace);
1173 m_log.WriteLine("Cuda Connection created using '" + m_cuda.Path + "'.", true);
1175 if (phase == Phase.TEST || phase == Phase.TRAIN)
1176 {
1177 m_log.WriteLine("Creating solver...", true);
1179 m_solver = Solver<T>.Create(m_cuda, m_log, p, m_evtCancel, m_evtForceSnapshot, m_evtForceTest, m_db, m_persist, m_rgGpu.Count, 0);
1181 if (p.WeightsState != null || p.SolverState != null)
1182 {
1183 string strSkipBlobType = null;
1185 ParameterDescriptor param = p.Parameters.Find("ModelResized");
1186 if (param != null && param.Value == "True")
1187 strSkipBlobType = BLOB_TYPE.IP_WEIGHT.ToString();
1189 if (p.SolverState != null)
1190 m_solver.Restore(p.WeightsState, p.SolverState, strSkipBlobType);
1191 else
1192 m_solver.TrainingNet.LoadWeights(p.WeightsState, m_persist);
1193 }
1195 m_solver.OnSnapshot += new EventHandler<SnapshotArgs>(m_solver_OnSnapshot);
1196 m_solver.OnTrainingIteration += new EventHandler<TrainingIterationArgs<T>>(m_solver_OnTrainingIteration);
1197 m_solver.OnTestingIteration += new EventHandler<TestingIterationArgs<T>>(m_solver_OnTestingIteration);
1198 m_log.WriteLine("Solver created.", true);
1200 verifySharedWeights();
1201 }
1203 if (m_db is IXImageDatabase1)
1204 {
1205#warning ImageDatabase V1 only
1206 if (phase == Phase.TRAIN && m_db != null)
1207 ((IXImageDatabase1)m_db).UpdateLabelBoosts(p.ID, m_dataSet.TrainingSource.ID);
1209 if (phase == Phase.TEST && m_db != null)
1210 ((IXImageDatabase1)m_db).UpdateLabelBoosts(p.ID, m_dataSet.TestingSource.ID);
1211 }
1213 if (phase == Phase.RUN && !bCreateRunNet)
1214 throw new Exception("You cannot opt out of creating the Run net when using the RUN phase.");
1216 if (p == null || !bCreateRunNet)
1217 return true;
1219 TransformationParameter tp = null;
1220 NetParameter netParam = createNetParameterForRunning(p, out tp);
1222 m_dataTransformer = null;
1224 if (tp != null && !p.Dataset.IsModelData && !p.Dataset.IsGym && m_settings.DbVersion != DB_VERSION.TEMPORAL)
1225 {
1226 SimpleDatum sdMean = (m_db == null) ? null : m_db.QueryItemMean(m_dataSet.TrainingSource.ID);
1231 if (sdMean != null)
1232 {
1233 m_log.CHECK_EQ(nC, sdMean.Channels, "The mean channel count does not match the datasets channel count.");
1234 m_log.CHECK_EQ(nH, sdMean.Height, "The mean height count does not match the datasets height count.");
1235 m_log.CHECK_EQ(nW, sdMean.Width, "The mean width count does not match the datasets width count.");
1236 }
1238 m_dataTransformer = new DataTransformer<T>(m_cuda, m_log, tp, Phase.RUN, nC, nH, nW, sdMean);
1239 }
1241 m_log.WriteLine("Creating run net...", true);
1243 if (phase == Phase.RUN)
1244 {
1245 m_net = new Net<T>(m_cuda, m_log, netParam, m_evtCancel, m_db);
1247 if (p.WeightsState != null)
1248 {
1249 m_log.WriteLine("Loading run weights...", true);
1250 loadWeights(m_net, p.WeightsState);
1251 }
1252 }
1253 else if (phase == Phase.TEST || phase == Phase.TRAIN)
1254 {
1255 try
1256 {
1257 m_net = new Net<T>(m_cuda, m_log, netParam, m_evtCancel, m_db, Phase.RUN, null, m_solver.TrainingNet);
1258 }
1259 catch (Exception excpt)
1260 {
1261 m_log.WriteLine("WARNING: Failed to create run net - using Test net instead for running. Error = " + excpt.Message);
1262 m_net = m_solver.TestingNet;
1263 m_bOwnRunNet = false;
1264 }
1265 }
1266 }
1267 catch (Exception excpt)
1268 {
1269 throw excpt;
1270 }
1271 finally
1272 {
1273 m_log.Enable = true;
1274 }
1276 return true;
1277 }
1298 public bool Load(Phase phase, string strSolver, string strModel, byte[] rgWeights, DB_LABEL_SELECTION_METHOD? labelSelectionOverride = null, DB_ITEM_SELECTION_METHOD? itemSelectionOverride = null, bool bResetFirst = false, IXDatabaseBase db = null, bool bUseDb = true, bool bCreateRunNet = true, string strStage = null, bool bEnableMemTrace = false)
1299 {
1300 try
1301 {
1302 m_log.Enable = m_bEnableVerboseStatus;
1304 m_strStage = strStage;
1305 m_db = db;
1306 m_bDbOwner = false;
1308 RawProto protoSolver = RawProto.Parse(strSolver);
1309 SolverParameter solverParam = SolverParameter.FromProto(protoSolver);
1311 strModel = addStage(strModel, phase, strStage);
1313 RawProto protoModel = RawProto.Parse(strModel);
1314 solverParam.net_param = NetParameter.FromProto(protoModel);
1316 m_dataSet = findDataset(solverParam.net_param);
1318 if (db != null && bUseDb)
1319 {
1320 if (m_settings.DbVersion != db.GetVersion())
1321 throw new Exception("The database version in the settings (" + m_settings.DbVersion.ToString() + ") must match the database version (" + db.GetVersion().ToString() + ") of the 'db' parameter!");
1322 }
1324 if (m_db == null && bUseDb)
1325 {
1326 string strType = "images";
1327 switch (m_settings.DbVersion)
1328 {
1329 case DB_VERSION.IMG_V1:
1331 ((IXImageDatabase1)m_db).InitializeWithDs1(m_settings, m_dataSet, m_evtCancel.Name);
1332 m_log.WriteLine("Loading primary images...", true);
1333 m_log.Enable = true;
1334 break;
1336 case DB_VERSION.IMG_V2:
1339 m_log.WriteLine("Loading primary images...", true);
1340 m_log.Enable = true;
1341 break;
1344 throw new NotImplementedException("The temporal database is not yet implemented!");
1345 }
1347 m_bDbOwner = true;
1348 if (m_evtCancel.WaitOne(0))
1349 return false;
1351 Tuple<DB_LABEL_SELECTION_METHOD, DB_ITEM_SELECTION_METHOD> selMethod = MyCaffeImageDatabase.GetSelectionMethod(m_settings);
1352 DB_LABEL_SELECTION_METHOD lblSel = selMethod.Item1;
1353 DB_ITEM_SELECTION_METHOD imgSel = selMethod.Item2;
1355 if (labelSelectionOverride.HasValue)
1356 lblSel = labelSelectionOverride.Value;
1358 if (itemSelectionOverride.HasValue)
1359 imgSel = itemSelectionOverride.Value;
1361 m_db.SetSelectionMethod(lblSel, imgSel);
1363 m_log.WriteLine("Images loaded.", true);
1365 DatasetDescriptor dsTarget = findDataset(solverParam.net_param, m_dataSet);
1366 if (dsTarget != null)
1367 {
1368 m_log.WriteLine("Loading target dataset '" + dsTarget.Name + "' " + strType + " using " + m_settings.DbLoadMethod.ToString() + " loading method.", true);
1370 switch (m_settings.DbVersion)
1371 {
1372 case DB_VERSION.IMG_V1:
1373 ((IXImageDatabase1)m_db).LoadDatasetByID1(dsTarget.ID);
1374 break;
1376 case DB_VERSION.IMG_V2:
1377 ((IXImageDatabase2)m_db).LoadDatasetByID(dsTarget.ID);
1378 break;
1381 throw new NotImplementedException("The temporal database is not yet implemented!");
1382 }
1385 m_log.WriteLine("Target dataset " + strType + " loaded.", true);
1386 }
1388 m_log.Enable = m_bEnableVerboseStatus;
1389 }
1391 m_project = null;
1393 if (m_cuda != null)
1394 m_cuda.Dispose();
1396 m_cuda = new CudaDnn<T>(m_rgGpu[0], DEVINIT.CUBLAS | DEVINIT.CURAND, null, m_strCudaPath, bResetFirst, bEnableMemTrace);
1397 m_log.WriteLine("Cuda Connection created using '" + m_cuda.Path + "'.", true);
1399 if (phase == Phase.TEST || phase == Phase.TRAIN)
1400 {
1401 m_log.WriteLine("Creating solver...", true);
1403 m_solver = Solver<T>.Create(m_cuda, m_log, solverParam, m_evtCancel, m_evtForceSnapshot, m_evtForceTest, m_db, m_persist, m_rgGpu.Count, 0);
1405 if (rgWeights != null)
1406 {
1407 m_log.WriteLine("Restoring weights...", true);
1408 m_solver.Restore(rgWeights, null);
1409 }
1411 m_solver.OnSnapshot += new EventHandler<SnapshotArgs>(m_solver_OnSnapshot);
1412 m_solver.OnTrainingIteration += new EventHandler<TrainingIterationArgs<T>>(m_solver_OnTrainingIteration);
1413 m_solver.OnTestingIteration += new EventHandler<TestingIterationArgs<T>>(m_solver_OnTestingIteration);
1414 m_log.WriteLine("Solver created.", true);
1415 }
1417 if (!bCreateRunNet)
1418 {
1419 if (phase == Phase.RUN)
1420 throw new Exception("You cannot opt out of creating the Run net when using the RUN phase.");
1422 return true;
1423 }
1425 TransformationParameter tp = null;
1426 NetParameter netParam = createNetParameterForRunning(m_dataSet, strModel, out tp);
1428 m_dataTransformer = null;
1430 if (tp != null)
1431 {
1432 SimpleDatum sdMean = (m_db == null) ? null : m_db.QueryItemMean(m_dataSet.TrainingSource.ID);
1433 int nC = 0;
1434 int nH = 0;
1435 int nW = 0;
1437 if (sdMean != null)
1438 {
1439 nC = sdMean.Channels;
1440 nH = sdMean.Height;
1441 nW = sdMean.Width;
1442 }
1443 else if (m_project != null)
1444 {
1448 }
1450 if (nC == 0 || nH == 0 || nW == 0)
1451 throw new Exception("Unable to size the Data Transformer for there is no Mean or Project to gather the sizing information from.");
1453 m_dataTransformer = new DataTransformer<T>(m_cuda, m_log, tp, Phase.RUN, nC, nH, nW, sdMean);
1454 }
1456 m_log.WriteLine("Creating run net...", true);
1458 if (phase == Phase.RUN)
1459 {
1460 m_net = new Net<T>(m_cuda, m_log, netParam, m_evtCancel, m_db);
1462 if (rgWeights != null)
1463 {
1464 m_log.WriteLine("Loading run weights...", true);
1465 loadWeights(m_net, rgWeights);
1466 }
1467 }
1468 else if (phase == Phase.TEST || phase == Phase.TRAIN)
1469 {
1470 netParam.force_backward = true;
1472 try
1473 {
1474 m_net = new Net<T>(m_cuda, m_log, netParam, m_evtCancel, m_db, Phase.RUN, null, m_solver.TrainingNet);
1475 }
1476 catch (Exception excpt)
1477 {
1478 m_log.WriteLine("WARNING: Failed to create run net - using Test net instead for running. Error = " + excpt.Message);
1479 m_net = m_solver.TestingNet;
1480 m_bOwnRunNet = false;
1481 }
1482 }
1483 }
1484 catch (Exception excpt)
1485 {
1486 throw excpt;
1487 }
1488 finally
1489 {
1490 m_log.Enable = true;
1491 }
1493 return true;
1494 }
1512 public bool LoadLite(Phase phase, string strSolver, string strModel, byte[] rgWeights = null, bool bResetFirst = false, bool bCreateRunNet = true, SimpleDatum sdMean = null, string strStage = null, bool bEnableMemTrace = false)
1513 {
1514 try
1515 {
1516 m_log.Enable = m_bEnableVerboseStatus;
1518 m_bLoadLite = true;
1519 m_strSolver = strSolver;
1520 m_strModel = strModel;
1522 m_strStage = strStage;
1523 m_db = null;
1524 m_bDbOwner = false;
1526 RawProto protoSolver = RawProto.Parse(strSolver);
1527 SolverParameter solverParam = SolverParameter.FromProto(protoSolver);
1529 strModel = addStage(strModel, phase, strStage);
1531 RawProto protoModel = RawProto.Parse(strModel);
1532 solverParam.net_param = NetParameter.FromProto(protoModel);
1534 m_dataSet = null;
1535 m_project = null;
1537 if (m_cuda != null)
1538 m_cuda.Dispose();
1540 m_cuda = new CudaDnn<T>(m_rgGpu[0], DEVINIT.CUBLAS | DEVINIT.CURAND, null, m_strCudaPath, bResetFirst, bEnableMemTrace);
1541 m_log.WriteLine("Cuda Connection created using '" + m_cuda.Path + "'.", true);
1543 if (phase == Phase.TEST || phase == Phase.TRAIN)
1544 {
1545 m_log.WriteLine("Creating solver...", true);
1547 m_solver = Solver<T>.Create(m_cuda, m_log, solverParam, m_evtCancel, m_evtForceSnapshot, m_evtForceTest, m_db, m_persist, m_rgGpu.Count, 0);
1549 if (rgWeights != null)
1550 m_solver.Restore(rgWeights, null);
1552 m_solver.OnSnapshot += new EventHandler<SnapshotArgs>(m_solver_OnSnapshot);
1553 m_solver.OnTrainingIteration += new EventHandler<TrainingIterationArgs<T>>(m_solver_OnTrainingIteration);
1554 m_solver.OnTestingIteration += new EventHandler<TestingIterationArgs<T>>(m_solver_OnTestingIteration);
1555 m_log.WriteLine("Solver created.");
1556 }
1558 if (!bCreateRunNet)
1559 {
1560 if (phase == Phase.RUN)
1561 throw new Exception("You cannot opt out of creating the Run net when using the RUN phase.");
1563 return true;
1564 }
1566 TransformationParameter tp = null;
1567 int nC = 0;
1568 int nH = 0;
1569 int nW = 0;
1570 NetParameter netParam = createNetParameterForRunning(sdMean, strModel, out tp, out nC, out nH, out nW);
1572 m_dataTransformer = null;
1574 if (tp != null)
1575 {
1576 if (nC == 0 || nH == 0 || nW == 0)
1577 throw new Exception("Unable to size the Data Transformer for no Mean image was provided as the 'sdMean' parameter which is used to gather the sizing information.");
1579 m_dataTransformer = new DataTransformer<T>(m_cuda, m_log, tp, Phase.RUN, nC, nH, nW, sdMean);
1580 }
1582 m_log.WriteLine("Creating run net...", true);
1584 if (phase == Phase.RUN)
1585 {
1586 m_net = new Net<T>(m_cuda, m_log, netParam, m_evtCancel, null);
1588 if (rgWeights != null)
1589 {
1590 m_log.WriteLine("Loading run weights...", true);
1591 loadWeights(m_net, rgWeights);
1592 }
1593 }
1594 else if (phase == Phase.TEST || phase == Phase.TRAIN)
1595 {
1596 netParam.force_backward = true;
1598 try
1599 {
1600 m_net = new Net<T>(m_cuda, m_log, netParam, m_evtCancel, null, Phase.RUN, null, m_solver.TrainingNet);
1601 }
1602 catch (Exception excpt)
1603 {
1604 m_log.WriteLine("WARNING: Failed to create run net - using Test net instead for running. Error = " + excpt.Message);
1605 m_net = m_solver.TestingNet;
1606 m_bOwnRunNet = false;
1607 }
1608 }
1609 }
1610 catch (Exception excpt)
1611 {
1612 throw excpt;
1613 }
1614 finally
1615 {
1616 m_log.Enable = true;
1617 }
1619 return true;
1620 }
1637 public void LoadToRun(string strModel, byte[] rgWeights, BlobShape shape, SimpleDatum sdMean = null, TransformationParameter transParam = null, bool bForceBackward = false, bool bConvertToRunNet = true)
1638 {
1639 try
1640 {
1641 m_log.Enable = m_bEnableVerboseStatus;
1642 m_dataSet = null;
1643 m_project = null;
1644 m_loadToRunShape = shape;
1646 if (m_cuda != null)
1647 m_cuda.Dispose();
1649 m_cuda = new CudaDnn<T>(m_rgGpu[0], DEVINIT.CUBLAS | DEVINIT.CURAND, null, m_strCudaPath);
1650 m_log.WriteLine("Cuda Connection created using '" + m_cuda.Path + "'.", true);
1652 TransformationParameter tp = null;
1653 NetParameter netParam = null;
1655 if (bConvertToRunNet)
1656 {
1657 netParam = createNetParameterForRunning(shape, strModel, out tp);
1658 }
1659 else
1660 {
1661 netParam = NetParameter.FromProto(RawProto.Parse(strModel));
1663 foreach (LayerParameter layer in netParam.layer)
1664 {
1665 if (layer.transform_param != null)
1666 {
1667 tp = layer.transform_param;
1668 break;
1669 }
1670 }
1671 }
1673 netParam.force_backward = bForceBackward;
1675 if (transParam != null)
1676 tp = transParam;
1678 if (tp != null)
1679 {
1680 if (tp.use_imagedb_mean && sdMean == null)
1681 throw new Exception("The transformer expects an image mean, yet the sdMean parameter is null!");
1683 m_dataTransformer = new DataTransformer<T>(m_cuda, m_log, tp, Phase.RUN, shape.dim[1], shape.dim[2], shape.dim[3], sdMean);
1684 }
1685 else
1686 {
1687 m_dataTransformer = null;
1688 }
1690 m_log.WriteLine("Creating run net...", true);
1691 m_net = new Net<T>(m_cuda, m_log, netParam, m_evtCancel, null, Phase.RUN);
1693 m_log.WriteLine("Loading weights...", true);
1694 loadWeights(m_net, rgWeights);
1695 }
1696 catch (Exception excpt)
1697 {
1698 throw excpt;
1699 }
1700 finally
1701 {
1702 m_log.Enable = true;
1703 }
1704 }
1706 private SimpleDatum getMeanImage(NetParameter p)
1707 {
1708 string strSrc = null;
1710 foreach (LayerParameter lp in p.layer)
1711 {
1712 if (lp.type == LayerParameter.LayerType.TRANSFORM)
1713 {
1715 return null;
1716 }
1717 else if (lp.type == LayerParameter.LayerType.DATA)
1718 {
1719 switch (lp.type)
1720 {
1721 case LayerParameter.LayerType.DATA:
1722 strSrc = lp.data_param.source;
1723 break;
1724 }
1725 }
1726 }
1728 if (strSrc == null)
1729 throw new Exception("Could not find the data source in the model!");
1731 DatasetFactory factory = new DatasetFactory();
1732 SourceDescriptor sd = factory.LoadSource(strSrc);
1734 if (sd == null)
1735 throw new Exception("Could not find the data source '" + strSrc + "' in the database.");
1737 return factory.QueryImageMean(sd.ID);
1738 }
1740 private DatasetDescriptor findDataset(NetParameter p, DatasetDescriptor dsPrimary = null)
1741 {
1742 string strTestSrc = null;
1743 string strTrainSrc = null;
1745 foreach (LayerParameter lp in p.layer)
1746 {
1747 if (lp.type == LayerParameter.LayerType.DATA)
1748 {
1749 string strSrc = null;
1751 switch (lp.type)
1752 {
1753 case LayerParameter.LayerType.DATA:
1754 strSrc = lp.data_param.source;
1755 break;
1756 }
1758 foreach (NetStateRule rule in lp.include)
1759 {
1760 if (rule.phase == Phase.TRAIN)
1761 strTrainSrc = strSrc;
1762 else if (rule.phase == Phase.TEST)
1763 strTestSrc = strSrc;
1764 }
1765 }
1767 if (strTrainSrc != null && strTestSrc != null)
1768 {
1769 if (dsPrimary == null || (strTrainSrc != dsPrimary.TrainingSourceName && strTestSrc != dsPrimary.TestingSourceName))
1770 break;
1771 }
1772 }
1774 if (strTrainSrc == null || strTestSrc == null)
1775 return null;
1777 if (dsPrimary != null && (strTrainSrc == dsPrimary.TrainingSourceName && strTestSrc == dsPrimary.TestingSourceName))
1778 return null;
1780 DatasetFactory factory = new DatasetFactory();
1781 DatasetDescriptor ds = factory.LoadDataset(strTestSrc, strTrainSrc);
1783 if (ds == null)
1784 throw new Exception("The datset sources '" + strTestSrc + "' and '" + strTrainSrc + "' do not exist in the database - do you need to load them?");
1786 return ds;
1787 }
1789 private void loadWeights(Net<T> net, byte[] rgWeights)
1790 {
1791 net.LoadWeights(rgWeights, m_persist);
1792 }
1800 public bool CompareWeights(Net<T> net1, Net<T> net2)
1801 {
1802 if (net1.learnable_parameters.Count != net2.learnable_parameters.Count)
1803 {
1804 m_log.WriteLine("WARNING: The number of learnable parameters differs between the two nets!");
1805 return false;
1806 }
1808 Blob<T> blobWork = new Blob<T>(m_cuda, m_log, false);
1810 try
1811 {
1812 for (int i = 0; i < net1.learnable_parameters.Count; i++)
1813 {
1814 Blob<T> blob1 = net1.learnable_parameters[i];
1815 Blob<T> blob2 = net2.learnable_parameters[i];
1817 if (blob1.Name != blob2.Name)
1818 {
1819 m_log.WriteLine("WARNING: The name of the blobs at index " + i.ToString() + " differ: net1 - " + blob1.Name + " vs. net2 - " + blob2.Name);
1820 return false;
1821 }
1823 if (blob1.shape_string != blob2.shape_string)
1824 {
1825 m_log.WriteLine("WARNING: The shape of the blobs at index " + i.ToString() + " differ: net1 - " + blob1.Name + " " + blob1.shape_string + " vs. net2 - " + blob2.Name + " " + blob2.shape_string);
1826 return false;
1827 }
1829 blobWork.ReshapeLike(blob1);
1830 m_cuda.sub(blob1.count(), blob1.gpu_data, blob2.gpu_data, blobWork.mutable_gpu_data);
1831 double dfSum = Utility.ConvertVal<T>(blobWork.asum_data());
1832 if (dfSum != 0)
1833 {
1834 m_log.WriteLine("WARNING: The data of the blobs at index " + i.ToString() + " differ: net1 - " + blob1.Name + " " + blob1.shape_string + " vs. net2 - " + blob2.Name + " " + blob2.shape_string);
1835 return false;
1836 }
1837 }
1838 }
1839 finally
1840 {
1841 blobWork.Dispose();
1842 }
1844 return true;
1845 }
1847 void m_solver_OnTestingIteration(object sender, TestingIterationArgs<T> e)
1848 {
1849 if (OnTestingIteration != null)
1850 OnTestingIteration(sender, e);
1851 }
1853 void m_solver_OnTrainingIteration(object sender, TrainingIterationArgs<T> e)
1854 {
1855 if (OnTrainingIteration != null)
1856 OnTrainingIteration(sender, e);
1857 }
1859 void m_solver_OnSnapshot(object sender, SnapshotArgs e)
1860 {
1861 if (OnSnapshot != null)
1862 OnSnapshot(sender, e);
1863 }
1876 public void Train(int nIterationOverride = -1, int nTrainingTimeLimitInMinutes = 0, TRAIN_STEP step = TRAIN_STEP.NONE, double dfLearningRateOverride = 0, bool bReset = false)
1877 {
1878 m_lastPhaseRun = Phase.TRAIN;
1880 if (nIterationOverride == -1)
1881 nIterationOverride = m_settings.MaximumIterationOverride;
1883 if (bReset)
1884 m_solver.Reset();
1886 m_solver.TrainingTimeLimitInMinutes = nTrainingTimeLimitInMinutes;
1887 m_solver.TrainingIterationOverride = nIterationOverride;
1890 if (dfLearningRateOverride > 0)
1891 m_solver.LearningRateOverride = dfLearningRateOverride;
1893 try
1894 {
1895 if (m_rgGpu.Count > 1)
1896 {
1897 if (nTrainingTimeLimitInMinutes > 0)
1898 {
1899 m_log.WriteLine("You have a training time-limit of " + nTrainingTimeLimitInMinutes.ToString("N0") + " minutes. Multi-GPU training is not supported when a training time-limit is imposed.");
1900 return;
1901 }
1903 m_log.WriteLine("Starting multi-GPU training on GPUs: " + listToString(m_rgGpu));
1904 NCCL<T> nccl = new NCCL<T>(m_cuda, m_log, m_solver, m_rgGpu[0], 0, null);
1905 nccl.Run(m_rgGpu, m_solver.TrainingIterationOverride);
1906 }
1907 else
1908 {
1909 m_solver.Solve(-1, null, null, step);
1910 }
1911 }
1912 catch (Exception excpt)
1913 {
1914 throw excpt;
1915 }
1916 finally
1917 {
1918 m_solver.LearningRateOverride = 0;
1919 }
1920 }
1922 private string listToString(List<int> rg)
1923 {
1924 string strOut = "";
1926 for (int i = 0; i < rg.Count; i++)
1927 {
1928 strOut += rg[i].ToString();
1930 if (i < rg.Count - 1)
1931 strOut += ", ";
1932 }
1934 return strOut;
1935 }
1942 public double Test(int nIterationOverride = -1)
1943 {
1944 m_lastPhaseRun = Phase.TEST;
1946 if (nIterationOverride == -1)
1947 nIterationOverride = m_settings.TestingIterationOverride;
1949 m_solver.TestingIterationOverride = nIterationOverride;
1951 return m_solver.TestAll();
1952 }
1962 {
1964 throw new Exception("Custom input is only supported by MODEL based datasets!");
1966 m_lastPhaseRun = Phase.RUN;
1968 UpdateRunWeights(false, false);
1970 double dfThreshold = customInput.GetPropertyAsDouble("Threshold", 0.2);
1971 int nMax = customInput.GetPropertyAsInt("Max", 80);
1972 int nK = customInput.GetPropertyAsInt("K", 1);
1973 PropertySet res;
1975 if (customInput.GetPropertyAsBool("Temporal", false))
1976 {
1977 res = RunModel(customInput);
1978 res.SetProperty("Temporal", "True");
1979 }
1980 else
1981 {
1982 res = Run(customInput, nK, dfThreshold, nMax);
1983 }
1985 return res;
1986 }
1996 {
1997 m_lastPhaseRun = Phase.RUN;
1999 UpdateRunWeights(false, false);
2001 double dfThreshold = customInput.GetPropertyAsDouble("Threshold", 0.2);
2002 int nMax = customInput.GetPropertyAsInt("Max", 80);
2003 int nK = customInput.GetPropertyAsInt("K", 1);
2005 if (customInput.GetProperty("Temporal") == "True")
2006 return RunModelEx(customInput);
2008 throw new Exception("TestManyEx currently only supports temporal testing.");
2009 }
2023 public List<Tuple<SimpleDatum, ResultCollection>> TestMany(int nCount, bool bOnTrainingSet, bool bOnTargetSet = false, DB_ITEM_SELECTION_METHOD imgSelMethod = DB_ITEM_SELECTION_METHOD.RANDOM, int nImageStartIdx = 0, DateTime? dtImageStartTime = null, double? dfThreshold = null)
2024 {
2025 Dictionary<int, int> rgMissedThreshold = new Dictionary<int, int>();
2027 m_lastPhaseRun = Phase.RUN;
2029 UpdateRunWeights(false);
2031 m_log.CHECK_GT(nCount, 0, "You must select at least 1 image to train on!");
2033 Stopwatch sw = new Stopwatch();
2034 DB_LABEL_SELECTION_METHOD? lblSelMethod = null;
2036 if (imgSelMethod == DB_ITEM_SELECTION_METHOD.NONE)
2039 Tuple<DB_LABEL_SELECTION_METHOD, DB_ITEM_SELECTION_METHOD> sel = m_db.GetSelectionMethod();
2043 int nSrcId = (bOnTrainingSet) ? m_dataSet.TrainingSource.ID : m_dataSet.TestingSource.ID;
2044 string strSrc = (bOnTrainingSet) ? m_dataSet.TrainingSourceName : m_dataSet.TestingSourceName;
2045 string strSet = (bOnTrainingSet) ? "training" : "test";
2046 int nCorrectCount = 0;
2047 Dictionary<int, int> rgCorrectCounts = new Dictionary<int, int>();
2048 Dictionary<int, int> rgLabelTotals = new Dictionary<int, int>();
2049 Dictionary<int, Dictionary<int, int>> rgDetectedCounts = new Dictionary<int, Dictionary<int, int>>();
2051 if (bOnTargetSet && m_project.DatasetTarget != null)
2052 {
2055 strSet = (bOnTrainingSet) ? "target training" : "target test";
2056 }
2058 sw.Start();
2060 m_log.WriteHeader("Test Many (" + nCount.ToString() + ") - on " + strSet + " '" + strSrc + "'");
2062 LabelMappingParameter labelMapping = null;
2063 Net<T> net = m_solver.TestingNet;
2064 if (net == null)
2065 net = m_solver.TrainingNet;
2067 AccuracyParameter accuracyParam = null;
2068 foreach (Layer<T> layer in net.layers)
2069 {
2070 if (layer.type == LayerParameter.LayerType.LABELMAPPING)
2071 {
2072 labelMapping = layer.layer_param.labelmapping_param;
2073 break;
2074 }
2075 else if (layer.type == LayerParameter.LayerType.ACCURACY ||
2076 layer.type == LayerParameter.LayerType.ACCURACY_DECODE ||
2077 layer.type == LayerParameter.LayerType.ACCURACY_ENCODING)
2078 {
2079 accuracyParam = layer.layer_param.accuracy_param;
2080 }
2081 }
2083 if (nImageStartIdx < 0)
2084 nImageStartIdx = 0;
2086 List<SimpleDatum> rgImg = null;
2087 if (dtImageStartTime.HasValue && dtImageStartTime.Value > DateTime.MinValue)
2088 {
2089 m_log.WriteLine("INFO: Starting test many at images with time " + dtImageStartTime.Value.ToString() + " or later...");
2090 rgImg = m_db.GetItemsFromTime(nSrcId, dtImageStartTime.Value, nCount);
2091 if (nCount > rgImg.Count)
2092 nCount = rgImg.Count;
2094 if (nCount == 0)
2095 throw new Exception("No images found after time '" + dtImageStartTime.Value.ToString() + "'. Make sure to use the LOAD_ALL image loading method when running TestMany after a specified time.");
2096 }
2098 List<Tuple<SimpleDatum, ResultCollection>> rgrgResults = new List<Tuple<SimpleDatum, ResultCollection>>();
2099 int nTotalCount = 0;
2100 int nMidPoint = 0;
2101 bool bPad = true;
2102 int nImgCount = nCount;
2104 Blob<T> blobData = null;
2106 try
2107 {
2108 SimpleDatum sd = null;
2109 List<int> rgOriginalRunNetInputShape = null;
2111 if (m_net.input_blobs != null && m_net.input_blobs.Count > 0)
2112 rgOriginalRunNetInputShape = Utility.Clone<int>(m_net.input_blobs[0].shape());
2114 for (int i = 0; i < nCount; i++)
2115 {
2116 if (m_evtCancel.WaitOne(0))
2117 {
2118 m_log.WriteLine("Test Many aborted!");
2119 return null;
2120 }
2123 sd = (rgImg != null) ? rgImg[i] : m_db.QueryItem(nSrcId, nImageStartIdx + i, lblSelMethod, imgSelMethod, null, m_settings.ItemDbLoadDataCriteria, m_settings.ItemDbLoadDebugData);
2126 {
2127 m_log.WriteLine("WARNING: Image size mismatch! Current image size " + sd.Width.ToString() + " x " + sd.Height.ToString() + " does not match the dataset image size " + m_dataSet.TrainingSource.Width.ToString() + " x " + m_dataSet.TrainingSource.Height.ToString() + "!");
2128 continue;
2129 }
2131 m_dataTransformer.TransformLabel(sd);
2133 if (!sd.GetDataValid(false))
2134 {
2135 Trace.WriteLine("You should not be here.");
2136 throw new Exception("NO DATA!");
2137 }
2139 // Create blob masks images (when enabled) during the data transform.
2140 blobData = CreateDataBlob(sd, blobData, bPad);
2142 List<int> rgIgnoreLabels = null;
2143 if (accuracyParam != null && accuracyParam.ignore_labels != null)
2144 rgIgnoreLabels = accuracyParam.ignore_labels;
2146 List<ResultCollection> rgrgResults1 = Run(blobData, false, false, int.MaxValue, rgIgnoreLabels);
2147 ResultCollection rgResults = rgrgResults1[0];
2149 // If masked, mask the sd to match the actual input in the blobData.
2150 m_dataTransformer.MaskImage(sd);
2151 rgrgResults.Add(new Tuple<SimpleDatum, ResultCollection>(sd, rgResults));
2153 if (rgResults.ResultType == ResultCollection.RESULT_TYPE.MULTIBOX)
2154 {
2155 Dictionary<int, List<Result>> rgLabeledResults = new Dictionary<int, List<Result>>();
2156 Dictionary<int, int> rgLabeledOrder = new Dictionary<int, int>();
2158 int nIdx = 0;
2159 foreach (Result result in rgResults.ResultsSorted)
2160 {
2161 if (!rgLabeledResults.ContainsKey(result.Label))
2162 {
2163 rgLabeledResults.Add(result.Label, new List<Result>());
2164 rgLabeledOrder.Add(result.Label, nIdx);
2165 nIdx++;
2166 }
2168 rgLabeledResults[result.Label].Add(result);
2169 }
2171 List<Tuple<int, List<Result>>> rgBestResults = new List<Tuple<int, List<Result>>>();
2172 List<int> rgDetectedLabels = rgLabeledOrder.OrderBy(p => p.Value).Select(p => p.Key).ToList();
2174 if (sd.annotation_group != null)
2175 {
2176 rgDetectedLabels = rgDetectedLabels.Take(sd.annotation_group.Count).ToList();
2178 for (int j = 0; j < sd.annotation_group.Count; j++)
2179 {
2180 int nExpectedLabel = sd.annotation_group[j].group_label;
2182 if (!rgCorrectCounts.ContainsKey(nExpectedLabel))
2183 rgCorrectCounts.Add(nExpectedLabel, 0);
2185 if (!rgLabelTotals.ContainsKey(nExpectedLabel))
2186 rgLabelTotals.Add(nExpectedLabel, 1);
2187 else
2188 rgLabelTotals[nExpectedLabel]++;
2190 if (rgDetectedLabels.Contains(nExpectedLabel))
2191 {
2192 rgCorrectCounts[nExpectedLabel]++;
2193 nCorrectCount++;
2194 }
2196 nTotalCount++;
2197 }
2198 }
2199 else
2200 {
2201 m_log.WriteLine("WARNING: No annotation data found in image with ID = " + sd.ImageID.ToString());
2202 }
2203 }
2204 else
2205 {
2206 int nDetectedLabel = rgResults.DetectedLabel;
2207 int nExpectedLabel = sd.Label;
2209 if (!dfThreshold.HasValue || rgResults.DetectedLabelOutput >= dfThreshold.Value)
2210 {
2211 if (rgResults.ResultsOriginal.Count % 2 != 0)
2212 nMidPoint = (int)Math.Floor(rgResults.ResultsOriginal.Count / 2.0);
2215 if (labelMapping != null)
2216 {
2217 if (m_dataTransformer.param.label_mapping.Active)
2218 m_log.FAIL("You can use either the LabelMappingLayer or the DataTransformer label_mapping, but not both!");
2220 nExpectedLabel = labelMapping.MapLabel(nExpectedLabel);
2221 }
2223 if (!rgCorrectCounts.ContainsKey(nExpectedLabel))
2224 rgCorrectCounts.Add(nExpectedLabel, 0);
2226 if (!rgLabelTotals.ContainsKey(nExpectedLabel))
2227 rgLabelTotals.Add(nExpectedLabel, 1);
2228 else
2229 rgLabelTotals[nExpectedLabel]++;
2231 if (nExpectedLabel == nDetectedLabel)
2232 {
2233 nCorrectCount++;
2234 rgCorrectCounts[nExpectedLabel]++;
2235 }
2237 if (!rgDetectedCounts.ContainsKey(nExpectedLabel))
2238 rgDetectedCounts.Add(nExpectedLabel, new Dictionary<int, int>());
2240 if (!rgDetectedCounts[nExpectedLabel].ContainsKey(nDetectedLabel))
2241 rgDetectedCounts[nExpectedLabel].Add(nDetectedLabel, 0);
2243 rgDetectedCounts[nExpectedLabel][nDetectedLabel]++;
2245 nTotalCount++;
2246 }
2247 else
2248 {
2249 if (!rgMissedThreshold.ContainsKey(nExpectedLabel))
2250 rgMissedThreshold.Add(nExpectedLabel, 0);
2252 rgMissedThreshold[nExpectedLabel]++;
2253 }
2254 }
2256 double dfPct = ((double)i / (double)nCount);
2257 m_log.Progress = dfPct;
2259 if (sw.ElapsedMilliseconds > 1000)
2260 {
2261 m_log.WriteLine("processing test many at " + dfPct.ToString("P"));
2262 sw.Stop();
2263 sw.Reset();
2264 sw.Start();
2265 }
2266 }
2268 // Resize inputs back to unpaded.
2269 if (rgOriginalRunNetInputShape != null)
2270 {
2271 if (m_net.input_blobs != null && m_net.input_blobs.Count > 0)
2272 m_net.input_blobs[0].Reshape(rgOriginalRunNetInputShape);
2273 }
2274 }
2275 finally
2276 {
2277 if (blobData != null)
2278 {
2279 blobData.Dispose();
2280 blobData = null;
2281 }
2282 }
2284 double dfCorrectPct = (nTotalCount == 0) ? 0 : ((double)nCorrectCount / (double)nTotalCount);
2286 m_log.WriteLine("Test Many Completed.");
2287 m_log.WriteLine(" " + dfCorrectPct.ToString("P") + " correct detections.");
2288 m_log.WriteLine(" " + (nTotalCount - nCorrectCount).ToString("N") + " incorrect detections.");
2290 foreach (KeyValuePair<int, int> kv in rgCorrectCounts.OrderBy(p => p.Key).ToList())
2291 {
2292 nCount = 0;
2294 foreach (KeyValuePair<int, int> kv1 in rgLabelTotals)
2295 {
2296 if (kv1.Key == kv.Key)
2297 {
2298 nCount = kv1.Value;
2299 break;
2300 }
2301 }
2303 if (nCount > 0)
2304 {
2305 string strSecondDetection = "";
2306 if (rgDetectedCounts.ContainsKey(kv.Key))
2307 {
2308 List<KeyValuePair<int, int>> rgDetectedCountsSorted = rgDetectedCounts[kv.Key].OrderByDescending(p => p.Value).ToList();
2309 if (rgDetectedCountsSorted.Count > 1)
2310 {
2311 strSecondDetection = " (secondary detections: " + rgDetectedCountsSorted[1].Key.ToString();
2313 if (rgDetectedCountsSorted.Count > 2)
2314 strSecondDetection += " and " + rgDetectedCountsSorted[2].Key.ToString();
2316 strSecondDetection += ")";
2317 }
2318 }
2320 dfCorrectPct = ((double)kv.Value / (double)nCount);
2321 m_log.WriteLine("Label #" + kv.Key.ToString() + " had " + dfCorrectPct.ToString("P") + " correct detections out of " + nCount.ToString("N0") + " items with this label." + strSecondDetection);
2322 }
2323 }
2325 if (nMidPoint > 0)
2326 {
2327 int nTotalBelow = 0;
2328 int nCorrectBelow = 0;
2329 int nTotalAbove = 0;
2330 int nCorrectAbove = 0;
2331 int nTotalBelowAndAbove = 0;
2332 int nCorrectBelowAndAbove = 0;
2334 List<KeyValuePair<int, int>> rgLabelTotalsList = rgLabelTotals.OrderBy(p => p.Key).ToList();
2335 List<KeyValuePair<int, int>> rgCorrectCountsList = rgCorrectCounts.OrderBy(p => p.Key).ToList();
2337 for (int i = 0; i < rgLabelTotalsList.Count; i++)
2338 {
2339 if (i < nMidPoint)
2340 {
2341 nTotalBelow += rgLabelTotalsList[i].Value;
2342 nCorrectBelow += rgCorrectCountsList[i].Value;
2343 nTotalBelowAndAbove += rgLabelTotalsList[i].Value;
2344 nCorrectBelowAndAbove += rgCorrectCountsList[i].Value;
2345 }
2346 else if (i > nMidPoint)
2347 {
2348 nTotalAbove += rgLabelTotalsList[i].Value;
2349 nCorrectAbove += rgCorrectCountsList[i].Value;
2350 nTotalBelowAndAbove += rgLabelTotalsList[i].Value;
2351 nCorrectBelowAndAbove += rgCorrectCountsList[i].Value;
2352 }
2353 }
2355 dfCorrectPct = (nTotalBelow == 0) ? 0 : nCorrectBelow / (double)nTotalBelow;
2356 m_log.WriteLine("Correct below midpoint of " + nMidPoint.ToString() + " = " + dfCorrectPct.ToString("P"));
2357 dfCorrectPct = (nTotalAbove == 0) ? 0 : nCorrectAbove / (double)nTotalAbove;
2358 m_log.WriteLine("Correct above midpoint of " + nMidPoint.ToString() + " = " + dfCorrectPct.ToString("P"));
2359 dfCorrectPct = (nTotalBelowAndAbove == 0) ? 0 : nCorrectBelowAndAbove / (double)nTotalBelowAndAbove;
2360 m_log.WriteLine("Correct below and above midpoint of " + nMidPoint.ToString() + " = " + dfCorrectPct.ToString("P"));
2361 }
2363 if (rgMissedThreshold.Count > 0)
2364 {
2365 m_log.WriteLine("---Missed Threshold Items---");
2367 int nTotal = 0;
2368 foreach (KeyValuePair<int, int> kv in rgMissedThreshold)
2369 {
2370 m_log.WriteLine("Expected Label " + kv.Key.ToString() + ": " + kv.Value.ToString() + " items missed threshold (" + ((double)kv.Value/nImgCount).ToString("P") + ").");
2371 nTotal += kv.Value;
2372 }
2374 m_log.WriteLine("A total of " + nTotal.ToString() + " items did not meet the threshold of " + dfThreshold.Value.ToString() + ", (" + ((double)nTotal / nImgCount).ToString("P") + ")");
2375 }
2377 return rgrgResults;
2378 }
2386 public ResultCollection Run(int nImageIdx, bool bPad = true)
2387 {
2389 m_dataTransformer.TransformLabel(sd);
2390 return Run(sd, true, bPad);
2391 }
2399 public List<ResultCollection> Run(List<int> rgImageIdx, ref Blob<T> blob)
2400 {
2401 List<SimpleDatum> rgSd = new List<SimpleDatum>();
2403 foreach (int nImageIdx in rgImageIdx)
2404 {
2406 m_dataTransformer.TransformLabel(sd);
2407 rgSd.Add(sd);
2408 }
2410 return Run(rgSd, ref blob, false, int.MaxValue);
2411 }
2418 public List<ResultCollection> Run(List<int> rgImageIdx)
2419 {
2420 List<SimpleDatum> rgSd = new List<SimpleDatum>();
2422 if (m_dataSet == null)
2423 throw new Exception("Running on indexes requires a full Load that includes loading the dataset.");
2425 foreach (int nImageIdx in rgImageIdx)
2426 {
2428 m_dataTransformer.TransformLabel(sd);
2429 rgSd.Add(sd);
2430 }
2432 Blob<T> blob = null;
2433 List<ResultCollection> rgRes = Run(rgSd, ref blob);
2435 if (blob != null)
2436 blob.Dispose();
2438 return rgRes;
2439 }
2441 private int getCount(List<int> rg)
2442 {
2443 int nCount = 1;
2445 foreach (int nDim in rg)
2446 {
2447 nCount *= nDim;
2448 }
2450 return nCount;
2451 }
2460 public Blob<T> CreateDataBlob(SimpleDatum d, Blob<T> blob = null, bool bPad = true)
2461 {
2462 if (m_dataTransformer == null)
2463 {
2464 if (blob != null)
2465 blob.SetData(d, true);
2466 else
2467 blob = new Blob<T>(m_cuda, m_log, d, true);
2468 }
2469 else
2470 {
2471 if (blob == null)
2472 blob = new Blob<T>(m_cuda, m_log);
2474 Datum datum = new Datum(d);
2476 List<int> rgShape = m_dataTransformer.InferBlobShape(datum);
2478 if (bPad)
2479 rgShape[0] = 2;
2481 int nCount = getCount(rgShape);
2482 blob.Reshape(rgShape);
2483 blob.Padded = bPad;
2485 if (m_rgRunData == null || m_rgRunData.Length != nCount)
2486 m_rgRunData = new T[nCount];
2488 T[] rgData = m_dataTransformer.Transform(datum);
2489 Array.Copy(rgData, 0, m_rgRunData, 0, rgData.Length);
2491 blob.mutable_cpu_data = m_rgRunData;
2493 m_dataTransformer.SetRange(blob);
2494 }
2496 return blob;
2497 }
2507 public ResultCollection Run(SimpleDatum d, bool bSort, bool bUseSolverNet, bool bPad = true)
2508 {
2509 if (m_net == null)
2510 throw new Exception("The Run net has not been created!");
2512 ResultCollection result = null;
2513 Blob<T> blob = null;
2515 try
2516 {
2517 blob = CreateDataBlob(d, null, bPad);
2518 BlobCollection<T> colBottom = new BlobCollection<T>() { blob };
2519 double dfLoss = 0;
2521 BlobCollection<T> colResults;
2522 LayerParameter.LayerType lastLayerType;
2524 if (bUseSolverNet)
2525 {
2526 lastLayerType = m_solver.TrainingNet.layers[m_solver.TrainingNet.layers.Count - 1].type;
2527 colResults = m_solver.TrainingNet.Forward(colBottom, out dfLoss, bPad);
2528 }
2529 else
2530 {
2531 lastLayerType = m_net.layers[m_net.layers.Count - 1].type;
2532 colResults = m_net.Forward(colBottom, out dfLoss, bPad);
2533 }
2535 if (blob.Padded)
2536 {
2537 List<int> rgShape = Utility.Clone<int>(colResults[0].shape());
2538 rgShape[0]--;
2539 if (rgShape[0] <= 0)
2540 rgShape[0] = 1;
2541 colResults[0].Reshape(rgShape);
2542 }
2544 List<Result> rgResults = new List<Result>();
2545 float[] rgData = Utility.ConvertVecF<T>(colResults[0].update_cpu_data());
2547 if (colResults[0].type == BLOB_TYPE.MULTIBBOX)
2548 {
2549 int nNum = rgData.Length / 7;
2551 for (int n = 0; n < nNum; n++)
2552 {
2553 int i = (int)rgData[(n * 7)];
2554 int nLabel = (int)rgData[(n * 7) + 1];
2555 double dfScore = rgData[(n * 7) + 2];
2556 double[] rgExtra = new double[4];
2557 rgExtra[0] = rgData[(n * 7) + 3]; // xmin
2558 rgExtra[1] = rgData[(n * 7) + 4]; // ymin
2559 rgExtra[2] = rgData[(n * 7) + 5]; // xmax
2560 rgExtra[3] = rgData[(n * 7) + 6]; // ymax
2562 rgResults.Add(new Result(nLabel, dfScore, rgExtra));
2563 }
2564 }
2565 else
2566 {
2567 for (int i = 0; i < rgData.Length; i++)
2568 {
2569 double dfProb = rgData[i];
2570 rgResults.Add(new Result(i, dfProb));
2571 }
2572 }
2574 result = new ResultCollection(rgResults, lastLayerType);
2575 if (m_db != null && m_db.GetVersion() != DB_VERSION.TEMPORAL)
2577 }
2578 catch (Exception excpt)
2579 {
2580 throw excpt;
2581 }
2582 finally
2583 {
2584 if (blob != null)
2585 blob.Dispose();
2586 }
2588 return result;
2589 }
2599 public List<ResultCollection> Run(List<SimpleDatum> rgSd, ref Blob<T> blob, bool bUseSolverNet = false, int nMax = int.MaxValue)
2600 {
2601 m_log.CHECK(m_dataTransformer != null, "The data transformer is not initialized!");
2603 if (m_net == null)
2604 throw new Exception("The Run net has not been created!");
2606 List<ResultCollection> rgFinalResults = new List<ResultCollection>();
2607 int nBatchSize = rgSd.Count;
2608 int nChannels = rgSd[0].Channels;
2609 int nHeight = rgSd[0].Height;
2610 int nWidth = rgSd[0].Width;
2611 List<T> rgDataInput = new List<T>();
2613 if (blob == null)
2614 blob = new common.Blob<T>(m_cuda, m_log, nBatchSize, nChannels, nHeight, nWidth, false);
2616 int nCount = 0;
2617 for (int i=0; i<rgSd.Count && i < nMax; i++)
2618 {
2619 rgDataInput.AddRange(m_dataTransformer.Transform(rgSd[i]));
2620 nCount++;
2621 }
2623 blob.Reshape(nCount, nChannels, nHeight, nWidth);
2624 blob.mutable_cpu_data = rgDataInput.ToArray();
2625 m_dataTransformer.SetRange(blob);
2627 BlobCollection<T> colBottom = new BlobCollection<T>() { blob };
2628 double dfLoss = 0;
2630 BlobCollection<T> colResults;
2631 LayerParameter.LayerType lastLayerType;
2633 if (bUseSolverNet)
2634 {
2635 lastLayerType = m_solver.TrainingNet.layers[m_net.layers.Count - 1].type;
2636 m_solver.TrainingNet.SetEnablePassthrough(true);
2637 colResults = m_solver.TrainingNet.Forward(colBottom, out dfLoss);
2638 m_solver.TrainingNet.SetEnablePassthrough(false);
2639 }
2640 else
2641 {
2642 lastLayerType = m_net.layers[m_net.layers.Count - 1].type;
2643 colResults = m_net.Forward(colBottom, out dfLoss, true);
2644 }
2646 T[] rgDataOutput = colResults[0].update_cpu_data();
2647 int nOutputCount = rgDataOutput.Length / rgSd.Count;
2649 for (int i = 0; i < rgSd.Count && i < nMax; i++)
2650 {
2651 List<Result> rgResults = new List<Result>();
2653 for (int j = 0; j < nOutputCount; j++)
2654 {
2655 int nIdx = i * nOutputCount + j;
2656 double dfProb = (double)Convert.ChangeType(rgDataOutput[nIdx], typeof(double));
2657 rgResults.Add(new Result(j, dfProb));
2658 }
2660 ResultCollection result = new ResultCollection(rgResults, lastLayerType);
2662 if (m_db != null && m_dataSet != null && m_db.GetVersion() != DB_VERSION.TEMPORAL)
2665 rgFinalResults.Add(result);
2666 }
2668 return rgFinalResults;
2669 }
2680 public List<ResultCollection> Run(Blob<T> blob, bool bSort = true, bool bUseSolverNet = false, int nMax = int.MaxValue, List<int> rgIgnoreLabels = null)
2681 {
2682 m_log.CHECK(m_dataTransformer != null, "The data transformer is not initialized!");
2684 if (m_net == null)
2685 throw new Exception("The Run net has not been created!");
2687 if (m_dataSet == null && (m_loadToRunShape == null || m_loadToRunShape.dim.Count < 4))
2688 throw new Exception("Cannot determine the blob shape, you must either load with a database, or use LoadToRun before calling Run with a Blob. When using LoadToRun, the shape must have at least 4 dimensions.");
2690 List<ResultCollection> rgFinalResults = new List<ResultCollection>();
2691 int nBatchSize = blob.num;
2692 int nChannels = (m_dataSet != null) ? m_dataSet.TestingSource.Channels : m_loadToRunShape.dim[1];
2693 if (blob.channels != nChannels)
2694 throw new Exception("The blob channels must match those of the testing dataset which has channels = " + m_dataSet.TestingSource.Channels.ToString());
2696 int nHeight = (m_dataSet != null) ? m_dataSet.TestingSource.Height : m_loadToRunShape.dim[2];
2697 int nWidth = (m_dataSet != null) ? m_dataSet.TestingSource.Width : m_loadToRunShape.dim[3];
2699 if (m_dataTransformer.param.resize_param != null && m_dataTransformer.param.resize_param.Active)
2700 {
2701 List<int> rgShape = m_dataTransformer.InferBlobShape(nChannels, nWidth, nHeight);
2702 nHeight = rgShape[2];
2703 nWidth = rgShape[3];
2704 }
2706 if (blob.height != nHeight)
2707 throw new Exception("The blob height must match those of the testing dataset which has height = " + nHeight.ToString());
2709 if (blob.width != nWidth)
2710 throw new Exception("The blob width must match those of the testing dataset which as width = " + nWidth.ToString());
2712 m_dataTransformer.SetRange(blob);
2714 BlobCollection<T> colBottom = new BlobCollection<T>() { blob };
2715 double dfLoss = 0;
2717 BlobCollection<T> colResults;
2718 LayerParameter.LayerType lastLayerType;
2720 if (bUseSolverNet)
2721 {
2722 lastLayerType = m_solver.TrainingNet.layers[m_net.layers.Count - 1].type;
2723 m_solver.TrainingNet.SetEnablePassthrough(true);
2724 colResults = m_solver.TrainingNet.Forward(colBottom, out dfLoss);
2725 m_solver.TrainingNet.SetEnablePassthrough(false);
2726 }
2727 else
2728 {
2729 lastLayerType = m_net.layers[m_net.layers.Count - 1].type;
2730 colResults = m_net.Forward(colBottom, out dfLoss, true);
2731 }
2733 float[] rgData = Utility.ConvertVecF<T>(colResults[0].update_cpu_data());
2734 int nOutputCount = rgData.Length / blob.num;
2736 int nNum = blob.num;
2737 if (blob.Padded)
2738 nNum--;
2742 for (int n = 0; n < nNum && n < nMax; n++)
2743 {
2744 List<Result> rgResults = new List<Result>();
2746 if (colResults[0].type == BLOB_TYPE.MULTIBBOX)
2747 {
2748 int i = (int)rgData[(n * 7)];
2749 int nLabel = (int)rgData[(n * 7) + 1];
2750 double dfScore = rgData[(n * 7) + 2];
2751 double[] rgExtra = new double[4];
2752 rgExtra[0] = rgData[(n * 7) + 3]; // xmin
2753 rgExtra[1] = rgData[(n * 7) + 4]; // ymin
2754 rgExtra[2] = rgData[(n * 7) + 5]; // xmax
2755 rgExtra[3] = rgData[(n * 7) + 6]; // ymax
2757 rgResults.Add(new Result(nLabel, dfScore, rgExtra));
2758 }
2759 else
2760 {
2761 for (int j = 0; j < nOutputCount; j++)
2762 {
2763 int nIdx = n * nOutputCount + j;
2764 double dfProb = rgData[nIdx];
2766 if (rgIgnoreLabels != null && rgIgnoreLabels.Contains(j))
2767 {
2768 if (resType == ResultCollection.RESULT_TYPE.DISTANCES)
2769 dfProb = double.MaxValue;
2770 else
2771 dfProb = 0;
2772 }
2774 rgResults.Add(new Result(j, dfProb));
2775 }
2776 }
2778 ResultCollection result = new ResultCollection(rgResults, lastLayerType);
2779 if (m_db != null && m_db.GetVersion() != DB_VERSION.TEMPORAL)
2782 rgFinalResults.Add(result);
2783 }
2785 return rgFinalResults;
2786 }
2798 public ResultCollection Run(Bitmap img, bool bSort = true, bool bPad = true)
2799 {
2800 if (m_net == null)
2801 throw new Exception("The Run net has not been created!");
2803 int nChannels = m_inputShape.dim[1];
2805 if (typeof(T) == typeof(double))
2806 return Run(ImageData.GetImageDataD(img, nChannels, false, -1), bSort, bPad);
2807 else
2808 return Run(ImageData.GetImageDataF(img, nChannels, false, -1), bSort, bPad);
2809 }
2818 public ResultCollection Run(SimpleDatum d, bool bSort = true, bool bPad = true)
2819 {
2820 return Run(d, bSort, false, bPad);
2821 }
2829 {
2830 Net<T> net = m_net;
2831 Phase phase = Phase.TRAIN;
2833 if (customInput != null)
2834 {
2835 string strPhase = customInput.GetProperty("Phase", false);
2836 if (!string.IsNullOrEmpty(strPhase))
2837 {
2838 if (strPhase == Phase.TRAIN.ToString())
2839 phase = Phase.TRAIN;
2840 else if (strPhase == Phase.TEST.ToString())
2841 phase = Phase.TEST;
2842 else if (strPhase == Phase.RUN.ToString())
2843 phase = Phase.RUN;
2844 }
2845 }
2847 m_log.WriteLine("INFO: Running TestMany with the " + phase.ToString() + " phase.");
2848 net = GetInternalNet(phase);
2850 BlobCollection<T> colTop = net.Forward();
2852 PropertySet res = new PropertySet();
2854 foreach (Blob<T> blob in colTop)
2855 {
2856 string strName = blob.Name;
2857 float[] rgData = Utility.ConvertVecF<T>(blob.mutable_cpu_data);
2858 byte[] rgBytes = blob.ToByteArray();
2860 res.SetPropertyBlob(strName, rgBytes);
2861 res.SetPropertyInt(strName, (int)blob.type);
2862 }
2864 return res;
2865 }
2873 {
2874 Net<T> net = m_net;
2875 Phase phase = Phase.TRAIN;
2877 if (customInput != null)
2878 {
2879 string strPhase = customInput.GetProperty("Phase", false);
2880 if (!string.IsNullOrEmpty(strPhase))
2881 {
2882 if (strPhase == Phase.TRAIN.ToString())
2883 phase = Phase.TRAIN;
2884 else if (strPhase == Phase.TEST.ToString())
2885 phase = Phase.TEST;
2886 else if (strPhase == Phase.RUN.ToString())
2887 phase = Phase.RUN;
2888 }
2889 }
2891 m_log.WriteLine("INFO: Running TestMany with the " + phase.ToString() + " phase.");
2892 net = GetInternalNet(phase);
2894 return net.Forward();
2895 }
2906 public PropertySet Run(PropertySet customInput, int nK = 1, double dfThreshold = 0.01, int nMax = 80, bool bBeamSearch = false)
2907 {
2908 m_log.CHECK_GE(nK, 1, "The K must be >= 1!");
2910 BlobCollection<T> colBottom = null;
2911 Layer<T> layerInput = null;
2912 string strInput = customInput.GetProperty("InputData");
2913 string[] rgstrInput = strInput.Split('|');
2914 List<string> rgstrOutput = new List<string>();
2915 int nSeqLen = nMax;
2917 foreach (string strInput1 in rgstrInput)
2918 {
2919 PropertySet input = new PropertySet("InputData=" + strInput1);
2920 string strOut = "\n";
2922 if (!bBeamSearch)
2923 {
2924 foreach (Layer<T> layer in m_net.layers)
2925 {
2926 int nSeqLen1;
2927 colBottom = layer.PreProcessInput(input, out nSeqLen1, colBottom);
2928 if (colBottom != null)
2929 {
2930 layerInput = layer;
2931 nSeqLen = nSeqLen1;
2932 break;
2933 }
2934 }
2936 if (colBottom == null)
2937 throw new Exception("At least one layer must support the 'PreprocessInput' method!");
2939 double dfLoss;
2940 int nAxis = 1;
2941 BlobCollection<T> colTop = m_net.Forward(colBottom, out dfLoss, layerInput.SupportsPostProcessingLogits);
2942 Blob<T> blobTop = colTop[0];
2943 Layer<T> softmax = null;
2945 if (m_net.layers[m_net.layers.Count - 1].layer_param.type == LayerParameter.LayerType.SOFTMAX)
2946 {
2947 softmax = m_net.layers[m_net.layers.Count - 1];
2948 nAxis = softmax.layer_param.softmax_param.axis;
2949 }
2950 else if (m_net.layers[m_net.layers.Count - 2].layer_param.type == LayerParameter.LayerType.SOFTMAX)
2951 {
2952 softmax = m_net.layers[m_net.layers.Count - 2];
2953 nAxis = softmax.layer_param.softmax_param.axis;
2954 }
2956 List<string> rgOutput = new List<string>();
2957 List<Tuple<string, int, double>> res;
2958 int nCount = 0;
2959 Stopwatch sw = new Stopwatch();
2960 sw.Start();
2962 for (int i = 0; i < nMax; i++)
2963 {
2964 if (layerInput.SupportsPostProcessingLogits)
2965 {
2966 nK = 10;
2967 blobTop = m_net.FindBlob("logits");
2968 if (blobTop == null)
2969 throw new Exception("Could not find the 'logits' blob!");
2970 res = layerInput.PostProcessLogitsOutput(i, blobTop, softmax, nAxis, nK);
2971 }
2972 else
2973 res = layerInput.PostProcessOutput(blobTop);
2975 if (!layerInput.PreProcessInput(null, res[0].Item2, colBottom))
2976 break;
2978 rgOutput.Add(res[0].Item1);
2980 colTop = m_net.Forward(colBottom, out dfLoss, layerInput.SupportsPostProcessingLogits);
2981 blobTop = colTop[0];
2982 nCount++;
2984 if (sw.Elapsed.TotalMilliseconds > 1000)
2985 {
2986 double dfPct = (double)nCount / nMax;
2987 m_log.WriteLine("Generating response at " + dfPct.ToString("P") + "...");
2988 }
2989 }
2991 strOut = "";
2992 foreach (string str in rgOutput)
2993 {
2994 strOut += str;
2995 }
2996 }
2997 else
2998 {
2999 BeamSearch<T> search = new BeamSearch<T>(m_net);
3001 List<Tuple<double, bool, List<Tuple<string, int, double>>>> res = search.Search(input, nK, dfThreshold, nMax);
3003 if (res.Count > 0)
3004 {
3005 for (int i = 0; i < res[0].Item3.Count; i++)
3006 {
3007 strOut += res[0].Item3[i].Item1.ToString() + " ";
3008 }
3009 }
3011 strOut = strOut.Trim();
3012 }
3014 rgstrOutput.Add(strOut);
3015 }
3017 string strFinal = "";
3018 foreach (string str in rgstrOutput)
3019 {
3020 strFinal += str + "|";
3021 }
3023 strFinal = clean(strFinal);
3024 return new PropertySet("Results=" + strFinal);
3025 }
3033 {
3034 double dfLoss;
3035 return m_net.Forward(colBottom, out dfLoss, true);
3036 }
3038 private string clean(string strFinal)
3039 {
3040 string str = "";
3042 foreach (char ch in strFinal)
3043 {
3044 if (ch == ';')
3045 str += ' ';
3046 else
3047 str += ch;
3048 }
3050 return str;
3051 }
3060 public Bitmap GetTestImage(Phase phase, out int nLabel, out string strLabel)
3061 {
3062 if (m_db.GetVersion() == DB_VERSION.TEMPORAL)
3063 throw new Exception("The GetTestImage only works with non-temporal databases.");
3065 int nSrcId = (phase == Phase.TRAIN) ? m_dataSet.TrainingSource.ID : m_dataSet.TestingSource.ID;
3067 m_dataTransformer.TransformLabel(sd);
3069 nLabel = sd.Label;
3070 strLabel = ((IXImageDatabaseBase)m_db).GetLabelName(nSrcId, nLabel);
3072 if (strLabel == null || strLabel.Length == 0)
3073 strLabel = nLabel.ToString();
3075 return new Bitmap(ImageData.GetImage(new Datum(sd), null));
3076 }
3084 public Bitmap GetTestImage(Phase phase, int nLabel)
3085 {
3086 if (m_db.GetVersion() == DB_VERSION.TEMPORAL)
3087 throw new Exception("The GetTestImage only works with non-temporal databases.");
3089 int nSrcId = (phase == Phase.TRAIN) ? m_dataSet.TrainingSource.ID : m_dataSet.TestingSource.ID;
3091 m_dataTransformer.TransformLabel(sd);
3093 return new Bitmap(ImageData.GetImage(new Datum(sd), null));
3094 }
3106 public Bitmap GetTargetImage(int nSrcId, int nIdx, out int nLabel, out string strLabel, out byte[] rgCriteria, out SimpleDatum.DATA_FORMAT fmtCriteria)
3107 {
3108 if (m_db.GetVersion() == DB_VERSION.TEMPORAL)
3109 throw new Exception("The GetTestImage only works with non-temporal databases.");
3112 m_dataTransformer.TransformLabel(sd);
3114 nLabel = sd.Label;
3115 strLabel = ((IXImageDatabaseBase)m_db).GetLabelName(nSrcId, nLabel);
3117 if (strLabel == null || strLabel.Length == 0)
3118 strLabel = nLabel.ToString();
3120 rgCriteria = sd.DataCriteria;
3121 fmtCriteria = sd.DataCriteriaFormat;
3123 return new Bitmap(ImageData.GetImage(new Datum(sd), null));
3124 }
3135 public Bitmap GetTargetImage(int nImageID, out int nLabel, out string strLabel, out byte[] rgCriteria, out SimpleDatum.DATA_FORMAT fmtCriteria)
3136 {
3137 if (m_db.GetVersion() == DB_VERSION.TEMPORAL)
3138 throw new Exception("The GetTestImage only works with non-temporal databases.");
3142 nLabel = d.Label;
3143 strLabel = ((IXImageDatabaseBase)m_db).GetLabelName(m_dataSet.TestingSource.ID, nLabel);
3145 if (strLabel == null || strLabel.Length == 0)
3146 strLabel = nLabel.ToString();
3148 rgCriteria = d.DataCriteria;
3149 fmtCriteria = d.DataCriteriaFormat;
3151 return new Bitmap(ImageData.GetImage(new Datum(d), null));
3152 }
3159 {
3160 if (m_db == null)
3161 throw new Exception("The image database is null!");
3163 if (m_solver == null)
3164 throw new Exception("The solver is null - make sure that you are loaded for training.");
3166 if ( == null)
3167 throw new Exception("The solver net is null - make sure that you are loaded for training.");
3169 string strSrc =;
3170 int nSrcId = m_db.GetSourceID(strSrc);
3172 return m_db.GetItemMean(nSrcId);
3173 }
3180 {
3181 return m_dataSet;
3182 }
3188 public byte[] GetWeights()
3189 {
3190 if (m_net != null)
3191 {
3192 m_net.ShareTrainedLayersWith(;
3193 return m_net.SaveWeights(m_persist);
3194 }
3195 else
3196 {
3197 return;
3198 }
3199 }
3206 public void UpdateRunWeights(bool bOutputStatus = false, bool bVerifyWeights = true)
3207 {
3208 bool? bLogEnabled = null;
3210 try
3211 {
3212 if (!bOutputStatus)
3213 {
3214 bLogEnabled = m_log.IsEnabled;
3215 m_log.Enable = false;
3216 }
3218 if (m_net != null && m_bOwnRunNet)
3219 {
3220 try
3221 {
3222 int nCopyCount = 0;
3224 if ( == m_net.learnable_parameters.Count)
3225 {
3226 for (int i = 0; i <; i++)
3227 {
3228 Blob<T> b =[i];
3229 Blob<T> bRun = m_net.learnable_parameters[i];
3231 m_log.CHECK(b.Name == bRun.Name, "The learnable parameter names do not match!");
3232 if (bRun.CopyFrom(b, false, true) != 0)
3233 nCopyCount++;
3234 }
3235 }
3236 else
3237 {
3238 for (int i = 0; i < m_net.learnable_parameters.Count; i++)
3239 {
3240 Blob<T> bRun = m_net.learnable_parameters[i];
3241 Blob<T> b =;
3243 if (b == null)
3244 m_log.FAIL("Could not find the run blob '" + bRun.Name + "' in the solver net!");
3246 bRun.CopyFrom(b, false, true);
3247 }
3248 }
3250 if (nCopyCount == 0)
3251 loadWeights(m_net,;
3252 }
3253 catch (Exception excpt)
3254 {
3255 m_log.WriteLine("WARNING: " + excpt.Message + ", attempting to load with legacy (slower method)...");
3256 loadWeights(m_net,;
3257 }
3258 }
3260 if (bVerifyWeights)
3261 {
3262 if (!CompareWeights(m_net,
3263 m_log.WriteLine("WARNING: The run weights differ from the training weights!");
3264 }
3265 }
3266 finally
3267 {
3268 if (bLogEnabled.HasValue)
3269 m_log.Enable = bLogEnabled.Value;
3270 }
3271 }
3277 public void UpdateWeights(byte[] rgWeights)
3278 {
3279 if (m_net != null)
3280 loadWeights(m_net, rgWeights);
3282 m_log.WriteLine("Updating weights in solver.");
3284 List<string> rgExpectedShapes = new List<string>();
3286 foreach (Blob<T> b in m_solver.TrainingNet.learnable_parameters)
3287 {
3288 rgExpectedShapes.Add(b.shape_string);
3289 }
3291 bool bLoadDiffs;
3292 m_persist.LoadWeights(rgWeights, rgExpectedShapes, m_solver.TrainingNet.learnable_parameters, false, out bLoadDiffs);
3294 m_solver.WeightsUpdated = true;
3295 m_log.WriteLine("Solver weights updated.");
3296 }
3304 public Net<T> CreateNet(byte[] rgWeights, CudaDnn<T> cudaOverride = null)
3305 {
3306 if (cudaOverride == null)
3307 cudaOverride = m_cuda;
3309 NetParameter p = (m_net != null) ? m_net.ToProto(false) :;
3310 Net<T> net = new Net<T>(cudaOverride, m_log, p, m_evtCancel, m_db);
3311 loadWeights(net, rgWeights);
3312 return net;
3313 }
3328 public Net<T> GetInternalNet(Phase phase = Phase.RUN)
3329 {
3330 if (phase == Phase.ALL)
3331 phase = m_lastPhaseRun;
3333 if (phase == Phase.NONE)
3334 phase = Phase.RUN;
3336 if (phase == Phase.TEST)
3337 return (m_solver != null) ? m_solver.TestingNet : null;
3339 else if (phase == Phase.TRAIN)
3340 return (m_solver != null) ? m_solver.TrainingNet : null;
3342 return m_net;
3343 }
3350 {
3351 return m_solver;
3352 }
3358 public void Snapshot(bool bUpdateDatabase = true)
3359 {
3360 m_solver.Snapshot(true, false, bUpdateDatabase);
3361 }
3371 public static void ResetDevice(int nDeviceID)
3372 {
3373 }
3380 public static string GetLicenseTextEx(string strOtherLicenses)
3381 {
3382 string str = Properties.Resources.LICENSE;
3383 int nYear = DateTime.Now.Year;
3385 if (nYear > 2016)
3386 str = replaceMacro(str, "$$YEAR$$", "-" + nYear.ToString());
3387 else
3388 str = replaceMacro(str, "$$YEAR$$", "");
3390 if (strOtherLicenses != null && strOtherLicenses.Length > 0)
3391 str = replaceMacro(str, "$$OTHERLICENSES$$", strOtherLicenses);
3393 return fixupReturns(str);
3394 }
3401 public string GetLicenseText(string strOtherLicenses)
3402 {
3403 return GetLicenseTextEx(strOtherLicenses);
3404 }
3413 public bool VerifyCompute(string strExtra = null, int nDeviceID = -1, bool bThrowException = true)
3414 {
3415 if (m_cuda == null)
3416 throw new Exception("You must initialize the MyCaffeControl with an instance of CudaDnn<T>, or Load a new project.");
3418 int nMinMajor;
3419 int nMinMinor;
3420 string strDll = m_cuda.GetRequiredCompute(out nMinMajor, out nMinMinor);
3422 if (nDeviceID == -1)
3423 nDeviceID = m_cuda.GetDeviceID();
3425 string strDevName = m_cuda.GetDeviceName(nDeviceID);
3426 string strCompute = parse(strDevName, "compute ", ")");
3427 string[] rgstr = strCompute.Split('.');
3428 string strMajor = rgstr[0];
3429 string strMinor = rgstr[1];
3430 if (strMajor == null || strMinor == null)
3431 throw new Exception("Could not find the current device's major and minor version information!");
3433 int nMajor = int.Parse(strMajor);
3434 int nMinor = int.Parse(strMinor);
3436 if (nMajor < nMinMajor || (nMajor == nMinMajor && nMinor < nMinMinor))
3437 {
3438 string strErr = "The device " + nDeviceID.ToString() + " - '" + strDevName + " does not meet the minimum compute of '" + nMinMajor.ToString() + "." + nMinMinor.ToString() + "' required by the CudaDnnDll used ('" + strDll + "')!";
3439 if (!string.IsNullOrEmpty(strExtra))
3440 strErr += strExtra;
3441 throw new Exception(strErr);
3442 }
3444 return true;
3445 }
3447 private string parse(string str, string strT1, string strT2)
3448 {
3449 int nPos = str.IndexOf(strT1);
3450 if (nPos < 0)
3451 return null;
3453 str = str.Substring(nPos + strT1.Length);
3454 nPos = str.IndexOf(strT2);
3455 if (nPos < 0)
3456 return null;
3458 return str.Substring(0, nPos).Trim();
3459 }
3461 private static string replaceMacro(string str, string strMacro, string strReplacement)
3462 {
3463 int nPos = str.IndexOf(strMacro);
3465 if (nPos < 0)
3466 return str;
3468 string strA = str.Substring(0, nPos);
3470 strA += strReplacement;
3471 strA += str.Substring(nPos + strMacro.Length);
3473 return strA;
3474 }
3476 private static string fixupReturns(string str)
3477 {
3478 string strOut = "";
3480 foreach (char ch in str)
3481 {
3482 if (ch == '\n')
3483 strOut += "\r\n";
3484 else
3485 strOut += ch;
3486 }
3488 return strOut;
3489 }
3496 public Blob<T> CreateBlob(string strName)
3497 {
3498 Blob<T> b = new Blob<T>(m_cuda, m_log);
3499 b.Name = strName;
3500 return b;
3501 }
3508 public long CreateExtension(string strExtensionDLLPath)
3509 {
3510 return m_cuda.CreateExtension(strExtensionDLLPath);
3511 }
3517 public void FreeExtension(long hExtension)
3518 {
3519 m_cuda.FreeExtension(hExtension);
3520 }
3528 public T[] RunExtension(long hExtension, long lfnIdx, T[] rgParam)
3529 {
3530 return m_cuda.RunExtension(hExtension, lfnIdx, rgParam);
3531 }
3539 public double[] RunExtensionD(long hExtension, long lfnIdx, double[] rgParam)
3540 {
3541 T[] rgP = (rgParam == null) ? null : Utility.ConvertVec<T>(rgParam);
3542 T[] rg = m_cuda.RunExtension(hExtension, lfnIdx, rgP);
3544 if (rg == null)
3545 return null;
3547 return Utility.ConvertVec<T>(rg);
3548 }
3556 public float[] RunExtensionF(long hExtension, long lfnIdx, float[] rgParam)
3557 {
3558 T[] rgP = (rgParam == null) ? null : Utility.ConvertVec<T>(rgParam);
3559 T[] rg = m_cuda.RunExtension(hExtension, lfnIdx, rgP);
3561 if (rg == null)
3562 return null;
3564 return Utility.ConvertVecF<T>(rg);
3565 }
3566 }
The MyCaffeControl is the main object used to manage all training, testing and running of the MyCaffe...
DatasetDescriptor m_dataSet
The dataset descriptor of the dataset used in the image database.
string GetDeviceName(int nDeviceID)
Returns the device name of a given device ID.
EventHandler< TrainingIterationArgs< T > > OnTrainingIteration
The OnTrainingIteration event fires at the end of each training iteration.
static FileVersionInfo Version
Get the file version of the MyCaffe assembly running.
long CreateExtension(string strExtensionDLLPath)
Create and load a new extension DLL.
ResultCollection Run(SimpleDatum d, bool bSort=true, bool bPad=true)
Run on a given Datum.
void Snapshot(bool bUpdateDatabase=true)
The Snapshot function forces a snapshot to occur.
void RemoveCancelOverrideByName(string strEvtCancel)
Remove a cancel override.
bool Load(Phase phase, string strSolver, string strModel, byte[] rgWeights, DB_LABEL_SELECTION_METHOD? labelSelectionOverride=null, DB_ITEM_SELECTION_METHOD? itemSelectionOverride=null, bool bResetFirst=false, IXDatabaseBase db=null, bool bUseDb=true, bool bCreateRunNet=true, string strStage=null, bool bEnableMemTrace=false)
Load a project and optionally the MyCaffeImageDatabase.
bool EnableTesting
Enable/disable testing. For example reinforcement learning does not use testing.
bool? EnableDetailedNanDetection
When enabled (requires EnableBlobDebugging = true), the detailed Nan (and Infinity) detection is pero...
void LoadToRun(string strModel, byte[] rgWeights, BlobShape shape, SimpleDatum sdMean=null, TransformationParameter transParam=null, bool bForceBackward=false, bool bConvertToRunNet=true)
The LoadToRun method loads the MyCaffeControl for running only (e.g. deployment).
static void ResetDevice(int nDeviceID)
Reset the device at the given device ID.
void dispose()
Releases all GPU and Host resources used by the CaffeControl.
List< int > ActiveGpus
Returns a list of Active GPU's used by the control.
EventHandler< TestingIterationArgs< T > > OnTestingIteration
The OnTestingIteration event fires at the end of each testing iteration.
AutoResetEvent m_evtForceSnapshot
An auto-reset event used to force a snapshot.
MyCaffeControl(SettingsCaffe settings, Log log, CancelEvent evtCancel, AutoResetEvent evtSnapshot=null, AutoResetEvent evtForceTest=null, ManualResetEvent evtPause=null, List< int > rgGpuId=null, string strCudaPath="", bool bCreateCudaDnn=false, ConnectInfo ci=null)
The MyCaffeControl constructor.
SettingsCaffe Settings
Returns the settings used to create the control.
PropertySet TestMany(PropertySet customInput)
Test on custom input data.
Bitmap GetTargetImage(int nSrcId, int nIdx, out int nLabel, out string strLabel, out byte[] rgCriteria, out SimpleDatum.DATA_FORMAT fmtCriteria)
Retrives the image at a given index within the Testing data set.
bool ReInitializeParameters(WEIGHT_TARGET target, params string[] rgstrLayers)
Re-initializes each of the specified layers by re-running the filler (if any) specified by the layer....
ProjectEx m_project
The active project (if any).
ResultCollection Run(Bitmap img, bool bSort=true, bool bPad=true)
Run on a given bitmap image.
void SetOnTestingStartOverride(EventHandler onTestingStart)
Sets the root solver's onTestingStart event function triggered on the start of each testing pass.
static NetParameter CreateNetParameterForRunning(BlobShape shape, string strModel, out TransformationParameter transform_param, Stage stage=Stage.NONE, bool bSkipLossLayer=false, bool bMaintainBatchSize=false)
Creates a net parameter for the RUN phase.
int GetDeviceCount()
Returns the total number of devices installed on this computer.
int CurrentIteration
Returns the current iteration.
EventHandler< SnapshotArgs > OnSnapshot
The OnSnapshot event fires each time a snap-shot is taken.
void SetOnTestOverride(EventHandler< TestArgs > onTest)
Sets the root solver's onTest event function.
ResultCollection Run(SimpleDatum d, bool bSort, bool bUseSolverNet, bool bPad=true)
Run on a given Datum.
void Train(int nIterationOverride=-1, int nTrainingTimeLimitInMinutes=0, TRAIN_STEP step=TRAIN_STEP.NONE, double dfLearningRateOverride=0, bool bReset=false)
Train the network a set number of iterations.
bool? EnableLayerDebugging
Enable/disable layer debugging which causes each layer to check for NAN/INF on each forward/backward ...
string LabelQueryEpochs
Returns a string describing the label query epochs observed during training.
void AddCancelOverride(CancelEvent evtCancel)
Adds a cancel override.
SimpleDatum GetItemMean()
Returns the item (e.g., image or temporal item) mean used by the solver network used during training.
Net< T > GetInternalNet(Phase phase=Phase.RUN)
Returns the internal net based on the Phase specified: TRAIN, TEST or RUN.
SettingsCaffe m_settings
The settings used to configure the control.
void UpdateRunWeights(bool bOutputStatus=false, bool bVerifyWeights=true)
Loads the weights from the training net into the Net used for running.
void CopyWeightsFrom(MyCaffeControl< T > src)
Copy the learnable parameter data from the source MyCaffeControl into this one.
Blob< T > CreateBlob(string strName)
Create an unsized blob and set its name.
void FreeExtension(long hExtension)
Free an existing extension and unload it.
byte[] GetWeights()
Retrieves the weights of the training network.
ManualResetEvent m_evtPause
An auto-reset event used to pause training.
NetParameter createNetParameterForRunning(ProjectEx p, out TransformationParameter transform_param)
Creates a net parameter for the RUN phase.
ConnectInfo DatasetConnectInfo
Returns the dataset connection information, if used (default = null).
bool CompareWeights(Net< T > net1, Net< T > net2)
The CompareWeights method compares the weights held in two different Net objects.
string CurrentDevice
Returns the name of the current device used.
Solver< T > GetInternalSolver()
Get the internal solver.
bool VerifyCompute(string strExtra=null, int nDeviceID=-1, bool bThrowException=true)
VerifyCompute compares the current compute of the current device (or device specified) against the re...
IXDatabaseBase m_db
The image database.
void CopyGradientsFrom(MyCaffeControl< T > src)
Copy the learnable parameter diffs from the source MyCaffeControl into this one.
void UpdateWeights(byte[] rgWeights)
Loads the training Net with new weights.
List< ResultCollection > Run(Blob< T > blob, bool bSort=true, bool bUseSolverNet=false, int nMax=int.MaxValue, List< int > rgIgnoreLabels=null)
Run on a Blob of data.
string GetLicenseText(string strOtherLicenses)
Returns the license text for MyCaffe.
string ActiveLabelCounts
Returns a string describing the active label counts observed during training.
float[] RunExtensionF(long hExtension, long lfnIdx, float[] rgParam)
Run a function on an existing extension using the float base type.
BlobCollection< T > Run(BlobCollection< T > colBottom)
Run the network forward on the bottom blobs.
void PrepareImageMeans(ProjectEx prj)
Prepare the testing image mean by copying the training image mean if the testing image mean is missin...
Bitmap GetTestImage(Phase phase, int nLabel)
Retrieves a random image from either the training or test set depending on the Phase specified.
ResultCollection Run(int nImageIdx, bool bPad=true)
Run on a given image in the MyCaffeImageDatabase based on its image index.
Bitmap GetTargetImage(int nImageID, out int nLabel, out string strLabel, out byte[] rgCriteria, out SimpleDatum.DATA_FORMAT fmtCriteria)
Retrives the image with a given ID.
double ApplyUpdate(int nIteration)
Directs the solver to apply the leanred blob diffs to the weights using the solver's learning rate an...
AutoResetEvent m_evtForceTest
An auto-reset event used to force a test cycle.
MyCaffeControl< T > Clone(int nGpuID)
Clone the current instance of the MyCaffeControl creating a second instance.
Net< T > CreateNet(byte[] rgWeights, CudaDnn< T > cudaOverride=null)
Creates a new Net, loads the weights specified into it and returns it.
IXPersist< T > Persist
Returns the persist used to load and save weights.
BlobCollection< T > TestManyEx(PropertySet customInput)
Test on custom input data.
bool Load(Phase phase, ProjectEx p, DB_LABEL_SELECTION_METHOD? labelSelectionOverride=null, DB_ITEM_SELECTION_METHOD? itemSelectionOverride=null, bool bResetFirst=false, IXDatabaseBase db=null, bool bUseDb=true, bool bCreateRunNet=true, string strStage=null, bool bEnableMemTrace=false)
Load a project and optionally the MyCaffeImageDatabase.
DataTransformer< T > m_dataTransformer
The data transformer used to transform data.
CancelEvent m_evtCancel
The CancelEvent used to cancel training and testing operations.
void AddCancelOverrideByName(string strEvtCancel)
Adds a cancel override.
string LabelQueryHitPercents
Returns a string describing the label query hit percentages observed during training.
CudaDnn< T > Cuda
Returns the CudaDnn connection used.
ProjectEx CurrentProject
Returns the name of the currently loaded project.
void SetOnTrainingStartOverride(EventHandler onTrainingStart)
Sets the root solver's onStart event function triggered on the start of each training pass.
Log m_log
The log used for output.
NetParameter createNetParameterForRunning(DatasetDescriptor ds, string strModel, out TransformationParameter transform_param, Stage stage=Stage.NONE)
Creates a net parameter for the RUN phase.
Blob< T > CreateDataBlob(SimpleDatum d, Blob< T > blob=null, bool bPad=true)
Create a data blob from a SimpleDatum by transforming the data and placing the results in the blob re...
Bitmap GetTestImage(Phase phase, out int nLabel, out string strLabel)
Retrieves a random image from either the training or test set depending on the Phase specified.
void Unload(bool bUnloadImageDb=true, bool bIgnoreExceptions=false)
Unload the currently loaded project, if any.
static string GetLicenseTextEx(string strOtherLicenses)
Returns the license text for MyCaffe.
string CurrentStage
Returns the stage under which the project was loaded, if any.
NetParameter createNetParameterForRunning(SimpleDatum sdMean, string strModel, out TransformationParameter transform_param, out int nC, out int nH, out int nW, Stage stage=Stage.NONE)
Creates a net parameter for the RUN phase.
void RemoveCancelOverride(CancelEvent evtCancel)
Remove a cancel override.
bool EnableVerboseStatus
Get/set whether or not to use verbose status. When enabled, the full status is output when loading a ...
double[] RunExtensionD(long hExtension, long lfnIdx, double[] rgParam)
Run a function on an existing extension using the double base type.
BlobCollection< T > RunModelEx(PropertySet customInput)
Run the model using data from the model itself - requires a Data layer with the RUN phase.
bool m_bDbOwner
Whether or not the control owns the image database.
List< ResultCollection > Run(List< SimpleDatum > rgSd, ref Blob< T > blob, bool bUseSolverNet=false, int nMax=int.MaxValue)
Run on a given list of Datum.
PropertySet RunModel(PropertySet customInput)
Run the model using data from the model itself - requires a Data layer with the RUN phase.
List< Tuple< SimpleDatum, ResultCollection > > TestMany(int nCount, bool bOnTrainingSet, bool bOnTargetSet=false, DB_ITEM_SELECTION_METHOD imgSelMethod=DB_ITEM_SELECTION_METHOD.RANDOM, int nImageStartIdx=0, DateTime? dtImageStartTime=null, double? dfThreshold=null)
Test on a number of images by selecting random images from the database, running them through the Run...
double Test(int nIterationOverride=-1)
Test the network a given number of iterations.
bool? EnableBreakOnFirstNaN
Enable/disable break training after first detecting a NaN.
NetParameter createNetParameterForRunning(BlobShape shape, string strModel, out TransformationParameter transform_param, Stage stage=Stage.NONE)
Creates a net parameter for the RUN phase.
string m_strCudaPath
The low-level path of the underlying CudaDnn DLL.
Phase LastPhase
Returns the last phase run (TRAIN, TEST or RUN).
bool? EnableBlobDebugging
Enable/disable blob debugging.
List< ResultCollection > Run(List< int > rgImageIdx, ref Blob< T > blob)
Run on a set of images in the MyCaffeImageDatabase based on their image indexes.
T[] RunExtension(long hExtension, long lfnIdx, T[] rgParam)
Run a function on an existing extension.
bool? EnableSingleStep
Enable/disable single step training.
PropertySet Run(PropertySet customInput, int nK=1, double dfThreshold=0.01, int nMax=80, bool bBeamSearch=false)
Run the model on custom input data.
List< ResultCollection > Run(List< int > rgImageIdx)
Run on a set of images in the MyCaffeImageDatabase based on their image indexes.
bool LoadLite(Phase phase, string strSolver, string strModel, byte[] rgWeights=null, bool bResetFirst=false, bool bCreateRunNet=true, SimpleDatum sdMean=null, string strStage=null, bool bEnableMemTrace=false)
Load a solver and model without using the MyCaffeImageDatabase.
int MaximumIteration
Returns the maximum iteration.
List< int > m_rgGpu
A list of the Device ID's used for training.
DatasetDescriptor GetDataset()
Returns the current dataset used when training and testing.
void Add(AnnotationGroupCollection col)
Add another AnnotationGroupCollection to this one.
Definition: Annotation.cs:369
int Count
Specifies the number of items in the collection.
Definition: Annotation.cs:350
The CancelEvent provides an extension to the manual cancel event that allows for overriding the manua...
Definition: CancelEvent.cs:17
bool RemoveCancelOverride(string strName)
Remove a new cancel override.
Definition: CancelEvent.cs:167
void AddCancelOverride(string strName)
Add a new cancel override.
Definition: CancelEvent.cs:83
string Name
Return the name of the cancel event.
Definition: CancelEvent.cs:263
bool WaitOne(int nMs=int.MaxValue)
Waits for the signal state to occur.
Definition: CancelEvent.cs:290
void Set()
Sets the event to the signaled state.
Definition: CancelEvent.cs:270
The ConnectInfo class specifies the server, database and username/password used to connect to a datab...
Definition: ConnectInfo.cs:14
The Datum class is a simple wrapper to the SimpleDatum class to ensure compatibility with the origina...
Definition: Datum.cs:12
The ImageData class is a helper class used to convert between Datum, other raw data,...
Definition: ImageData.cs:14
static Bitmap GetImage(SimpleDatum d, ColorMapper clrMap=null, List< int > rgClrOrder=null)
Converts a SimplDatum (or Datum) into an image, optionally using a ColorMapper.
Definition: ImageData.cs:506
static Datum GetImageDataF(Bitmap bmp, int nChannels, bool bDataIsReal, int nLabel, bool bUseLockBitmap=true, int[] rgFocusMap=null)
The GetImageDataF function converts a Bitmap into a Datum using the float type for real data.
Definition: ImageData.cs:181
static Datum GetImageDataD(Bitmap bmp, int nChannels, bool bDataIsReal, int nLabel, bool bUseLockBitmap=true, int[] rgFocusMap=null)
The GetImageDataD function converts a Bitmap into a Datum using the double type for real data.
Definition: ImageData.cs:44
The Log class provides general output in text form.
Definition: Log.cs:13
void CHECK(bool b, string str)
Test a flag for true.
Definition: Log.cs:227
bool IsEnabled
Returns whether or not the Log is enabled.
Definition: Log.cs:50
void WriteLine(string str, bool bOverrideEnabled=false, bool bHeader=false, bool bError=false, bool bDisable=false)
Write a line of output.
Definition: Log.cs:80
bool Enable
Enables/disables the Log. When disabled, the Log does not output any data.
Definition: Log.cs:42
void FAIL(string str)
Causes a failure which throws an exception with the desciptive text.
Definition: Log.cs:394
double Progress
Get/set the progress associated with the Log.
Definition: Log.cs:147
void CHECK_EQ(double df1, double df2, string str)
Test whether one number is equal to another.
Definition: Log.cs:239
void WriteHeader(string str)
Write a header as output.
Definition: Log.cs:109
void CHECK_GT(double df1, double df2, string str)
Test whether one number is greater than another.
Definition: Log.cs:299
void CHECK_GE(double df1, double df2, string str)
Test whether one number is greater than or equal to another.
Definition: Log.cs:287
The ProjectEx class manages a project containing the solver description, model description,...
Definition: ProjectEx.cs:15
int TargetDatasetID
Get/set the dataset ID of the target dataset (if exists), otherwise return 0.
Definition: ProjectEx.cs:915
RawProto CreateModelForRunning(string strName, int nNum, int nChannels, int nHeight, int nWidth, out RawProto protoTransform, out bool bSkipTransformParam, Stage stage=Stage.NONE, bool bSkipLossLayer=false)
Create a model description as a RawProto for running the Project.
Definition: ProjectEx.cs:1041
int ID
Returns the ID of the Project in the database.
Definition: ProjectEx.cs:533
ParameterDescriptorCollection Parameters
Returns any project parameters that may exist (if any).
Definition: ProjectEx.cs:812
DatasetDescriptor Dataset
Return the descriptor of the dataset used.
Definition: ProjectEx.cs:896
string? ModelDescription
Get/set the model description script used by the Project.
Definition: ProjectEx.cs:757
byte[] WeightsState
Get/set the weight state.
Definition: ProjectEx.cs:873
DatasetDescriptor DatasetTarget
Returns the target dataset (if exists) or null if it does not.
Definition: ProjectEx.cs:907
byte[] SolverState
Get/set the solver state.
Definition: ProjectEx.cs:864
Stage Stage
Return the stage under which the project was opened.
Definition: ProjectEx.cs:603
Specifies a key-value pair of properties.
Definition: PropertySet.cs:16
string GetProperty(string strName, bool bThrowExceptions=true)
Returns a property as a string value.
Definition: PropertySet.cs:146
int GetPropertyAsInt(string strName, int nDefault=0)
Returns a property as an integer value.
Definition: PropertySet.cs:287
bool GetPropertyAsBool(string strName, bool bDefault=false)
Returns a property as a boolean value.
Definition: PropertySet.cs:267
double GetPropertyAsDouble(string strName, double dfDefault=0)
Returns a property as an double value.
Definition: PropertySet.cs:307
void SetProperty(string strName, string strVal)
Sets a property in the property set to a value if it exists, otherwise it adds the new property.
Definition: PropertySet.cs:211
The RawProtoCollection class is a list of RawProto objects.
int Count
Returns the number of items in the collection.
The RawProto class is used to parse and output Google prototxt file data.
Definition: RawProto.cs:17
string Value
Get/set the value of the node.
Definition: RawProto.cs:79
RawProto FindChild(string strName)
Searches for a given node.
Definition: RawProto.cs:231
override string ToString()
Returns the RawProto as its full prototxt string.
Definition: RawProto.cs:681
static RawProto Parse(string str)
Parses a prototxt and places it in a new RawProto.
Definition: RawProto.cs:306
RawProtoCollection FindChildren(params string[] rgstrName)
Searches for all children with a given name in this node's children.
Definition: RawProto.cs:263
The Result class contains a single result.
Definition: Result.cs:14
int Label
Returns the label.
Definition: Result.cs:36
The SettingsCaffe defines the settings used by the MyCaffe CaffeControl.
Get/set the snapshot update method.
SettingsCaffe Clone()
Returns a copy of the SettingsCaffe object.
Get/set the image database loading method.
bool ItemDbLoadDebugData
Specifies whether or not to load the debug data from file (default = false).
string GpuIds
Get/set the default GPU ID's to use when training.
int MaximumIterationOverride
Get/set the maximum iteration override. When set, this overrides the training iterations specified in...
Get/set the version of the MyCaffeImageDatabase to use.
bool ItemDbLoadDataCriteria
Specifies whether or not to load the image criteria data from file (default = false).
int TestingIterationOverride
Get/set the testing iteration override. When set, this overrides the testing iterations specified in ...
The SimpleDatum class holds a data input within host memory.
Definition: SimpleDatum.cs:161
bool GetDataValid(bool bByType=true)
Returns true if the ByteData or RealDataD or RealDataF are not null, false otherwise.
int Channels
Return the number of channels of the data.
AnnotationGroupCollection annotation_group
When using annoations, each annotation group contains an annotation for a particular class used with ...
byte[] DataCriteria
Get/set data criteria associated with the data.
Defines the data format of the DebugData and DataCriteria when specified.
Definition: SimpleDatum.cs:223
int Width
Return the width of the data.
int ImageID
Returns the ID of the image in the database.
int Height
Return the height of the data.
DATA_FORMAT DataCriteriaFormat
Get/set the data format of the data criteria.
int Label
Return the known label of the data.
The Utility class provides general utility funtions.
Definition: Utility.cs:35
static double[] ConvertVec(float[] rgf)
Convert an array of float to an array of generics.
Definition: Utility.cs:550
int ID
Get/set the database ID of the item.
string Name
Get/set the name of the item.
The DatasetDescriptor class describes a dataset which contains both a training data source and testin...
bool? IsGym
Returns whether or not this dataset is from a Gym.
SourceDescriptor TrainingSource
Get/set the training data source.
string? TrainingSourceName
Returns the training source name, or null if not specifies.
bool? IsModelData
Returns whether or not this dataset is from the model itself.
SourceDescriptor TestingSource
Get/set the testing data source.
string? TestingSourceName
Returns the testing source name or null if not specified.
ParameterDescriptor Find(string strName)
Searches for a parameter by name in the collection.
The ParameterDescriptor class describes a parameter in the database.
override string ToString()
Creates the string representation of the descriptor.
string Value
Get/set the value of the item.
The SourceDescriptor class contains all information describing a data source.
override string ToString()
Return a string representation of thet SourceDescriptor.
int Height
Returns the height of each data item in the data source.
int Width
Returns the width of each data item in the data source.
int Channels
Returns the item colors - 1 channel = black/white, 3 channels = RGB color.
The BeamSearch uses the softmax output from the network and continually runs the net on each output (...
Definition: BeamSearch.cs:19
List< Tuple< double, bool, List< Tuple< string, int, double > > > > Search(PropertySet input, int nK, double dfThreshold=0.01, int nMax=80)
Perform the beam-search.
Definition: BeamSearch.cs:56
The BlobCollection contains a list of Blobs.
int Count
Returns the number of items in the collection.
void Reshape(int[] rgShape)
Reshapes all blobs in the collection to the given shape.
The Blob is the main holder of data that moves through the Layers of the Net.
Definition: Blob.cs:25
int channels
DEPRECIATED; legacy shape accessor channels: use shape(1) instead.
Definition: Blob.cs:800
int height
DEPRECIATED; legacy shape accessor height: use shape(2) instead.
Definition: Blob.cs:808
long mutable_gpu_data
Returns the data GPU handle used by the CudaDnn connection.
Definition: Blob.cs:1487
string shape_string
Returns a string describing the Blob's shape.
Definition: Blob.cs:657
T[] mutable_cpu_data
Get data from the GPU and bring it over to the host, or Set data from the Host and send it over to th...
Definition: Blob.cs:1461
Returns the BLOB_TYPE of the Blob.
Definition: Blob.cs:2761
byte[] ToByteArray()
Saves this Blob to a byte array.
Definition: Blob.cs:2436
void CopyFrom(Blob< T > src, int nSrcOffset, int nDstOffset, int nCount, bool bCopyData, bool bCopyDiff)
Copy from a source Blob.
Definition: Blob.cs:903
int width
DEPRECIATED; legacy shape accessor width: use shape(3) instead.
Definition: Blob.cs:816
T asum_data()
Compute the sum of absolute values (L1 norm) of the data.
Definition: Blob.cs:1706
int count()
Returns the total number of items in the Blob.
Definition: Blob.cs:739
void ReshapeLike(Blob< T > b, bool? bUseHalfSize=null)
Reshape this Blob to have the same shape as another Blob.
Definition: Blob.cs:648
string Name
Get/set the name of the Blob.
Definition: Blob.cs:2184
virtual void Dispose(bool bDisposing)
Releases all resources used by the Blob (including both GPU and Host).
Definition: Blob.cs:402
bool Padded
Get/set the padding state of the blob.
Definition: Blob.cs:284
int num
DEPRECIATED; legacy shape accessor num: use shape(0) instead.
Definition: Blob.cs:792
long gpu_data
Returns the data GPU handle used by the CudaDnn connection.
Definition: Blob.cs:1479
The CudaDnn object is the main interface to the Low-Level Cuda C++ DLL.
Definition: CudaDnn.cs:969
int GetDeviceID()
Returns the current device id set within Cuda.
Definition: CudaDnn.cs:2013
string GetRequiredCompute(out int nMinMajor, out int nMinMinor)
The GetRequiredCompute function returns the Major and Minor compute values required by the current Cu...
Definition: CudaDnn.cs:2216
void sub(int n, long hA, long hB, long hY, int nAOff=0, int nBOff=0, int nYOff=0, int nB=0)
Subtracts B from A and places the result in Y.
Definition: CudaDnn.cs:7312
string Path
Specifies the file path used to load the Low-Level Cuda DNN Dll file.
Definition: CudaDnn.cs:1924
void FreeExtension(long hExtension)
Free an instance of an Extension.
Definition: CudaDnn.cs:3474
void FreeHostBuffer(long hMem)
Free previously allocated host memory.
Definition: CudaDnn.cs:2602
int GetDeviceCount()
Query the number of devices (gpu's) installed.
Definition: CudaDnn.cs:2127
string GetDeviceName(int nDeviceID)
Query the name of a device.
Definition: CudaDnn.cs:2035
T[] RunExtension(long hExtension, long lfnIdx, T[] rgParam)
Run a function on the extension specified.
Definition: CudaDnn.cs:3489
long CreateExtension(string strExtensionDllPath)
Create an instance of an Extension DLL.
Definition: CudaDnn.cs:3456
virtual void Dispose(bool bDisposing)
Disposes this instance freeing up all of its host and GPU memory.
Definition: CudaDnn.cs:1612
The NCCL class manages the multi-GPU operations using the low-level NCCL functionality provided by th...
Definition: Parallel.cs:267
void Run(List< int > rgGpus, int nIterationOverride=-1)
Run the root Solver and coordinate with all other Solver's participating in the multi-GPU training.
Definition: Parallel.cs:351
Connects Layer's together into a direct acrylic graph (DAG) specified by a NetParameter
Definition: Net.cs:23
List< Layer< T > > layers
Returns the layers.
Definition: Net.cs:2003
BlobCollection< T > parameters
Returns the parameters.
Definition: Net.cs:2085
BlobCollection< T > Forward()
Run forward with the input Blob's already fed separately.
Definition: Net.cs:1445
BlobCollection< T > input_blobs
Returns the collection of input Blobs.
Definition: Net.cs:2201
virtual void Dispose(bool bDisposing)
Releases all resources (GPU and Host) used by the Net.
Definition: Net.cs:184
void LoadWeights(byte[] rgWeights, IXPersist< T > persist, List< string > inputWtInfo=null, List< string > targetWtInfo=null, string strSkipBlobType=null)
Loads new weights into the Net.
Definition: Net.cs:2510
NetParameter ToProto(bool bIncludeBlobs)
Writes the net to a proto.
Definition: Net.cs:1865
Blob< T > FindBlob(string strName)
Finds a Blob in the Net by name.
Definition: Net.cs:2592
byte[] SaveWeights(IXPersist< T > persist, bool bSaveDiff=false)
Save the weights to a byte array.
Definition: Net.cs:2541
BlobCollection< T > learnable_parameters
Returns the learnable parameters.
Definition: Net.cs:2117
void ShareTrainedLayersWith(Net< T > srcNet, bool bEnableLog=false)
For an already initialized net, implicitly compies (i.e., using no additional memory) the pre-trained...
Definition: Net.cs:1653
bool ReInitializeParameters(WEIGHT_TARGET target, params string[] rgstrLayers)
Re-initializes the blobs and each of the specified layers by re-running the filler (if any) specified...
Definition: Net.cs:2729
The PersistCaffe class is used to load and save weight files in the .caffemodel format.
Definition: PersistCaffe.cs:20
PersistCaffe(Log log, bool bFailOnFirstTry)
The PersistCaffe constructor.
Definition: PersistCaffe.cs:30
BlobCollection< T > LoadWeights(byte[] rgWeights, List< string > rgExpectedShapes, BlobCollection< T > colBlobs, bool bSizeToFit, out bool bLoadedDiffs, List< string > inputWtInfo=null, List< string > targetWtInfo=null, string strSkipBlobType=null)
Loads new weights into a BlobCollection
The ResultCollection contains the result of a given CaffeControl::Run.
Returns the result type of the result data: PROBABILITIES (Sigmoid), DISTANCES (Decode),...
List< Result > ResultsSorted
Returns the original results in sorted order.
double DetectedLabelOutput
Returns the detected label output depending on the result type (distance or probability) with a defau...
static RESULT_TYPE GetResultType(LayerParameter.LayerType type)
Get the result type based on the layer-type used.
void SetLabels(List< LabelDescriptor > rgLabels)
Sets the label names in the label dictionary lookup.
int DetectedLabel
Returns the detected label depending on the result type (distance or probability) with a default type...
Defines the type of result.
List< Result > ResultsOriginal
Returns the original results.
The SnapshotArgs is sent to the Solver::OnSnapshot event which fires each time the Solver::Snapshot m...
Definition: EventArgs.cs:416
Specifies the TestingIterationArgs sent to the Solver::OnTestingIteration, which is called at the end...
Definition: EventArgs.cs:216
The TrainingIterationArgs is sent to the Solver::OnTrainingIteration event that fires at the end of a...
Definition: EventArgs.cs:264
Applies common transformations to the input data, such as scaling, mirroring, subtracting the image m...
The Database class manages the actual connection to the physical database using Entity Framworks from...
Definition: Database.cs:23
The DatasetFactory manages the connection to the Database object.
SimpleDatum QueryImageMean(int nSrcId=0)
Return the SimpleDatum for the image mean from the open data source.
int GetRawImageMeanID(int nSrcId=0)
Returns the raw image ID for the image mean associated with a data source.
DatasetDescriptor LoadDataset(string strDataset, ConnectInfo ci=null)
Load a dataset descriptor from a dataset name.
bool CopyImageMean(string strSrcSrc, string strDstSrc)
Copy the raw image mean from one source to another.
SourceDescriptor LoadSource(string strSource)
Load the source descriptor from a data source name.
[V2 Image Database] The MyCaffeImageDatabase2 provides an enhanced in-memory image database used for ...
The MyCaffeImageDatabase provides an enhanced in-memory image database used for quick image retrieval...
static Tuple< DB_LABEL_SELECTION_METHOD, DB_ITEM_SELECTION_METHOD > GetSelectionMethod(SettingsCaffe s)
Returns the label/image selection methods based on the SettingsCaffe settings.
An interface for the units of computation which can be composed into a Net.
Definition: Layer.cs:31
LayerParameter.LayerType type
Returns the LayerType of this Layer.
Definition: Layer.cs:927
LayerParameter layer_param
Returns the LayerParameter for this Layer.
Definition: Layer.cs:899
virtual BlobCollection< T > PreProcessInput(PropertySet customInput, out int nSeqLen, BlobCollection< T > colBottom=null)
The PreprocessInput allows derivative data layers to convert a property set of input data into the bo...
Definition: Layer.cs:294
Specifies the parameters for the AccuracyLayer.
Specifies the shape of a Blob.
Definition: BlobShape.cs:15
List< int > dim
The blob shape dimensions.
Definition: BlobShape.cs:93
string source
When used with the DATA parameter, specifies the data 'source' within the database....
/b DEPRECIATED (use DataLayer DataLabelMappingParameter instead) Specifies the parameters for the Lab...
Specifies the base parameter for all layers.
void PrepareRunModel()
Prepare the layer settings for a run model.
LayerType type
Specifies the type of this LayerParameter.
SoftmaxParameter softmax_param
Returns the parameter set when initialized with LayerType.SOFTMAX
List< NetStateRule > include
Specifies the NetStateRule's for which this LayerParameter should be included.
AccuracyParameter accuracy_param
Returns the parameter set when initialized with LayerType.ACCURACY
string PrepareRunModelInputs()
Prepare model inputs for the run-net (if any are needed for the layer).
TransformationParameter transform_param
Returns the parameter set when initialized with LayerType.TRANSFORM
DataParameter data_param
Returns the parameter set when initialized with LayerType.DATA
Specifies the layer type.
LabelMappingParameter labelmapping_param
Returns the parameter set when initialized with LayerType.LABELMAPPING
Specifies the parameters use to create a Net
Definition: NetParameter.cs:18
static NetParameter FromProto(RawProto rp)
Parse a RawProto into a new instance of the parameter.
NetState state
The current 'state' of the network, including the phase, level and stage. Some layers may be included...
bool force_backward
Whether the network will force every layer to carry out backward operation. If set False,...
override RawProto ToProto(string strName)
Constructor for the parameter.
List< LayerParameter > layer
The layers that make up the net. Each of their configurations, including connectivity and behavior,...
static Dictionary< string, BlobShape > InputFromProto(RawProto rp)
Collect the inputs from the RawProto.
Phase phase
Specifies the Phase of the NetState.
Definition: NetState.cs:63
List< string > stage
Specifies the stages of the NetState.
Definition: NetState.cs:83
Specifies a NetStateRule used to determine whether a Net falls within a given include or exclude patt...
Definition: NetStateRule.cs:20
Phase phase
Set phase to require the NetState to have a particular phase (TRAIN or TEST) to meet this rule.
Definition: NetStateRule.cs:99
int axis
The axis along which to perform the softmax – may be negative to index from the end (e....
The SolverParameter is a parameter for the solver, specifying the train and test networks.
NetParameter net_param
Inline train net param, possibly combined with one or more test nets.
static SolverParameter FromProto(RawProto rp)
Parses a new SolverParameter from a RawProto.
Stores parameters used to apply transformation to the data layer's data.
bool use_imagedb_mean
Specifies whether to subtract the mean image from the image database, subtract the mean values,...
static TransformationParameter FromProto(RawProto rp)
Parses the parameter from a RawProto.
Specifies the parameters for the DecodeLayer and the AccuracyEncodingLayer.
An interface for classes that perform optimization on Nets - this class serves as the base class for ...
Definition: Solver.cs:28
void Dispose()
Discards the resources (GPU and Host) used by this Solver.
Definition: Solver.cs:218
EventHandler< TrainingIterationArgs< T > > OnTrainingIteration
The OnTrainingIteration event fires at the end of each training iteration.
Definition: Solver.cs:134
static SGDSolver< T > Create(CudaDnn< T > cuda, Log log, ProjectEx p, CancelEvent evtCancel, AutoResetEvent evtForceSnapshot, AutoResetEvent evtForceTest, IXDatabaseBase db, IXPersist< T > persist, int nSolverCount=1, int nSolverRank=0, Net< T > shareNet=null, onGetWorkspace getws=null, onSetWorkspace setws=null)
Create a new Solver based on the project containing the SolverParameter.
Definition: Solver.cs:1889
void Restore(byte[] rgWeights, byte[] rgState, string strSkipBlobTypes=null)
The restore method simply calls the RestoreSolverState method of the inherited class.
Definition: Solver.cs:1097
void Snapshot(bool bForced, bool bScheduled, bool bUpdateDatabase=true)
The snapshot function implements the basic snapshotting utility that stores the learned net....
Definition: Solver.cs:1115
int MaximumIteration
Returns the maximum training iterations.
Definition: Solver.cs:700
EventHandler< SnapshotArgs > OnSnapshot
The OnSnapshot event fires when the Solver detects that a snapshot is needed.
Definition: Solver.cs:130
bool EnableBlobDebugging
When enabled, the OnTrainingIteration event is set extra debugging information describing the state o...
Definition: Solver.cs:361
Net< T > net
Returns the main training Net.
Definition: Solver.cs:1229
bool ForceOnTrainingIterationEvent()
Force an OnTrainingIterationEvent to fire.
Definition: Solver.cs:236
double LearningRateOverride
Get/set the learning rate override. When 0, this setting is ignored.
Definition: Solver.cs:227
bool EnableTesting
When enabled, the training cycle calls TestAll periodically based on the SolverParameter....
Definition: Solver.cs:352
Net< T > TrainingNet
Returns the training Net used by the solver.
Definition: Solver.cs:445
EventHandler< TestingIterationArgs< T > > OnTestingIteration
The OnTestingIteration event fires at the end of each testing iteration.
Definition: Solver.cs:138
bool EnableBreakOnFirstNaN
When enabled (requires EnableBlobDebugging = true), the Solver immediately stop training upon detecti...
Definition: Solver.cs:382
Get/set the snapshot weight update method.
Definition: Solver.cs:301
int TrainingTimeLimitInMinutes
Get/set the training time limit in minutes. When set to 0, no time limit is imposed on training.
Definition: Solver.cs:292
void Reset()
Reset the iterations of the net.
Definition: Solver.cs:478
double TestAll(int nIterationOverride=-1)
Run a TestAll by running all test Nets.
Definition: Solver.cs:1322
string LabelQueryEpochs
Return the label query epochs for the active datasource.
Definition: Solver.cs:684
int TrainingIterationOverride
Get/set the training iteration override.
Definition: Solver.cs:1179
EventHandler OnTestStart
The OnTestStart event fires at the start of each testing iteration.
Definition: Solver.cs:150
bool WeightsUpdated
Get/set when the weights have been updated.
Definition: Solver.cs:413
bool EnableDetailedNanDetection
When enabled (requires EnableBlobDebugging = true), the detailed Nan (and Infinity) detection is pero...
Definition: Solver.cs:395
EventHandler< TestArgs > OnTest
When specified, the OnTest event fires during a TestAll and overrides the call to Test.
Definition: Solver.cs:146
int TestingIterationOverride
Get/set the testing iteration override.
Definition: Solver.cs:1188
virtual void Solve(int nIterationOverride=-1, byte[] rgWeights=null, byte[] rgState=null, TRAIN_STEP step=TRAIN_STEP.NONE)
The main entry of the solver function. In default, iter will be zero. Pass in a non-zero iter number ...
Definition: Solver.cs:744
EventHandler OnStart
The OnStart event fires at the start of each training iteration.
Definition: Solver.cs:118
string ActiveLabelCounts
Returns a string describing the labels detected in the training along with the % that each label has ...
Definition: Solver.cs:668
abstract double ApplyUpdate(int nIterationOverride=-1)
Make and apply the update value for the current iteration.
bool EnableSingleStep
When enabled (requires EnableBlobDebugging = true), the Solver only runs one training cycle.
Definition: Solver.cs:404
string LabelQueryHitPercents
Return the label query hit percentages for the active datasource.
Definition: Solver.cs:676
Net< T > TestingNet
Returns the testing Net used by the solver.
Definition: Solver.cs:431
bool EnableLayerDebugging
Enable/disable layer debugging which causes each layer to check for NAN/INF on each forward/backward ...
Definition: Solver.cs:373
int CurrentIteration
Returns the current training iteration.
Definition: Solver.cs:692
The Component class is a standard Microsoft.NET class that implements the IComponent interface and is...
Definition: Component.cs:18
The IXDatabaseBase interface defines the general interface to the in-memory database.
Definition: Interfaces.cs:444
SimpleDatum GetItem(int nItemID, params int[] rgSrcId)
Get the item (e.g., image or temporal item) with a given Raw Item ID.
SimpleDatum GetItemMean(int nSrcId)
Returns the item (e.g., image or temporal item) mean for a data source.
int GetSourceID(string strSrc)
Returns a data source ID given its name.
SimpleDatum QueryItem(int nSrcId, int nIdx, DB_LABEL_SELECTION_METHOD? labelSelectionOverride=null, DB_ITEM_SELECTION_METHOD? imageSelectionOverride=null, int? nLabel=null, bool bLoadDataCriteria=false, bool bLoadDebugData=false)
Query an image in a given data source.
Returns the label and image selection method used.
SimpleDatum QueryItemMean(int nSrcId)
Queries the item (e.g., image or temporal item) mean for a data source from the database on disk.
Sets the label and image selection methods.
void CleanUp(int nDsId=0, bool bForce=false)
Releases the image database, and if this is the last instance using the in-memory database,...
DB_VERSION GetVersion()
Returns the version of the MyCaffe Image Database being used.
List< SimpleDatum > GetItemsFromTime(int nSrcId, DateTime dtStart, int nQueryCount=int.MaxValue, string strFilterVal=null, int? nBoostVal=null, bool bBoostValIsExact=false)
Returns the array of items (e.g., images or temporal items) in the item set, possibly filtered with t...
The IXImageDatabase interface defines the eneral interface to the in-memory image database.
Definition: Interfaces.cs:1004
The IXImageDatabase2 interface defines the general interface to the in-memory image database (v2).
Definition: Interfaces.cs:1092
The IXImageDatabaseBase interface defines the general interface to the in-memory image database.
Definition: Interfaces.cs:878
The IXMyCaffeExtension interface allows for easy extension management of the low-level software that ...
Definition: Interfaces.cs:614
The IXMyCaffe interface contains functions used to perform MyCaffe operations that work with the MyCa...
Definition: Interfaces.cs:410
The IXMyCaffeNoDb interface contains functions used to perform MyCaffe operations that run in a light...
Definition: Interfaces.cs:564
The IXMyCaffeState interface contains functions related to the MyCaffeComponent state.
Definition: Interfaces.cs:258
The IXPersist interface is used by the CaffeControl to load and save weights.
Definition: Interfaces.cs:187
The descriptors namespace contains all descriptor used to describe various items stored within the da...
The MyCaffe.basecode contains all generic types used throughout MyCaffe.
Definition: Annotation.cs:12
Defines the item (e.g., image or temporal item) selection method.
Definition: Interfaces.cs:278
Defines the Phase under which to run a Net.
Definition: Interfaces.cs:61
Defines the image database version to use.
Definition: Interfaces.cs:397
Defines the label selection method.
Definition: Interfaces.cs:333
Specifies the stage underwhich to run a custom trainer.
Definition: Interfaces.cs:88
The MyCaffe.common namespace contains common MyCaffe classes.
Definition: BatchInput.cs:8
Specifies the initialization flags used when initializing CUDA.
Definition: CudaDnn.cs:207
Defines the tpe of data held by a given Blob.
Definition: Interfaces.cs:62
Defines the training stepping method (if any).
Definition: Interfaces.cs:131
Defines the type of weight to target in re-initializations.
Definition: Interfaces.cs:38
The namespace contains dataset creators used to create common testing datasets such as M...
Definition: BinaryFile.cs:16
The MyCaffe.db.image namespace contains all image database related classes.
Definition: Database.cs:18
The MyCaffe.layers namespace contains all layers that have a solidified code base,...
Definition: LayerFactory.cs:15
The MyCaffe.param.beta parameters are used by the MyCaffe.layer.beta layers.
The MyCaffe.param namespace contains parameters used to create models.
The MyCaffe.solvers namespace contains all solver classes, including the base Solver.
The MyCaffe namespace contains the main body of MyCaffe code that closesly tracks the C++ Caffe open-...
Definition: Annotation.cs:12