MyCaffe  1.12.2.41
Deep learning software for Windows C# programmers.
PersistCaffe.cs
1using MyCaffe.basecode;
2using System;
3using System.Collections.Generic;
4using System.Diagnostics;
5using System.IO;
6using System.Linq;
7using System.Text;
8using System.Threading.Tasks;
9using Google.Protobuf;
10using System.Collections;
11using MyCaffe.param;
12
13namespace MyCaffe.common
14{
19 public class PersistCaffe<T> : IXPersist<T>
20 {
21 Log m_log;
22 bool m_bFailOnFirstTry = false;
23 const string m_strWeightMyCaffeTag = "mycaffe.ai";
24
30 public PersistCaffe(Log log, bool bFailOnFirstTry)
31 {
32 m_log = log;
33 m_bFailOnFirstTry = bFailOnFirstTry;
34 }
35
40 public string MyCaffeTag
41 {
42 get { return m_strWeightMyCaffeTag; }
43 }
44
51 public bool IsMyCaffe(byte[] rgWeights, out string strVer)
52 {
53 strVer = null;
54
55 if (rgWeights == null || rgWeights.Length < 10)
56 return false;
57
58 string strCaffeNet = Encoding.ASCII.GetString(rgWeights, rgWeights.Length - 10, 10);
59 if (strCaffeNet == m_strWeightMyCaffeTag)
60 {
61 long lCaffeStart = BitConverter.ToInt64(rgWeights, rgWeights.Length - (10 + sizeof(long)));
62 strVer = Encoding.ASCII.GetString(rgWeights, (int)lCaffeStart + 10, 5);
63 return true;
64 }
65
66 return false;
67 }
68
76 {
77 FieldDescriptor fd = FieldDescriptor.CreateSolverStateFieldDesc();
78 ProtoBufWriter writer = new ProtoBufWriter(m_log);
79
80 m_log.WriteLine("Saving state...");
81
82 writer.WriteField(fd, "iter", new int[] { state.iter });
83 writer.WriteField(fd, "current_step", new int[] { state.current_step });
84
85 if (type == SolverParameter.SolverType.LBFGS)
86 {
87 writer.WriteField(fd, "start", new int[] { state.start });
88 writer.WriteField(fd, "end", new int[] { state.end });
89 }
90
91 for (int i = 0; i < state.history.Count; i++)
92 {
93 writer.WriteField(fd, "history", saveBlobProto(fd.FindFirstChild("history"), state.history[i]));
94 }
95
96 if (type == SolverParameter.SolverType.LBFGS)
97 {
98 for (int i = 0; i < state.s_history.Count; i++)
99 {
100 writer.WriteField(fd, "s_history", saveBlobProto(fd.FindFirstChild("s_history"), state.s_history[i]));
101 }
102
103 writer.WriteField(fd, "gradients", saveBlobProto(fd.FindFirstChild("gradient"), state.gradients));
104 writer.WriteField(fd, "direction", saveBlobProto(fd.FindFirstChild("direction"), state.direction));
105 }
106
107 return writer.GetBytes(true);
108 }
109
117 {
118 SolverState state = new SolverState();
119 FieldDescriptor fd = FieldDescriptor.CreateSolverStateFieldDesc();
120 ProtoBufReader reader = new ProtoBufReader(rgState);
121 ProtoBufFieldCollection fields = reader.ReadFields(fd, false);
122 Stopwatch sw = new Stopwatch();
123
124 m_log.WriteLine("Loading the Solver state...");
125
126 if (fields == null || fields.Count == 0)
127 return null;
128
129
130 //---------------------------------------------
131 // Load the state.
132 //---------------------------------------------
133
134 ProtoBufField pbIter = fields.FindFirstChild("iter");
135 state.iter = (pbIter == null || pbIter.IntValues == null || pbIter.IntValues.Length == 0) ? 0 : pbIter.IntValues[0];
136
137 ProtoBufField pbCurStep = fields.FindFirstChild("current_step");
138 state.current_step = (pbCurStep == null || pbCurStep.IntValues == null || pbCurStep.IntValues.Length == 0) ? 1 : pbCurStep.IntValues[0];
139
140 if (type == SolverParameter.SolverType.LBFGS)
141 {
142 ProtoBufField pbStart = fields.FindFirstChild("start");
143 state.start = (pbStart == null || pbStart.IntValues == null || pbStart.IntValues.Length == 0) ? 0 : pbStart.IntValues[0];
144
145 ProtoBufField pbEnd = fields.FindFirstChild("end");
146 state.end = (pbEnd == null || pbEnd.IntValues == null || pbEnd.IntValues.Length == 0) ? 1 : pbEnd.IntValues[0];
147 }
148
149 ProtoBufFieldCollection col = fields.FindAllChildren("history");
150 if (col != null && col.Count > 0)
151 {
152 FieldDescriptor fdHist = fd.FindFirstChild("history");
153
154 for (int i = 0; i < col.Count; i++)
155 {
156 state.history.Add(LoadBlobProto(col[i].Bytes, fdHist.FieldId));
157 }
158 }
159
160 if (type == SolverParameter.SolverType.LBFGS)
161 {
162 ProtoBufFieldCollection colS = fields.FindAllChildren("s_history");
163 if (colS != null && colS.Count > 0)
164 {
165 FieldDescriptor fdHist = fd.FindFirstChild("s_history");
166
167 for (int i = 0; i < colS.Count; i++)
168 {
169 state.s_history.Add(LoadBlobProto(colS[i].Bytes, fdHist.FieldId));
170 }
171 }
172
173 ProtoBufField pbGrad = fields.FindFirstChild("gradients");
174 if (pbGrad != null)
175 {
176 FieldDescriptor fdGrad = fd.FindFirstChild("gradients");
177 state.gradients = LoadBlobProto(pbGrad.Bytes, fdGrad.FieldId);
178 }
179
180 ProtoBufField pbDir = fields.FindFirstChild("direction");
181 if (pbDir != null)
182 {
183 FieldDescriptor fdDir = fd.FindFirstChild("direction");
184 state.direction = LoadBlobProto(pbDir.Bytes, fdDir.FieldId);
185 }
186 }
187
188 return state;
189 }
190
209 public BlobCollection<T> LoadWeights(byte[] rgWeights, List<string> rgExpectedShapes, BlobCollection<T> colBlobs, bool bSizeToFit, out bool bLoadedDiffs, List<string> inputWtInfo = null, List<string> targetWtInfo = null, string strSkipBlobType = null)
210 {
211 BlobCollection<T> colBlob1;
212 m_log.WriteLine("Attempting to load the weights in Caffe model format...");
213 string strVer;
214
215 if (!IsMyCaffe(rgWeights, out strVer))
216 {
217 colBlob1 = loadFromCaffe(rgWeights, rgExpectedShapes, colBlobs, bSizeToFit, out bLoadedDiffs, inputWtInfo, targetWtInfo, strSkipBlobType);
218 if (colBlob1 != null)
219 {
220 m_log.WriteLine("Weights loaded in Caffe model format.");
221 return colBlob1;
222 }
223
224 if (m_bFailOnFirstTry)
225 throw new Exception("Failed to load the weights from the caffe model.");
226 }
227 else if (strVer == "v.1.0")
228 {
229 m_log.FAIL("Loading weights with 'depreciated' native v.1.0 format...");
230 }
231
232 m_log.WriteLine("Attempting to load weights in MyCaffe model format...");
233 colBlob1 = loadFromMyCaffe(rgWeights, rgExpectedShapes, colBlobs, bSizeToFit, out bLoadedDiffs, inputWtInfo, targetWtInfo, strSkipBlobType);
234 if (colBlob1 != null)
235 {
236 m_log.WriteLine("Weights loaded in MyCaffe model format.");
237 return colBlob1;
238 }
239
240 if (m_bFailOnFirstTry)
241 throw new Exception("Failed to load the weights from the MyCaffe model.");
242
243 m_log.FAIL("Loading weights with 'depreciated' native format...");
244 return null;
245 }
246
252 public WeightInfo<T> LoadWeightInfo(byte[] rgWeights)
253 {
254 string strVer;
255
256 if (!IsMyCaffe(rgWeights, out strVer))
257 return loadInfoFromCaffe(rgWeights);
258 else
259 return loadInfoFromMyCaffe(rgWeights);
260 }
261
268 {
269 WeightInfo<T> info = new common.WeightInfo<T>();
270
271 foreach (Blob<T> b in colBlobs)
272 {
273 info.AddBlob(b);
274 }
275
276 return info;
277 }
278
291 public byte[] SaveWeights(BlobCollection<T> colBlobs, bool bSaveDiffs = false)
292 {
293 FieldDescriptor fd = FieldDescriptor.CreateNetworkParamFieldDesc();
294 ProtoBufWriter writer = new ProtoBufWriter(m_log);
295 Dictionary<string, BlobCollection<T>> rgLayers = new Dictionary<string, BlobCollection<T>>();
296
297 foreach (Blob<T> blob in colBlobs)
298 {
299 if (blob != null)
300 {
301 string strLayer = (string)blob.Tag;
302 if (strLayer == null || strLayer.Length == 0)
303 throw new Exception("Invalid blob specification - missing layer name.");
304
305 if (!rgLayers.ContainsKey(strLayer))
306 rgLayers.Add(strLayer, new BlobCollection<T>());
307
308 rgLayers[strLayer].Add(blob);
309 }
310 }
311
312 writer.WriteField(fd, "name", "");
313
314 foreach (KeyValuePair<string, BlobCollection<T>> kv in rgLayers)
315 {
316 m_log.WriteLine("Saving layer '" + kv.Key + "'...");
317 writer.WriteField(fd, "LayerParameter", saveLayerParameter(fd.FindFirstChild("LayerParameter"), kv.Key, kv.Value));
318 }
319
320 writer.Flush();
321
322 long lCaffeNetStart = writer.Length;
323 byte[] rgPad = new byte[256];
324
325 using (BinaryWriter bw = new BinaryWriter(writer.Stream))
326 {
327 bw.Write(rgPad);
328 lCaffeNetStart += rgPad.Length;
329
330 string strCaffeNet = MyCaffeTag;
331 byte[] rgCaffeNet = Encoding.ASCII.GetBytes(strCaffeNet);
332
333 string strVer = "1.0.1";
334 byte[] rgV = Encoding.ASCII.GetBytes(strVer);
335 byte[] rgVer = new byte[32];
336 Array.Copy(rgV, rgVer, rgV.Length);
337
338 bw.Write(rgCaffeNet);
339 bw.Write(rgVer);
340 bw.Write(lCaffeNetStart);
341 bw.Write(rgCaffeNet);
342 }
343
344 return writer.GetBytes(false);
345 }
346
353 public BlobProto LoadBlobProto(byte[] rg, int nFieldId)
354 {
355 FieldDescriptor fd = FieldDescriptor.CreateBlobProtoDesc(nFieldId);
356 ProtoBufReader reader = new ProtoBufReader(rg);
357 ProtoBufFieldCollection fields = reader.ReadFields(fd, false);
358
359 if (fields == null || fields.Count == 0)
360 return null;
361
362 for (int i = 0; i < fields.Count; i++)
363 {
364 ProtoBufField field = fields[i];
365 field.LoadSubFields(0, 4);
366 }
367
368 List<int> rgShape = new List<int>();
369
370 ProtoBufField pbShape = fields.FindFirstChild("shape");
371 if (pbShape != null)
372 {
373 if (pbShape.Type != ProtoBufField.TYPE.ARRAY)
374 throw new Exception("Invalid proto buf: invalid type 'shape'");
375
376 ProtoBufField pbDim = pbShape.Array.FindFirstChild("dim");
377 if (pbDim == null || pbDim.Type != ProtoBufField.TYPE.LONG_ARRAY)
378 throw new Exception("Invalid proto buf: missing 'dim' type.");
379
380 for (int i = 0; i < pbDim.LongValues.Length; i++)
381 {
382 rgShape.Add((int)pbDim.LongValues[i]);
383 }
384 }
385 else
386 {
387 ProtoBufField pbNum = fields.FindFirstChild("num");
388 if (pbNum != null)
389 {
390 if (pbNum.Type != ProtoBufField.TYPE.BIT32)
391 throw new Exception("Invalid proto buf: invalid type 'num'");
392
393 rgShape.Add(pbNum.IntValue);
394
395 ProtoBufField pbChannels = fields.FindFirstChild("channels");
396 if (pbChannels != null)
397 {
398 if (pbChannels.Type != ProtoBufField.TYPE.BIT32)
399 throw new Exception("Invalid proto buf: invalid type 'channels'");
400
401 rgShape.Add(pbChannels.IntValue);
402
403 ProtoBufField pbHeight = fields.FindFirstChild("height");
404 if (pbHeight != null)
405 {
406 if (pbHeight.Type != ProtoBufField.TYPE.BIT32)
407 throw new Exception("Invalid proto buf: invalid type 'height'");
408
409 rgShape.Add(pbHeight.IntValue);
410
411 ProtoBufField pbWidth = fields.FindFirstChild("width");
412 if (pbWidth != null)
413 {
414 if (pbWidth.Type != ProtoBufField.TYPE.BIT32)
415 throw new Exception("Invalid proto buf: invalid type 'width'");
416
417 rgShape.Add(pbWidth.IntValue);
418 }
419 }
420 }
421 }
422 }
423
424 ProtoBufField pbData = fields.FindFirstChild("data");
425 if (pbData == null)
426 {
427 pbData = fields.FindFirstChild("double_data");
428 if (pbData == null)
429 throw new Exception("Invalid proto buf: missing 'data' or 'double_data'");
430 }
431
432 BlobProto proto = new param.BlobProto(rgShape);
433
434 if (pbData.Type == ProtoBufField.TYPE.FLOAT_ARRAY)
435 proto.data = new List<float>(pbData.FloatValues);
436 else if (pbData.Type == ProtoBufField.TYPE.DOUBLE_ARRAY)
437 proto.double_data = new List<double>(pbData.DoubleValues);
438 else
439 throw new Exception("Invalid proto buf: invalid data type '" + pbData.Type.ToString() + "'.");
440
441 return proto;
442 }
443
450 public BlobProto LoadBlobProto(string strFile, int nFieldId)
451 {
452 byte[] rgBytes;
453
454 using (FileStream fs = new FileStream(strFile, FileMode.Open, FileAccess.Read))
455 {
456 using (BinaryReader br = new BinaryReader(fs))
457 {
458 rgBytes = br.ReadBytes((int)fs.Length);
459 }
460 }
461
462 return LoadBlobProto(rgBytes, nFieldId);
463 }
464
465 private byte[] saveLayerParameter(FieldDescriptor fd, string strName, BlobCollection<T> col)
466 {
467 ProtoBufWriter writer = new common.ProtoBufWriter(m_log);
468
469 writer.WriteField(fd, "name", strName);
470
471 foreach (Blob<T> blob in col)
472 {
473 writer.WriteField(fd, "blobs", saveBlobProto(fd.FindFirstChild("blobs"), blob));
474 m_log.WriteLine(" - saved blob '" + blob.Name + "'");
475 }
476
477 return writer.GetBytes();
478 }
479
480 private byte[] saveBlobProto(FieldDescriptor fd, BlobProto bp)
481 {
482 ProtoBufWriter writer = new ProtoBufWriter(m_log);
483
484 writer.WriteField(fd, "shape", saveBlobShape(fd.FindFirstChild("shape"), bp.shape.dim));
485
486 if (bp.double_data != null && bp.double_data.Count > 0)
487 writer.WriteField(fd, "double_data", bp.double_data.ToArray());
488 else
489 writer.WriteField(fd, "data", bp.data.ToArray());
490
491 return writer.GetBytes();
492 }
493
494 private byte[] saveBlobProto(FieldDescriptor fd, Blob<T> blob)
495 {
496 ProtoBufWriter writer = new ProtoBufWriter(m_log);
497
498 writer.WriteField(fd, "shape", saveBlobShape(fd.FindFirstChild("shape"), blob.shape()));
499
500 T[] rg = blob.update_cpu_data();
501
502 if (typeof(T) == typeof(double))
503 {
504 double[] rgD = (double[])Convert.ChangeType(rg, typeof(double[]));
505 writer.WriteField(fd, "double_data", rgD);
506 }
507 else
508 {
509 float[] rgD = (float[])Convert.ChangeType(rg, typeof(float[]));
510 writer.WriteField(fd, "data", rgD);
511 }
512
513 return writer.GetBytes();
514 }
515
516 private byte[] saveBlobShape(FieldDescriptor fd, List<int> rg)
517 {
518 ProtoBufWriter writer = new ProtoBufWriter(m_log);
519 List<long> rgLong = new List<long>();
520
521 for (int i = 0; i < rg.Count; i++)
522 {
523 rgLong.Add(rg[i]);
524 }
525
526 writer.WriteField(fd, "dim", rgLong.ToArray());
527
528 return writer.GetBytes();
529 }
530
531 private BlobCollection<T> loadFromMyCaffe(byte[] rgWeights, List<string> rgExpectedShapes, BlobCollection<T> colBlobs, bool bSizeToFit, out bool bLoadedDiffs, List<string> inputWtInfo = null, List<string> targetWtInfo = null, string strSkipBlobType = null)
532 {
533 BlobCollection<T> colBlobs1 = loadFromCaffe(rgWeights, rgExpectedShapes, colBlobs, bSizeToFit, out bLoadedDiffs, inputWtInfo, targetWtInfo, strSkipBlobType);
534 return colBlobs1;
535 }
536
537 private WeightInfo<T> loadInfoFromMyCaffe(byte[] rgWeights)
538 {
539 return loadInfoFromCaffe(rgWeights);
540 }
541
542 private BlobCollection<T> loadFromCaffe(byte[] rgWeights, List<string> rgExpectedShapes, BlobCollection<T> colBlobs, bool bSizeToFit, out bool bLoadedDiffs, List<string> inputWtInfo = null, List<string> targetWtInfo = null, string strSkipBlobType = null)
543 {
544 FieldDescriptor fd = FieldDescriptor.CreateNetworkParamFieldDesc();
545 ProtoBufReader reader = new ProtoBufReader(rgWeights);
546 ProtoBufFieldCollection fields = reader.ReadFields(fd, true);
547 Stopwatch sw = new Stopwatch();
548 BlobName name = new BlobName();
549
550 bLoadedDiffs = false;
551
552 if (fields == null || fields.Count == 0)
553 return null;
554
555 sw.Start();
556
557 for (int i=0; i<fields.Count; i++)
558 {
559 ProtoBufField field = fields[i];
560 field.LoadSubFields(0, 4);
561
562 if (sw.Elapsed.TotalMilliseconds > 1000)
563 {
564 m_log.Progress = (double)i / (double)fields.Count;
565 m_log.WriteLine("(" + m_log.Progress.ToString("P") + ") loading fields...");
566 sw.Restart();
567 }
568 }
569
570 //---------------------------------------------
571 // Find all the blobs containing learnable
572 // parameters.
573 //---------------------------------------------
574
575 ProtoBufFieldCollection colFieldBlobs = new common.ProtoBufFieldCollection();
576 int nLayerIdx = 0;
577
578 for (int i = 0; i < fields.Count; i++)
579 {
580 if (fields[i].FieldDesc != null) // Ignore null entries which can occur in V1.
581 {
582 if (fields[i].FieldDesc.Name == "LayerParameter")
583 {
584 ProtoBufField pbName = fields[i].Array.FindFirstChild("name");
585 ProtoBufFieldCollection col = fields[i].Array.FindAllChildren("blobs");
586 string strName = (pbName != null) ? pbName.StringValue : ("layer_" + nLayerIdx.ToString());
587
588 if (col != null && col.Count > 0)
589 {
590 col.SetTag(strName);
591 colFieldBlobs.AddRange(col);
592 }
593
594 nLayerIdx++;
595 }
596 else if (fields[i].FieldDesc.Name == "V1LayerParameter")
597 {
598 ProtoBufField pbName = fields[i].Array.FindFirstChild("name");
599 ProtoBufFieldCollection col = fields[i].Array.FindAllChildren("blobs");
600 string strName = (pbName != null) ? pbName.StringValue : ("layer_" + nLayerIdx.ToString());
601
602 if (col != null && col.Count > 0)
603 {
604 col.SetTag(strName);
605 col.SetLegacy(true);
606 colFieldBlobs.AddRange(col);
607 }
608
609 nLayerIdx++;
610 }
611 }
612 }
613
614 //---------------------------------------------
615 // Find the first learnable parameter that
616 // matches the size of the first colBlob.
617 //---------------------------------------------
618
619 m_log.Progress = 0;
620 m_log.WriteLine("Loading the weights...");
621
622 if (colBlobs.Count != colFieldBlobs.Count)
623 m_log.WriteLine("The number of learnable blobs within the weights does not match the number within the network, attempting to load by size...");
624
625 int nFieldIdx = 0;
626 int nBlobIdx = 0;
627 int nInfoIdx = 0;
628 int nTargetIdx = 0;
629
630 List<long> rgBlobShape = null;
631
632 while (nFieldIdx < colFieldBlobs.Count && nBlobIdx < colBlobs.Count)
633 {
634 Blob<T> blob = colBlobs[nBlobIdx];
635 string strName = name.GetName(blob.Name);
636
637 if (targetWtInfo != null)
638 {
639 while (strName != targetWtInfo[nTargetIdx] && nBlobIdx < colBlobs.Count)
640 {
641 nBlobIdx++;
642
643 if (nBlobIdx == colBlobs.Count)
644 break;
645
646 blob = colBlobs[nBlobIdx];
647 strName = name.GetName(blob.Name);
648 }
649
650 if (nBlobIdx == colBlobs.Count)
651 m_log.WriteError(new Exception("Could not find the target blob '" + targetWtInfo[nTargetIdx] + "'!"));
652
653 nTargetIdx++;
654 }
655
656 string strShapeB = rgExpectedShapes[nBlobIdx];
657 bool bSizeToFitWts = colBlobs[nBlobIdx].reshape_when_sharing;
658 string strShapeW = "";
659 long lCount = 0;
660 bool bResizeNeeded = false;
661 bool bMisSized = false;
662
663 //-----------------------------------------
664 // Find the first matching size.
665 //-----------------------------------------
666 while (nFieldIdx < colFieldBlobs.Count)
667 {
668 strName = null;
669
670 ProtoBufField pbName = colFieldBlobs[nFieldIdx].Array.FindFirstChild("name");
671 if (pbName != null && pbName.Type == ProtoBufField.TYPE.STRING)
672 {
673 strName = pbName.StringValue;
674 }
675 else
676 {
677 ProtoBufField pbType = colFieldBlobs[nFieldIdx].Array.FindFirstChild("type");
678 if (pbType != null && pbType.Type == ProtoBufField.TYPE.STRING)
679 strName = pbType.StringValue + "_" + nFieldIdx.ToString();
680 else
681 strName = "blob_" + nFieldIdx.ToString();
682 }
683
684 if (inputWtInfo == null || strName == inputWtInfo[nInfoIdx])
685 {
686 nInfoIdx++;
687
688 ProtoBufField pbShape = colFieldBlobs[nFieldIdx].Array.FindFirstChild("shape");
689 if (pbShape != null && pbShape.Type == ProtoBufField.TYPE.ARRAY)
690 {
691 ProtoBufField pbDim = pbShape.Array.FindFirstChild("dim");
692 if (pbDim != null && pbDim.Type == ProtoBufField.TYPE.LONG_ARRAY)
693 {
694 strShapeW = createShapeString(pbDim.LongValues, out lCount);
695
696 if (compareShapes(strShapeB, strShapeW) || bSizeToFitWts)
697 {
698 rgBlobShape = new List<long>(pbDim.LongValues);
699 bResizeNeeded = bSizeToFitWts;
700 break;
701 }
702
703 if (bSizeToFit && compareShapes(strShapeB, strShapeW, 2))
704 {
705 rgBlobShape = new List<long>(pbDim.LongValues);
706 break;
707 }
708
709 bMisSized = true;
710 break;
711 }
712 }
713 else
714 {
715 ProtoBufField pbNum = colFieldBlobs[nFieldIdx].Array.FindFirstChild("num");
716 if (pbNum != null && pbNum.Type == ProtoBufField.TYPE.BIT32)
717 {
718 List<long> rgShape = new List<long>();
719 rgShape.Add(pbNum.IntValue);
720
721 ProtoBufField pbChannels = colFieldBlobs[nFieldIdx].Array.FindFirstChild("channels");
722 if (pbChannels != null && pbChannels.Type == ProtoBufField.TYPE.BIT32)
723 {
724 rgShape.Add(pbChannels.IntValue);
725
726 ProtoBufField pbHeight = colFieldBlobs[nFieldIdx].Array.FindFirstChild("height");
727 if (pbHeight != null && pbHeight.Type == ProtoBufField.TYPE.BIT32)
728 {
729 rgShape.Add(pbHeight.IntValue);
730
731 ProtoBufField pbWidth = colFieldBlobs[nFieldIdx].Array.FindFirstChild("width");
732 if (pbWidth != null && pbWidth.Type == ProtoBufField.TYPE.BIT32)
733 {
734 rgShape.Add(pbWidth.IntValue);
735 }
736 }
737 }
738
739 strShapeW = createShapeString(rgShape.ToArray(), out lCount);
740
741 if (compareShapes(strShapeB, strShapeW) || (bSizeToFit || bSizeToFitWts))
742 {
743 rgBlobShape = rgShape;
744 break;
745 }
746
747 if ((bSizeToFit || bSizeToFitWts) && compareShapes(strShapeB, strShapeW, 2))
748 {
749 rgBlobShape = rgShape;
750 bResizeNeeded = true;
751 break;
752 }
753
754 bMisSized = true;
755 break;
756 }
757 }
758 }
759
760 nFieldIdx++;
761 }
762
763 if (nFieldIdx == colFieldBlobs.Count)
764 continue;
765
766 //-----------------------------------------
767 // Copy the data, but only for blobs
768 // that are not missized and ones that do
769 // not match the skip type, if specified.
770 //-----------------------------------------
771
772 if (!bMisSized && (strSkipBlobType == null || blob.type.ToString() != strSkipBlobType))
773 {
774 ProtoBufField pbData = colFieldBlobs[nFieldIdx].Array.FindFirstChild("data");
775 FieldDescriptor.TYPE type = FieldDescriptor.TYPE.FLOAT;
776 long lDataCount = 0;
777 if (pbData == null)
778 {
779 pbData = colFieldBlobs[nFieldIdx].Array.FindFirstChild("double_data");
780 type = FieldDescriptor.TYPE.DOUBLE;
781 lDataCount = pbData.DoubleValues.Length;
782 }
783 else
784 {
785 lDataCount = pbData.FloatValues.Length;
786 }
787
788 if (pbData == null || (lDataCount != lCount && !bSizeToFit && !bSizeToFitWts))
789 m_log.FAIL("Could not find the weights matching the data size '" + strShapeB + "'!");
790
791
792 if (bSizeToFitWts)
793 {
794 if (bResizeNeeded)
795 {
796 List<int> rgNewShape = parseShape(strShapeW);
797
798 while (rgNewShape.Count < rgBlobShape.Count)
799 {
800 rgNewShape.Add(1);
801 }
802
803 blob.Reshape(rgNewShape);
804
805 for (int i = 0; i < rgNewShape.Count; i++)
806 {
807 rgBlobShape[i] = rgNewShape[i];
808 }
809 }
810
811 T[] rgData = copyData(pbData, type, lDataCount, rgBlobShape);
812
813 if (blob.HalfSize)
814 {
815 Blob<T> blobTemp = new Blob<T>(blob.Cuda, blob.Log, false, false);
816 blobTemp.ReshapeLike(blob);
817 blobTemp.mutable_cpu_data = rgData;
818 blob.CopyFrom(blobTemp);
819 blobTemp.Dispose();
820 }
821 else
822 {
823 blob.mutable_cpu_data = rgData;
824 }
825 }
826 else
827 {
828 if (bSizeToFit && !compareShapes(strShapeB, strShapeW, 4))
829 m_log.FAIL("Could not find the weights matching the first two items of the shape '" + strShapeB + "'!");
830
831 T[] rgData = copyData(pbData, type, lDataCount, rgBlobShape);
832
833 if (blob.HalfSize)
834 {
835 Blob<T> blobTemp = new Blob<T>(blob.Cuda, blob.Log, false, false);
836 blobTemp.ReshapeLike(blob);
837 blobTemp.mutable_cpu_data = rgData;
838 blob.CopyFrom(blobTemp);
839 blobTemp.Dispose();
840 }
841 else
842 {
843 blob.mutable_cpu_data = rgData;
844 }
845
846 if (bSizeToFit && bResizeNeeded)
847 {
848 List<int> rgNewShape = parseShape(strShapeB);
849 Blob<T> blobResized = blob.Resize(rgNewShape);
850 blob.Dispose();
851 colBlobs[nBlobIdx] = blobResized;
852 }
853 }
854
855 blob.Tag = colFieldBlobs[nFieldIdx].Tag;
856
857 m_log.WriteLine("(" + m_log.Progress.ToString("P") + ") loaded blob '" + colBlobs[nBlobIdx].Name + "' size = " + strShapeB);
858 }
859 else
860 {
861 m_log.WriteLine("WARNING: did NOT load blob '" + colBlobs[nBlobIdx].Name + "' size = " + strShapeB);
862 }
863
864 m_log.Progress = (double)nBlobIdx / (double)colBlobs.Count;
865
866 nFieldIdx++;
867 nBlobIdx++;
868
869 if ((targetWtInfo != null && nTargetIdx == targetWtInfo.Count) ||
870 (inputWtInfo != null && nInfoIdx == inputWtInfo.Count))
871 break;
872 }
873
874 return colBlobs;
875 }
876
877 private WeightInfo<T> loadInfoFromCaffe(byte[] rgWeights)
878 {
879 WeightInfo<T> info = new common.WeightInfo<T>();
880 FieldDescriptor fd = FieldDescriptor.CreateNetworkParamFieldDesc();
881 ProtoBufReader reader = new ProtoBufReader(rgWeights);
882 ProtoBufFieldCollection fields = reader.ReadFields(fd, true);
883 Stopwatch sw = new Stopwatch();
884
885 if (fields == null || fields.Count == 0)
886 return null;
887
888 sw.Start();
889
890 for (int i = 0; i < fields.Count; i++)
891 {
892 ProtoBufField field = fields[i];
893 field.LoadSubFields(0, 4);
894
895 if (sw.Elapsed.TotalMilliseconds > 1000)
896 {
897 m_log.Progress = (double)i / (double)fields.Count;
898 m_log.WriteLine("(" + m_log.Progress.ToString("P") + ") loading fields...");
899 sw.Restart();
900 }
901 }
902
903 //---------------------------------------------
904 // Find all the blobs containing learnable
905 // parameters.
906 //---------------------------------------------
907
908 ProtoBufFieldCollection colFieldBlobs = new common.ProtoBufFieldCollection();
909 int nLayerIdx = 0;
910
911 for (int i = 0; i < fields.Count; i++)
912 {
913 if (fields[i].FieldDesc != null) // Ignore null entries which can occur in V1
914 {
915 if (fields[i].FieldDesc.Name == "LayerParameter")
916 {
917 ProtoBufField pbName = fields[i].Array.FindFirstChild("name");
918 ProtoBufFieldCollection col = fields[i].Array.FindAllChildren("blobs");
919 string strName = (pbName != null) ? pbName.StringValue : ("layer_" + nLayerIdx.ToString());
920
921 if (col != null && col.Count > 0)
922 {
923 col.SetTag(strName);
924 colFieldBlobs.AddRange(col);
925 }
926
927 nLayerIdx++;
928 }
929 else if (fields[i].FieldDesc.Name == "V1LayerParameter")
930 {
931 ProtoBufField pbName = fields[i].Array.FindFirstChild("name");
932 ProtoBufFieldCollection col = fields[i].Array.FindAllChildren("blobs");
933 string strName = (pbName != null) ? pbName.StringValue : ("layer_" + nLayerIdx.ToString());
934
935 if (col != null && col.Count > 0)
936 {
937 col.SetTag(strName);
938 col.SetLegacy(true);
939 colFieldBlobs.AddRange(col);
940 }
941
942 nLayerIdx++;
943 }
944 }
945 }
946
947 //---------------------------------------------
948 // Find the first learnable parameter that
949 // matches the size of the first colBlob.
950 //---------------------------------------------
951
952 m_log.Progress = 0;
953
954 int nFieldIdx = 0;
955
956 while (nFieldIdx < colFieldBlobs.Count)
957 {
958 string strName = null;
959
960 ProtoBufField pbName = colFieldBlobs[nFieldIdx].Array.FindFirstChild("name");
961 if (pbName != null && pbName.Type == ProtoBufField.TYPE.STRING)
962 {
963 strName = pbName.StringValue;
964 }
965 else
966 {
967 ProtoBufField pbType = colFieldBlobs[nFieldIdx].Array.FindFirstChild("type");
968 if (pbType != null && pbType.Type == ProtoBufField.TYPE.STRING)
969 strName = pbType.StringValue + "_" + nFieldIdx.ToString();
970 else
971 strName = "blob_" + nFieldIdx.ToString();
972 }
973
974 List<int> rgShape = new List<int>();
975
976 ProtoBufField pbShape = colFieldBlobs[nFieldIdx].Array.FindFirstChild("shape");
977 if (pbShape != null && pbShape.Type == ProtoBufField.TYPE.ARRAY)
978 {
979 ProtoBufField pbDim = pbShape.Array.FindFirstChild("dim");
980 if (pbDim != null && pbDim.Type == ProtoBufField.TYPE.LONG_ARRAY)
981 {
982 for (int i = 0; i < pbDim.LongValues.Length; i++)
983 {
984 rgShape.Add((int)pbDim.LongValues[i]);
985 }
986 }
987 }
988 else
989 {
990 ProtoBufField pbNum = colFieldBlobs[nFieldIdx].Array.FindFirstChild("num");
991 if (pbNum != null && pbNum.Type == ProtoBufField.TYPE.BIT32)
992 {
993 rgShape.Add(pbNum.IntValue);
994
995 ProtoBufField pbChannels = colFieldBlobs[nFieldIdx].Array.FindFirstChild("channels");
996 if (pbChannels != null && pbChannels.Type == ProtoBufField.TYPE.BIT32)
997 {
998 rgShape.Add(pbChannels.IntValue);
999
1000 ProtoBufField pbHeight = colFieldBlobs[nFieldIdx].Array.FindFirstChild("height");
1001 if (pbHeight != null && pbHeight.Type == ProtoBufField.TYPE.BIT32)
1002 {
1003 rgShape.Add(pbHeight.IntValue);
1004
1005 ProtoBufField pbWidth = colFieldBlobs[nFieldIdx].Array.FindFirstChild("width");
1006 if (pbWidth != null && pbWidth.Type == ProtoBufField.TYPE.BIT32)
1007 {
1008 rgShape.Add(pbWidth.IntValue);
1009 }
1010 }
1011 }
1012 }
1013 }
1014
1015 info.AddBlob(strName, rgShape, BLOB_TYPE.UNKNOWN);
1016
1017 nFieldIdx++;
1018 }
1019
1020 return info;
1021 }
1022
1023 private T[] copyData(ProtoBufField pb, FieldDescriptor.TYPE type, long lCount, List<long> rgBlobShape)
1024 {
1025 T[] rgData = new T[lCount];
1026
1027 if (type == FieldDescriptor.TYPE.FLOAT)
1028 Array.Copy(pb.FloatValues, rgData, lCount);
1029 else
1030 {
1031 if (typeof(T) == typeof(double))
1032 Array.Copy(pb.DoubleValues, rgData, lCount);
1033 else
1034 return Utility.ConvertVec<T>(pb.DoubleValues);
1035 }
1036
1037 return rgData;
1038 }
1039
1040 private List<int> parseShape(string strShape, int nCount = int.MaxValue)
1041 {
1042 List<int> rg1 = new List<int>();
1043 string[] rgstr1 = strShape.Split(' ');
1044
1045 for (int i = 0; i < rgstr1.Length - 1 && i < nCount; i++)
1046 {
1047 int nVal = int.Parse(rgstr1[i]);
1048
1049 if (nVal > 1)
1050 rg1.Add(nVal);
1051 }
1052
1053 return rg1;
1054 }
1055
1056 private bool compareShapes(string strA, string strB, int nCount = int.MaxValue)
1057 {
1058 if (strA == strB)
1059 return true;
1060
1061 List<int> rg1 = parseShape(strA, nCount);
1062 List<int> rg2 = parseShape(strB, nCount);
1063
1064 if (rg1.Count != rg2.Count)
1065 return false;
1066
1067 if (rg1.Count == 0)
1068 {
1069 if (strA != strB)
1070 return false;
1071 else
1072 return true;
1073 }
1074
1075 for (int i = 0; i < rg1.Count; i++)
1076 {
1077 if (rg1[i] != rg2[i])
1078 return false;
1079 }
1080
1081 return true;
1082 }
1083
1084 private string createShapeString(long[] rg, out long lCount)
1085 {
1086 lCount = 1;
1087 string str = "";
1088
1089 for (int i = 0; i < rg.Length; i++)
1090 {
1091 if (rg[i] >= 1)
1092 {
1093 str += rg[i].ToString();
1094 str += " ";
1095 lCount *= rg[i];
1096 }
1097 }
1098
1099 str += "(" + rg.Length.ToString() + ")";
1100
1101 return str;
1102 }
1103 }
1104
1105 class ProtoBufWriter : IDisposable
1106 {
1107 MemoryStream m_ms = null;
1108 CodedOutputStream m_strm = null;
1109 bool m_bOwnStream = true;
1110 Log m_log;
1111 static int m_nUnknownFieldID = 5000;
1112 Dictionary<string, int> m_rgUnknownFields = new Dictionary<string, int>();
1113
1114 public ProtoBufWriter(Log log)
1115 {
1116 m_log = log;
1117 m_ms = new MemoryStream();
1118 m_strm = new CodedOutputStream(m_ms);
1119 }
1120
1121 public ProtoBufWriter(Log log, CodedOutputStream strm)
1122 {
1123 m_strm = strm;
1124 m_bOwnStream = false;
1125 }
1126
1127 public void Dispose()
1128 {
1129 if (m_strm != null && m_bOwnStream)
1130 {
1131 m_strm.Dispose();
1132 m_strm = null;
1133 }
1134
1135 if (m_ms != null)
1136 {
1137 m_ms.Dispose();
1138 m_ms = null;
1139 }
1140 }
1141
1142 public int Length
1143 {
1144 get { return (int)m_ms.Length; }
1145 }
1146
1147 public byte[] GetBytes(bool bFlush = true)
1148 {
1149 if (m_strm != null && bFlush)
1150 m_strm.Flush();
1151
1152 byte[] rg = m_ms.ToArray();
1153 return rg;
1154 }
1155
1156 public void Flush()
1157 {
1158 m_strm.Flush();
1159 }
1160
1161 public MemoryStream Stream
1162 {
1163 get { return m_ms; }
1164 }
1165
1166 private int getFieldId(FieldDescriptor fd, string strName, out FieldDescriptor.TYPE type)
1167 {
1168 type = FieldDescriptor.TYPE.UNKNOWN;
1169
1170 fd = fd.FindFirstChild(strName);
1171 if (fd != null)
1172 {
1173 type = fd.Type;
1174 return fd.FieldId;
1175 }
1176
1177 if (m_rgUnknownFields.ContainsKey(strName))
1178 return m_rgUnknownFields[strName];
1179
1180 int nId = m_nUnknownFieldID;
1181 m_nUnknownFieldID++;
1182
1183 m_rgUnknownFields.Add(strName, nId);
1184
1185 return nId;
1186 }
1187
1188 public void WriteField(FieldDescriptor fd, string strName, string strVal)
1189 {
1190 FieldDescriptor.TYPE type;
1191 int nFieldId = getFieldId(fd, strName, out type);
1192 uint tag = WireFormat.MakeTag(nFieldId, WireFormat.WireType.LengthDelimited);
1193
1194 m_strm.WriteUInt32(tag);
1195 m_strm.WriteString(strVal);
1196 }
1197
1198 public void WriteField(FieldDescriptor fd, string strName, byte[] rg)
1199 {
1200 FieldDescriptor.TYPE type;
1201 int nFieldId = getFieldId(fd, strName, out type);
1202 uint tag = WireFormat.MakeTag(nFieldId, WireFormat.WireType.LengthDelimited);
1203
1204 m_strm.WriteUInt32(tag);
1205 m_strm.WriteBytes(ByteString.CopyFrom(rg));
1206 }
1207
1208 public void WriteField(FieldDescriptor fd, string strName, double dfVal)
1209 {
1210 FieldDescriptor.TYPE type;
1211 int nFieldId = getFieldId(fd, strName, out type);
1212 uint tag;
1213
1214 switch (type)
1215 {
1216 case FieldDescriptor.TYPE.DOUBLE:
1217 tag = WireFormat.MakeTag(nFieldId, WireFormat.WireType.Fixed64);
1218 m_strm.WriteUInt32(tag);
1219 m_strm.WriteDouble(dfVal);
1220 break;
1221
1222 case FieldDescriptor.TYPE.FLOAT:
1223 tag = WireFormat.MakeTag(nFieldId, WireFormat.WireType.Fixed32);
1224 m_strm.WriteUInt32(tag);
1225 m_strm.WriteFloat((float)dfVal);
1226 break;
1227
1228 case FieldDescriptor.TYPE.LONG:
1229 case FieldDescriptor.TYPE.ULONG:
1230 tag = WireFormat.MakeTag(nFieldId, WireFormat.WireType.Fixed64);
1231 m_strm.WriteUInt32(tag);
1232 m_strm.WriteFixed64((ulong)(long)dfVal);
1233 break;
1234
1235 case FieldDescriptor.TYPE.INT:
1236 case FieldDescriptor.TYPE.UINT:
1237 tag = WireFormat.MakeTag(nFieldId, WireFormat.WireType.Fixed32);
1238 m_strm.WriteUInt32(tag);
1239 m_strm.WriteFixed32((uint)(int)dfVal);
1240 break;
1241
1242 default:
1243 throw new Exception("Unknown type '" + type.ToString() + "'");
1244 }
1245 }
1246
1247 public void WriteField(FieldDescriptor fd, string strName, long[] rgVal)
1248 {
1249 FieldDescriptor.TYPE type;
1250 int nFieldId = getFieldId(fd, strName, out type);
1251 uint tag;
1252
1253 if (type != FieldDescriptor.TYPE.LONG &&
1254 type != FieldDescriptor.TYPE.ULONG)
1255 throw new Exception("Invalid type '" + type.ToString() + "'");
1256
1257 tag = WireFormat.MakeTag(nFieldId, WireFormat.WireType.LengthDelimited);
1258 m_strm.WriteUInt32(tag);
1259
1260 ProtoBufWriter pbWriter = new ProtoBufWriter(m_log);
1261 byte[] rg = pbWriter.WriteArray(type, rgVal);
1262 m_strm.WriteBytes(ByteString.CopyFrom(rg));
1263 }
1264
1265 public byte[] WriteArray(FieldDescriptor.TYPE type, long[] rgVal)
1266 {
1267 for (int i = 0; i < rgVal.Length; i++)
1268 {
1269 if (type == FieldDescriptor.TYPE.ULONG)
1270 m_strm.WriteUInt64((uint)rgVal[i]);
1271 else
1272 m_strm.WriteInt64(rgVal[i]);
1273 }
1274
1275 return GetBytes();
1276 }
1277
1278 public void WriteField(FieldDescriptor fd, string strName, int[] rgVal)
1279 {
1280 FieldDescriptor.TYPE type;
1281 int nFieldId = getFieldId(fd, strName, out type);
1282 uint tag;
1283
1284 if (type != FieldDescriptor.TYPE.INT &&
1285 type != FieldDescriptor.TYPE.UINT)
1286 throw new Exception("Invalid type '" + type.ToString() + "'");
1287
1288 tag = WireFormat.MakeTag(nFieldId, WireFormat.WireType.LengthDelimited);
1289 m_strm.WriteUInt32(tag);
1290
1291 ProtoBufWriter pbWriter = new ProtoBufWriter(m_log);
1292 byte[] rg = pbWriter.WriteArray(type, rgVal);
1293 m_strm.WriteBytes(ByteString.CopyFrom(rg));
1294 }
1295
1296 public byte[] WriteArray(FieldDescriptor.TYPE type, int[] rgVal)
1297 {
1298 for (int i = 0; i < rgVal.Length; i++)
1299 {
1300 if (type == FieldDescriptor.TYPE.UINT)
1301 m_strm.WriteUInt32((uint)rgVal[i]);
1302 else
1303 m_strm.WriteInt64(rgVal[i]);
1304 }
1305
1306 return GetBytes();
1307 }
1308
1309 public void WriteField(FieldDescriptor fd, string strName, double[] rgVal)
1310 {
1311 FieldDescriptor.TYPE type;
1312 int nFieldId = getFieldId(fd, strName, out type);
1313 uint tag;
1314
1315 if (type != FieldDescriptor.TYPE.DOUBLE &&
1316 type != FieldDescriptor.TYPE.FLOAT)
1317 throw new Exception("Invalid type '" + type.ToString() + "'");
1318
1319 tag = WireFormat.MakeTag(nFieldId, WireFormat.WireType.LengthDelimited);
1320 m_strm.WriteUInt32(tag);
1321
1322 ProtoBufWriter pbWriter = new ProtoBufWriter(m_log);
1323 byte[] rg = pbWriter.WriteArray(type, rgVal);
1324 m_strm.WriteBytes(ByteString.CopyFrom(rg));
1325 }
1326
1327 public byte[] WriteArray(FieldDescriptor.TYPE type, double[] rgVal)
1328 {
1329 for (int i = 0; i < rgVal.Length; i++)
1330 {
1331 m_strm.WriteDouble(rgVal[i]);
1332 }
1333
1334 return GetBytes();
1335 }
1336
1337 public void WriteField(FieldDescriptor fd, string strName, float[] rgVal)
1338 {
1339 FieldDescriptor.TYPE type;
1340 int nFieldId = getFieldId(fd, strName, out type);
1341 uint tag;
1342
1343 if (type != FieldDescriptor.TYPE.DOUBLE &&
1344 type != FieldDescriptor.TYPE.FLOAT)
1345 throw new Exception("Invalid type '" + type.ToString() + "'");
1346
1347 tag = WireFormat.MakeTag(nFieldId, WireFormat.WireType.LengthDelimited);
1348 m_strm.WriteUInt32(tag);
1349
1350 ProtoBufWriter pbWriter = new ProtoBufWriter(m_log);
1351 byte[] rg = pbWriter.WriteArray(type, rgVal);
1352 m_strm.WriteBytes(ByteString.CopyFrom(rg));
1353 }
1354
1355 public byte[] WriteArray(FieldDescriptor.TYPE type, float[] rgVal)
1356 {
1357 for (int i = 0; i < rgVal.Length; i++)
1358 {
1359 m_strm.WriteFloat(rgVal[i]);
1360 }
1361
1362 return GetBytes();
1363 }
1364 }
1365
1366 class ProtoBufReader : IDisposable
1367 {
1368 CodedInputStream m_strm = null;
1369 bool m_bOwnStream = true;
1370
1371 public ProtoBufReader(byte[] rg)
1372 {
1373 m_strm = new CodedInputStream(rg);
1374 }
1375
1376 public ProtoBufReader(CodedInputStream strm)
1377 {
1378 m_strm = strm;
1379 m_bOwnStream = false;
1380 }
1381
1382 public void Dispose()
1383 {
1384 if (m_strm != null && m_bOwnStream)
1385 {
1386 m_strm.Dispose();
1387 m_strm = null;
1388 }
1389 }
1390
1391 public ProtoBufFieldCollection ReadFields(FieldDescriptor fd, bool bFirstRead)
1392 {
1393 ProtoBufFieldCollection fields = new common.ProtoBufFieldCollection();
1394 ProtoBufField field = ReadField(fd, bFirstRead);
1395
1396 while (field != null)
1397 {
1398 if (field.Length > 0 || (field.Type != ProtoBufField.TYPE.BYTES && field.Type != ProtoBufField.TYPE.STRING))
1399 fields.Add(field);
1400
1401 field = ReadField(fd, bFirstRead);
1402 }
1403
1404 return fields;
1405 }
1406
1407 public ProtoBufField ReadField(FieldDescriptor fd, bool bFirstRead)
1408 {
1409 if (m_strm.IsAtEnd)
1410 return null;
1411
1412 uint tag = m_strm.ReadUInt32();
1413 int nField = WireFormat.GetTagFieldNumber(tag);
1414
1415 if (nField <= 0)
1416 return null;
1417
1418 int nWireFmt = (int)WireFormat.GetTagWireType(tag);
1419 if (bFirstRead && nWireFmt != (int)WireFormat.WireType.LengthDelimited)
1420 return null;
1421
1422 if (fd != null)
1423 fd = fd.FindFirstChild(nField);
1424
1425 ProtoBufField field = new ProtoBufField(m_strm, nField, fd);
1426 if (!field.Load((WireFormat.WireType)nWireFmt))
1427 return null;
1428
1429 return field;
1430 }
1431 }
1432
1433 class ProtoBufFieldCollection : IEnumerable<ProtoBufField>
1434 {
1435 List<ProtoBufField> m_rgFields = new List<ProtoBufField>();
1436
1437 public ProtoBufFieldCollection()
1438 {
1439 }
1440
1441 public int Count
1442 {
1443 get { return m_rgFields.Count; }
1444 }
1445
1446 public ProtoBufField this[int nIdx]
1447 {
1448 get { return m_rgFields[nIdx]; }
1449 }
1450
1451 public void SetTag(string str)
1452 {
1453 foreach (ProtoBufField field in m_rgFields)
1454 {
1455 field.Tag = str;
1456 }
1457 }
1458
1459 public void SetLegacy(bool bLegacy)
1460 {
1461 foreach (ProtoBufField field in m_rgFields)
1462 {
1463 field.Legacy = bLegacy;
1464 }
1465 }
1466
1467 public void Add(ProtoBufField p)
1468 {
1469 m_rgFields.Add(p);
1470 }
1471
1472 public void AddRange(ProtoBufFieldCollection col)
1473 {
1474 m_rgFields.AddRange(col.m_rgFields);
1475 }
1476
1477 public ProtoBufFieldCollection FindAllChildren(string strName)
1478 {
1479 ProtoBufFieldCollection col = new common.ProtoBufFieldCollection();
1480
1481 foreach (ProtoBufField field in m_rgFields)
1482 {
1483 if (field.FieldDesc != null && field.FieldDesc.Name == strName)
1484 col.Add(field);
1485 }
1486
1487 return col;
1488 }
1489
1490 public ProtoBufField FindFirstChild(string strName)
1491 {
1492 foreach (ProtoBufField field in m_rgFields)
1493 {
1494 if (field.FieldDesc != null && field.FieldDesc.Name == strName)
1495 return field;
1496 }
1497
1498 return null;
1499 }
1500
1501 public IEnumerator<ProtoBufField> GetEnumerator()
1502 {
1503 return m_rgFields.GetEnumerator();
1504 }
1505
1506 IEnumerator IEnumerable.GetEnumerator()
1507 {
1508 return m_rgFields.GetEnumerator();
1509 }
1510 }
1511
1512
1513 class ProtoBufField
1514 {
1515 FieldDescriptor m_fd;
1516 CodedInputStream m_strm;
1517 byte[] m_rgBytes;
1518 string m_strVal;
1519 int m_nVal = 0;
1520 long m_lVal = 0;
1521 float m_fVal = 0;
1522 double m_dfVal = 0;
1523 int[] m_rgnVal = null;
1524 long[] m_rglVal = null;
1525 float[] m_rgfVal = null;
1526 double[] m_rgdfVal = null;
1527 string m_strTag = null;
1528 bool m_bLegacy = false;
1529
1530 TYPE m_type = TYPE.BYTES;
1531 ProtoBufFieldCollection m_col = new ProtoBufFieldCollection();
1532 int m_nField;
1533 WireFormat.WireType m_wireType;
1534
1535 public enum TYPE
1536 {
1537 BYTES,
1538 STRING,
1539 BIT32,
1540 BIT64,
1541 ARRAY,
1542 FLOAT_ARRAY,
1543 DOUBLE_ARRAY,
1544 INT_ARRAY,
1545 LONG_ARRAY
1546 }
1547
1548 public ProtoBufField(CodedInputStream strm, int nField, FieldDescriptor fd)
1549 {
1550 m_nField = nField;
1551 m_fd = fd;
1552 m_strm = strm;
1553 }
1554
1555 public bool Load(WireFormat.WireType wireType)
1556 {
1557 m_wireType = wireType;
1558
1559 switch (wireType)
1560 {
1561 case WireFormat.WireType.Varint:
1562 m_lVal = m_strm.ReadInt32();
1563 m_nVal = (int)m_lVal;
1564 m_type = TYPE.BIT32;
1565 break;
1566
1567 case WireFormat.WireType.LengthDelimited:
1568 ByteString bs = m_strm.ReadBytes();
1569 if (bs.Length > 0)
1570 {
1571 m_rgBytes = bs.ToByteArray();
1572
1573 if (m_fd == null || m_fd.Type == FieldDescriptor.TYPE.STRING)
1574 m_strVal = getString(m_rgBytes, out m_type);
1575
1576 if (m_type == TYPE.BYTES && m_fd != null && m_fd.Type != FieldDescriptor.TYPE.FIELDDESC)
1577 {
1578 switch (m_fd.Type)
1579 {
1580 case FieldDescriptor.TYPE.INT:
1581 case FieldDescriptor.TYPE.UINT:
1582 m_rgnVal = readIntArray(m_rgBytes, m_fd.Type);
1583 m_type = TYPE.INT_ARRAY;
1584 break;
1585
1586 case FieldDescriptor.TYPE.LONG:
1587 case FieldDescriptor.TYPE.ULONG:
1588 m_rglVal = readLongArray(m_rgBytes, m_fd.Type);
1589 m_type = TYPE.LONG_ARRAY;
1590 break;
1591
1592 case FieldDescriptor.TYPE.FLOAT:
1593 m_rgfVal = readFloatArray(m_rgBytes);
1594 m_type = TYPE.FLOAT_ARRAY;
1595 break;
1596
1597 case FieldDescriptor.TYPE.DOUBLE:
1598 m_rgdfVal = readDoubleArray(m_rgBytes);
1599 m_type = TYPE.DOUBLE_ARRAY;
1600 break;
1601 }
1602 }
1603 }
1604 break;
1605
1606 case WireFormat.WireType.Fixed32:
1607 float fVal = m_strm.ReadFloat();
1608 m_nVal = (int)fVal;
1609 m_fVal = (float)fVal;
1610 m_type = TYPE.BIT32;
1611 break;
1612
1613 case WireFormat.WireType.Fixed64:
1614 double dfVal = m_strm.ReadDouble();
1615 m_lVal = (long)dfVal;
1616 m_dfVal = (double)dfVal;
1617 m_type = TYPE.BIT64;
1618 break;
1619
1620 default:
1621 return false;
1622 }
1623
1624 return true;
1625 }
1626
1627 private int[] readIntArray(byte[] rgBytes, FieldDescriptor.TYPE type)
1628 {
1629 CodedInputStream strm = new CodedInputStream(rgBytes);
1630 List<int> rg = new List<int>();
1631
1632 while (!strm.IsAtEnd)
1633 {
1634 int lVal = (type == FieldDescriptor.TYPE.INT) ? (int)strm.ReadInt32() : (int)strm.ReadUInt32();
1635 rg.Add(lVal);
1636 }
1637
1638 return rg.ToArray();
1639 }
1640
1641 private long[] readLongArray(byte[] rgBytes, FieldDescriptor.TYPE type)
1642 {
1643 CodedInputStream strm = new CodedInputStream(rgBytes);
1644 List<long> rg = new List<long>();
1645
1646 while (!strm.IsAtEnd)
1647 {
1648 long lVal = (type == FieldDescriptor.TYPE.LONG) ? (long)strm.ReadInt64() : (long)strm.ReadUInt64();
1649 rg.Add(lVal);
1650 }
1651
1652 return rg.ToArray();
1653 }
1654
1655 private float[] readFloatArray(byte[] rgBytes)
1656 {
1657 int nCount = rgBytes.Length / sizeof(float);
1658 int nErr = rgBytes.Length % sizeof(float);
1659
1660 if (nErr != 0)
1661 throw new Exception("Invalid " + m_fd.Type.ToString() + " data - not aligned.");
1662
1663 CodedInputStream strm = new CodedInputStream(rgBytes);
1664 float[] rg = new float[nCount];
1665
1666 for (int i = 0; i < nCount; i++)
1667 {
1668 rg[i] = strm.ReadFloat();
1669 }
1670
1671 return rg;
1672 }
1673
1674 private double[] readDoubleArray(byte[] rgBytes)
1675 {
1676 int nCount = rgBytes.Length / sizeof(double);
1677 int nErr = rgBytes.Length % sizeof(double);
1678
1679 if (nErr != 0)
1680 throw new Exception("Invalid " + m_fd.Type.ToString() + " data - not aligned.");
1681
1682 CodedInputStream strm = new CodedInputStream(rgBytes);
1683 double[] rg = new double[nCount];
1684
1685 for (int i = 0; i < nCount; i++)
1686 {
1687 rg[i] = strm.ReadDouble();
1688 }
1689
1690 return rg;
1691 }
1692
1693 public void LoadSubFields(int nDepth = 0, int nMaxDepth = int.MaxValue, List<KeyValuePair<int, string>> rgIgnore = null)
1694 {
1695 ProtoBufFieldCollection col = null;
1696
1697 if (m_type == TYPE.BYTES)
1698 {
1699 ProtoBufReader reader = new common.ProtoBufReader(m_rgBytes);
1700 col = reader.ReadFields(m_fd, false);
1701 m_col = col;
1702 m_type = TYPE.ARRAY;
1703 }
1704 else if (m_type == TYPE.ARRAY)
1705 {
1706 col = m_col;
1707 }
1708
1709 if (col != null && col.Count > 0)
1710 {
1711 nDepth += 1;
1712
1713 if (nDepth < nMaxDepth)
1714 {
1715 if (rgIgnore != null)
1716 {
1717 foreach (KeyValuePair<int, string> kv in rgIgnore)
1718 {
1719 if (kv.Key <= m_col.Count &&
1720 m_col[kv.Key].Type == TYPE.STRING &&
1721 m_col[kv.Key].StringValue == kv.Value)
1722 return;
1723 }
1724 }
1725
1726 foreach (ProtoBufField field in m_col)
1727 {
1728 field.LoadSubFields(nDepth, nMaxDepth);
1729 }
1730 }
1731 }
1732 }
1733
1734 private string getString(byte[] rg, out TYPE type)
1735 {
1736 string strOut = null;
1737
1738 type = TYPE.BYTES;
1739
1740 for (int i = 0; i < rg.Length; i++)
1741 {
1742 char ch = (char)rg[i];
1743 if (char.IsControl(ch))
1744 return null;
1745
1746 strOut += ch;
1747 }
1748
1749 type = TYPE.STRING;
1750
1751 return strOut;
1752 }
1753
1754 private byte[] getBytes(string str, out TYPE type)
1755 {
1756 byte[] rg = new byte[str.Length];
1757
1758 type = TYPE.STRING;
1759
1760 for (int i = 0; i < str.Length; i++)
1761 {
1762 rg[i] = (byte)str[i];
1763
1764 if (char.IsControl(str[i]))
1765 type = TYPE.BYTES;
1766 }
1767
1768 return rg;
1769 }
1770
1771 public bool Legacy
1772 {
1773 get { return m_bLegacy; }
1774 set { m_bLegacy = value; }
1775 }
1776
1777 public string Tag
1778 {
1779 get { return m_strTag; }
1780 set { m_strTag = value; }
1781 }
1782
1783 public byte[] Bytes
1784 {
1785 get { return m_rgBytes; }
1786 }
1787
1788 public int Length
1789 {
1790 get { return (m_rgBytes == null) ? 0 : m_rgBytes.Length; }
1791 }
1792
1793 public TYPE Type
1794 {
1795 get { return m_type; }
1796 }
1797
1798 public string StringValue
1799 {
1800 get { return m_strVal; }
1801 }
1802
1803 public long LongValue
1804 {
1805 get { return m_lVal; }
1806 }
1807
1808 public long[] LongValues
1809 {
1810 get { return m_rglVal; }
1811 }
1812
1813 public int IntValue
1814 {
1815 get { return m_nVal; }
1816 }
1817
1818 public int[] IntValues
1819 {
1820 get { return m_rgnVal; }
1821 }
1822
1823 public float FloatValue
1824 {
1825 get { return m_fVal; }
1826 }
1827
1828 public float[] FloatValues
1829 {
1830 get { return m_rgfVal; }
1831 }
1832
1833 public double DoubleValue
1834 {
1835 get { return m_dfVal; }
1836 }
1837
1838 public double[] DoubleValues
1839 {
1840 get { return m_rgdfVal; }
1841 }
1842
1843 public ProtoBufFieldCollection Array
1844 {
1845 get { return m_col; }
1846 }
1847
1848 public int FieldId
1849 {
1850 get { return m_nField; }
1851 }
1852
1853 public FieldDescriptor FieldDesc
1854 {
1855 get { return m_fd; }
1856 }
1857
1858 public override string ToString()
1859 {
1860 string strName = (m_fd == null) ? "NO FLDESC!" : m_fd.Name;
1861 string str = strName + "(" + m_nField.ToString() + ")[" + m_wireType.ToString() + "] " + m_type.ToString() + ": ";
1862
1863 if (m_type == TYPE.STRING)
1864 return str + m_strVal;
1865
1866 if (m_type == TYPE.BIT32)
1867 return str + m_nVal.ToString() + " (float = " + m_fVal.ToString() + ")";
1868
1869 if (m_type == TYPE.BIT64)
1870 return str + m_lVal.ToString() + " (double = " + m_dfVal.ToString() + ")";
1871
1872 if (m_type == TYPE.ARRAY)
1873 return str + " Count = " + m_col.Count.ToString();
1874
1875 return str + " bytes = " + ((m_rgBytes == null) ? "0" : m_rgBytes.Length.ToString());
1876 }
1877 }
1878
1879#pragma warning disable 1591
1880
1881 public class FieldDescriptor
1882 {
1883 List<FieldDescriptor> m_rgChildren = new List<FieldDescriptor>();
1884 int m_nFieldID = 0;
1885 string m_strName = "";
1886 TYPE m_type = TYPE.UNKNOWN;
1887
1888 public enum TYPE
1889 {
1890 UNKNOWN,
1891 STRING,
1892 BOOL,
1893 INT,
1894 UINT,
1895 LONG,
1896 ULONG,
1897 FLOAT,
1898 DOUBLE,
1899 FIELDDESC
1900 }
1901
1902 public FieldDescriptor(int nField, string strName, TYPE type, List<FieldDescriptor> rgChildren = null)
1903 {
1904 m_nFieldID = nField;
1905 m_strName = strName;
1906 m_type = type;
1907
1908 if (rgChildren != null)
1909 m_rgChildren = rgChildren;
1910 }
1911
1912 public FieldDescriptor FindFirstChild(int nFieldId)
1913 {
1914 foreach (FieldDescriptor fd in m_rgChildren)
1915 {
1916 if (fd.FieldId == nFieldId)
1917 return fd;
1918 }
1919
1920 return null;
1921 }
1922
1923 public FieldDescriptor FindFirstChild(string strName)
1924 {
1925 foreach (FieldDescriptor fd in m_rgChildren)
1926 {
1927 if (fd.Name == strName)
1928 return fd;
1929 }
1930
1931 return null;
1932 }
1933
1934 public int FieldId
1935 {
1936 get { return m_nFieldID; }
1937 }
1938
1939 public string Name
1940 {
1941 get { return m_strName; }
1942 }
1943
1944 public TYPE Type
1945 {
1946 get { return m_type; }
1947 }
1948
1949 public List<FieldDescriptor> Children
1950 {
1951 get { return m_rgChildren; }
1952 }
1953
1954 public override string ToString()
1955 {
1956 return m_strName + " (" + m_nFieldID.ToString() + ") - " + m_type.ToString();
1957 }
1958
1959 public static FieldDescriptor CreateSolverStateFieldDesc()
1960 {
1961 return new common.FieldDescriptor(0, "SolverState", TYPE.FIELDDESC, loadSolverState());
1962 }
1963
1964 public static FieldDescriptor CreateNetworkParamFieldDesc()
1965 {
1966 return new common.FieldDescriptor(0, "NetParameter", TYPE.FIELDDESC, loadNetParameter());
1967 }
1968
1969 public static FieldDescriptor CreateBlobProtoDesc(int nFieldId)
1970 {
1971 return new FieldDescriptor(nFieldId, "BlobProto", TYPE.FIELDDESC, loadBlobProto());
1972 }
1973
1974 private static List<FieldDescriptor> loadSolverState()
1975 {
1976 List<FieldDescriptor> rgF = new List<common.FieldDescriptor>();
1977 rgF.Add(new FieldDescriptor(1, "iter", TYPE.INT));
1978 rgF.Add(new FieldDescriptor(3, "history", TYPE.FIELDDESC, loadBlobProto()));
1979 rgF.Add(new FieldDescriptor(4, "current_step", TYPE.INT));
1980 return rgF;
1981 }
1982
1983 private static List<FieldDescriptor> loadNetParameter()
1984 {
1985 List<FieldDescriptor> rgF = new List<common.FieldDescriptor>();
1986 rgF.Add(new FieldDescriptor(1, "name", TYPE.STRING));
1987 rgF.Add(new FieldDescriptor(100, "LayerParameter", TYPE.FIELDDESC, loadLayerParameter()));
1988 rgF.Add(new FieldDescriptor(2, "V1LayerParameter", TYPE.FIELDDESC, loadV1LayerParameter()));
1989 return rgF;
1990 }
1991
1992 private static List<FieldDescriptor> loadLayerParameter()
1993 {
1994 List<FieldDescriptor> rgF = new List<common.FieldDescriptor>();
1995 rgF.Add(new FieldDescriptor(1, "name", FieldDescriptor.TYPE.STRING));
1996 rgF.Add(new FieldDescriptor(2, "type", FieldDescriptor.TYPE.STRING));
1997 rgF.Add(new FieldDescriptor(3, "bottom", FieldDescriptor.TYPE.STRING));
1998 rgF.Add(new FieldDescriptor(4, "top", FieldDescriptor.TYPE.STRING));
1999 rgF.Add(new FieldDescriptor(10, "phase", FieldDescriptor.TYPE.INT));
2000 rgF.Add(new FieldDescriptor(5, "loss_weight", FieldDescriptor.TYPE.FLOAT));
2001 rgF.Add(new FieldDescriptor(6, "param", FieldDescriptor.TYPE.FIELDDESC, loadParamSpec()));
2002 rgF.Add(new FieldDescriptor(7, "blobs", FieldDescriptor.TYPE.FIELDDESC, loadBlobProto()));
2003 rgF.Add(new FieldDescriptor(11, "prop_down", FieldDescriptor.TYPE.BOOL));
2004 rgF.Add(new FieldDescriptor(8, "include", FieldDescriptor.TYPE.FIELDDESC, loadNetStateRule()));
2005 rgF.Add(new FieldDescriptor(9, "exclude", FieldDescriptor.TYPE.FIELDDESC, loadNetStateRule()));
2006 rgF.Add(new FieldDescriptor(100, LayerParameter.LayerType.TRANSFORM.ToString() + "_param", FieldDescriptor.TYPE.FIELDDESC));
2007 rgF.Add(new FieldDescriptor(101, LayerParameter.LayerType.LOSS.ToString() + "_param", FieldDescriptor.TYPE.FIELDDESC));
2008
2009 rgF.Add(new FieldDescriptor(102, LayerParameter.LayerType.ACCURACY.ToString() + "_param", FieldDescriptor.TYPE.FIELDDESC));
2010 rgF.Add(new FieldDescriptor(103, LayerParameter.LayerType.ARGMAX.ToString() + "_param", FieldDescriptor.TYPE.FIELDDESC));
2011 rgF.Add(new FieldDescriptor(139, LayerParameter.LayerType.BATCHNORM.ToString() + "_param", FieldDescriptor.TYPE.FIELDDESC));
2012 rgF.Add(new FieldDescriptor(141, LayerParameter.LayerType.BIAS.ToString() + "_param", FieldDescriptor.TYPE.FIELDDESC));
2013 rgF.Add(new FieldDescriptor(104, LayerParameter.LayerType.CONCAT.ToString() + "_param", FieldDescriptor.TYPE.FIELDDESC));
2014 rgF.Add(new FieldDescriptor(105, LayerParameter.LayerType.CONTRASTIVE_LOSS.ToString() + "_param", FieldDescriptor.TYPE.FIELDDESC));
2015 rgF.Add(new FieldDescriptor(106, LayerParameter.LayerType.CONVOLUTION.ToString() + "_param", FieldDescriptor.TYPE.FIELDDESC, loadConvolutionParam()));
2016 rgF.Add(new FieldDescriptor(144, LayerParameter.LayerType.CROP.ToString() + "_param", FieldDescriptor.TYPE.FIELDDESC));
2017 rgF.Add(new FieldDescriptor(107, LayerParameter.LayerType.DATA.ToString() + "_param", FieldDescriptor.TYPE.FIELDDESC));
2018 rgF.Add(new FieldDescriptor(108, LayerParameter.LayerType.DROPOUT.ToString() + "_param", FieldDescriptor.TYPE.FIELDDESC));
2019 rgF.Add(new FieldDescriptor(109, LayerParameter.LayerType.DUMMYDATA.ToString() + "_param", FieldDescriptor.TYPE.FIELDDESC));
2020 rgF.Add(new FieldDescriptor(110, LayerParameter.LayerType.ELTWISE.ToString() + "_param", FieldDescriptor.TYPE.FIELDDESC));
2021 rgF.Add(new FieldDescriptor(140, LayerParameter.LayerType.ELU.ToString() + "_param", FieldDescriptor.TYPE.FIELDDESC));
2022 rgF.Add(new FieldDescriptor(137, LayerParameter.LayerType.EMBED.ToString() + "_param", FieldDescriptor.TYPE.FIELDDESC));
2023 rgF.Add(new FieldDescriptor(111, LayerParameter.LayerType.EXP.ToString() + "_param", FieldDescriptor.TYPE.FIELDDESC));
2024 rgF.Add(new FieldDescriptor(135, LayerParameter.LayerType.FLATTEN.ToString() + "_param", FieldDescriptor.TYPE.FIELDDESC));
2025 rgF.Add(new FieldDescriptor(112, "hdf5_input_param", FieldDescriptor.TYPE.FIELDDESC));
2026 rgF.Add(new FieldDescriptor(113, "hdf5_output_param", FieldDescriptor.TYPE.FIELDDESC));
2027 rgF.Add(new FieldDescriptor(114, LayerParameter.LayerType.HINGE_LOSS.ToString() + "_param", FieldDescriptor.TYPE.FIELDDESC));
2028 rgF.Add(new FieldDescriptor(115, "image_data_param", FieldDescriptor.TYPE.FIELDDESC));
2029 rgF.Add(new FieldDescriptor(116, LayerParameter.LayerType.INFOGAIN_LOSS.ToString() + "_param", FieldDescriptor.TYPE.FIELDDESC));
2030 rgF.Add(new FieldDescriptor(117, LayerParameter.LayerType.INNERPRODUCT.ToString() + "_param", FieldDescriptor.TYPE.FIELDDESC));
2031 rgF.Add(new FieldDescriptor(143, LayerParameter.LayerType.INPUT.ToString() + "_param", FieldDescriptor.TYPE.FIELDDESC));
2032 rgF.Add(new FieldDescriptor(134, LayerParameter.LayerType.LOG.ToString() + "_param", FieldDescriptor.TYPE.FIELDDESC));
2033 rgF.Add(new FieldDescriptor(118, LayerParameter.LayerType.LRN.ToString() + "_param", FieldDescriptor.TYPE.FIELDDESC));
2034 rgF.Add(new FieldDescriptor(119, LayerParameter.LayerType.MEMORYDATA.ToString() + "_param", FieldDescriptor.TYPE.FIELDDESC));
2035 rgF.Add(new FieldDescriptor(120, LayerParameter.LayerType.MVN.ToString() + "_param", FieldDescriptor.TYPE.FIELDDESC));
2036 rgF.Add(new FieldDescriptor(121, LayerParameter.LayerType.PARAMETER.ToString() + "_param", FieldDescriptor.TYPE.FIELDDESC));
2037 rgF.Add(new FieldDescriptor(121, LayerParameter.LayerType.POOLING.ToString() + "_param", FieldDescriptor.TYPE.FIELDDESC));
2038 rgF.Add(new FieldDescriptor(122, LayerParameter.LayerType.POWER.ToString() + "_param", FieldDescriptor.TYPE.FIELDDESC));
2039 rgF.Add(new FieldDescriptor(131, LayerParameter.LayerType.PRELU.ToString() + "_param", FieldDescriptor.TYPE.FIELDDESC));
2040 rgF.Add(new FieldDescriptor(130, "python_param", FieldDescriptor.TYPE.FIELDDESC));
2041 rgF.Add(new FieldDescriptor(146, LayerParameter.LayerType.RECURRENT.ToString() + "_param", FieldDescriptor.TYPE.FIELDDESC));
2042 rgF.Add(new FieldDescriptor(136, LayerParameter.LayerType.REDUCTION.ToString() + "_param", FieldDescriptor.TYPE.FIELDDESC));
2043 rgF.Add(new FieldDescriptor(123, LayerParameter.LayerType.RELU.ToString() + "_param", FieldDescriptor.TYPE.FIELDDESC));
2044 rgF.Add(new FieldDescriptor(133, LayerParameter.LayerType.RESHAPE.ToString() + "_param", FieldDescriptor.TYPE.FIELDDESC));
2045 rgF.Add(new FieldDescriptor(142, LayerParameter.LayerType.SCALE.ToString() + "_param", FieldDescriptor.TYPE.FIELDDESC));
2046 rgF.Add(new FieldDescriptor(142, LayerParameter.LayerType.SCALAR.ToString() + "_param", FieldDescriptor.TYPE.FIELDDESC));
2047 rgF.Add(new FieldDescriptor(124, LayerParameter.LayerType.SIGMOID.ToString() + "_param", FieldDescriptor.TYPE.FIELDDESC));
2048 rgF.Add(new FieldDescriptor(125, LayerParameter.LayerType.SOFTMAX.ToString() + "_param", FieldDescriptor.TYPE.FIELDDESC));
2049 rgF.Add(new FieldDescriptor(132, LayerParameter.LayerType.SPP.ToString() + "_param", FieldDescriptor.TYPE.FIELDDESC));
2050 rgF.Add(new FieldDescriptor(126, LayerParameter.LayerType.SLICE.ToString() + "_param", FieldDescriptor.TYPE.FIELDDESC));
2051 rgF.Add(new FieldDescriptor(127, LayerParameter.LayerType.TANH.ToString() + "_param", FieldDescriptor.TYPE.FIELDDESC));
2052 rgF.Add(new FieldDescriptor(128, LayerParameter.LayerType.THRESHOLD.ToString() + "_param", FieldDescriptor.TYPE.FIELDDESC));
2053 rgF.Add(new FieldDescriptor(138, LayerParameter.LayerType.TILE.ToString() + "_param", FieldDescriptor.TYPE.FIELDDESC));
2054 rgF.Add(new FieldDescriptor(129, "window_data_param", FieldDescriptor.TYPE.FIELDDESC));
2055
2056 return rgF;
2057 }
2058
2059 private static List<FieldDescriptor> loadV1LayerParameter()
2060 {
2061 List<FieldDescriptor> rgF = new List<common.FieldDescriptor>();
2062 rgF.Add(new FieldDescriptor(2, "bottom", FieldDescriptor.TYPE.STRING));
2063 rgF.Add(new FieldDescriptor(3, "top", FieldDescriptor.TYPE.STRING));
2064 rgF.Add(new FieldDescriptor(4, "name", FieldDescriptor.TYPE.STRING));
2065 rgF.Add(new FieldDescriptor(32, "include", FieldDescriptor.TYPE.FIELDDESC, loadNetStateRule()));
2066 rgF.Add(new FieldDescriptor(33, "exclude", FieldDescriptor.TYPE.FIELDDESC, loadNetStateRule()));
2067 rgF.Add(new FieldDescriptor(5, "type", FieldDescriptor.TYPE.INT));
2068 rgF.Add(new FieldDescriptor(6, "blobs", FieldDescriptor.TYPE.FIELDDESC, loadBlobProto()));
2069 rgF.Add(new FieldDescriptor(1001, "param", FieldDescriptor.TYPE.FIELDDESC, loadParamSpec()));
2070 rgF.Add(new FieldDescriptor(1002, "blob_share_mode", FieldDescriptor.TYPE.INT));
2071 rgF.Add(new FieldDescriptor(7, "blobs_lr", FieldDescriptor.TYPE.FLOAT));
2072 rgF.Add(new FieldDescriptor(8, "weight_decay", FieldDescriptor.TYPE.FLOAT));
2073 rgF.Add(new FieldDescriptor(35, "loss_weight", FieldDescriptor.TYPE.FLOAT));
2074
2075 rgF.Add(new FieldDescriptor(27, LayerParameter.LayerType.ACCURACY.ToString() + "_param", FieldDescriptor.TYPE.FIELDDESC));
2076 rgF.Add(new FieldDescriptor(23, LayerParameter.LayerType.ARGMAX.ToString() + "_param", FieldDescriptor.TYPE.FIELDDESC));
2077 rgF.Add(new FieldDescriptor(9, LayerParameter.LayerType.CONCAT.ToString() + "_param", FieldDescriptor.TYPE.FIELDDESC));
2078 rgF.Add(new FieldDescriptor(40, LayerParameter.LayerType.CONTRASTIVE_LOSS.ToString() + "_param", FieldDescriptor.TYPE.FIELDDESC));
2079 rgF.Add(new FieldDescriptor(10, LayerParameter.LayerType.CONVOLUTION.ToString() + "_param", FieldDescriptor.TYPE.FIELDDESC, loadConvolutionParam()));
2080 rgF.Add(new FieldDescriptor(11, LayerParameter.LayerType.DATA.ToString() + "_param", FieldDescriptor.TYPE.FIELDDESC));
2081 rgF.Add(new FieldDescriptor(12, LayerParameter.LayerType.DROPOUT.ToString() + "_param", FieldDescriptor.TYPE.FIELDDESC));
2082 rgF.Add(new FieldDescriptor(26, LayerParameter.LayerType.DUMMYDATA.ToString() + "_param", FieldDescriptor.TYPE.FIELDDESC));
2083 rgF.Add(new FieldDescriptor(24, LayerParameter.LayerType.ELTWISE.ToString() + "_param", FieldDescriptor.TYPE.FIELDDESC));
2084 rgF.Add(new FieldDescriptor(41, LayerParameter.LayerType.EXP.ToString() + "_param", FieldDescriptor.TYPE.FIELDDESC));
2085 rgF.Add(new FieldDescriptor(13, "hdf5_input_param", FieldDescriptor.TYPE.FIELDDESC));
2086 rgF.Add(new FieldDescriptor(14, "hdf5_output_param", FieldDescriptor.TYPE.FIELDDESC));
2087 rgF.Add(new FieldDescriptor(29, LayerParameter.LayerType.HINGE_LOSS.ToString() + "_param", FieldDescriptor.TYPE.FIELDDESC));
2088 rgF.Add(new FieldDescriptor(15, "image_data_param", FieldDescriptor.TYPE.FIELDDESC));
2089 rgF.Add(new FieldDescriptor(16, LayerParameter.LayerType.INFOGAIN_LOSS.ToString() + "_param", FieldDescriptor.TYPE.FIELDDESC));
2090 rgF.Add(new FieldDescriptor(17, LayerParameter.LayerType.INNERPRODUCT.ToString() + "_param", FieldDescriptor.TYPE.FIELDDESC));
2091 rgF.Add(new FieldDescriptor(18, LayerParameter.LayerType.LRN.ToString() + "_param", FieldDescriptor.TYPE.FIELDDESC));
2092 rgF.Add(new FieldDescriptor(22, LayerParameter.LayerType.MEMORYDATA.ToString() + "_param", FieldDescriptor.TYPE.FIELDDESC));
2093 rgF.Add(new FieldDescriptor(34, LayerParameter.LayerType.MVN.ToString() + "_param", FieldDescriptor.TYPE.FIELDDESC));
2094 rgF.Add(new FieldDescriptor(19, LayerParameter.LayerType.POOLING.ToString() + "_param", FieldDescriptor.TYPE.FIELDDESC));
2095 rgF.Add(new FieldDescriptor(21, LayerParameter.LayerType.POWER.ToString() + "_param", FieldDescriptor.TYPE.FIELDDESC));
2096 rgF.Add(new FieldDescriptor(30, LayerParameter.LayerType.RELU.ToString() + "_param", FieldDescriptor.TYPE.FIELDDESC));
2097 rgF.Add(new FieldDescriptor(38, LayerParameter.LayerType.SIGMOID.ToString() + "_param", FieldDescriptor.TYPE.FIELDDESC));
2098 rgF.Add(new FieldDescriptor(39, LayerParameter.LayerType.SOFTMAX.ToString() + "_param", FieldDescriptor.TYPE.FIELDDESC));
2099 rgF.Add(new FieldDescriptor(31, LayerParameter.LayerType.SLICE.ToString() + "_param", FieldDescriptor.TYPE.FIELDDESC));
2100 rgF.Add(new FieldDescriptor(37, LayerParameter.LayerType.TANH.ToString() + "_param", FieldDescriptor.TYPE.FIELDDESC));
2101 rgF.Add(new FieldDescriptor(25, LayerParameter.LayerType.THRESHOLD.ToString() + "_param", FieldDescriptor.TYPE.FIELDDESC));
2102 rgF.Add(new FieldDescriptor(20, "window_data_param", FieldDescriptor.TYPE.FIELDDESC));
2103 rgF.Add(new FieldDescriptor(36, LayerParameter.LayerType.TRANSFORM.ToString() + "_param", FieldDescriptor.TYPE.FIELDDESC));
2104 rgF.Add(new FieldDescriptor(42, LayerParameter.LayerType.LOSS.ToString() + "_param", FieldDescriptor.TYPE.FIELDDESC));
2105
2106 return rgF;
2107 }
2108
2109 private static List<FieldDescriptor> loadParamSpec()
2110 {
2111 List<FieldDescriptor> rgF = new List<common.FieldDescriptor>();
2112 rgF.Add(new FieldDescriptor(1, "name", FieldDescriptor.TYPE.STRING));
2113 rgF.Add(new FieldDescriptor(2, "share_mode", FieldDescriptor.TYPE.INT));
2114 rgF.Add(new FieldDescriptor(3, "lr_mult", FieldDescriptor.TYPE.FLOAT));
2115 rgF.Add(new FieldDescriptor(4, "decay_mult", FieldDescriptor.TYPE.FLOAT));
2116 return rgF;
2117 }
2118
2119 private static List<FieldDescriptor> loadBlobShape()
2120 {
2121 List<FieldDescriptor> rgF = new List<common.FieldDescriptor>();
2122 rgF.Add(new FieldDescriptor(1, "dim", FieldDescriptor.TYPE.LONG));
2123 return rgF;
2124 }
2125
2126 private static List<FieldDescriptor> loadBlobProto()
2127 {
2128 List<FieldDescriptor> rgF = new List<common.FieldDescriptor>();
2129 rgF.Add(new FieldDescriptor(7, "shape", FieldDescriptor.TYPE.FIELDDESC, loadBlobShape()));
2130 rgF.Add(new FieldDescriptor(5, "data", FieldDescriptor.TYPE.FLOAT));
2131 rgF.Add(new FieldDescriptor(6, "diff", FieldDescriptor.TYPE.FLOAT));
2132 rgF.Add(new FieldDescriptor(8, "double_data", FieldDescriptor.TYPE.DOUBLE));
2133 rgF.Add(new FieldDescriptor(9, "double_diff", FieldDescriptor.TYPE.DOUBLE));
2134 rgF.Add(new FieldDescriptor(1, "num", FieldDescriptor.TYPE.INT));
2135 rgF.Add(new FieldDescriptor(2, "channels", FieldDescriptor.TYPE.INT));
2136 rgF.Add(new FieldDescriptor(3, "height", FieldDescriptor.TYPE.INT));
2137 rgF.Add(new FieldDescriptor(4, "width", FieldDescriptor.TYPE.INT));
2138 return rgF;
2139 }
2140
2141 private static List<FieldDescriptor> loadNetStateRule()
2142 {
2143 List<FieldDescriptor> rgF = new List<common.FieldDescriptor>();
2144 rgF.Add(new FieldDescriptor(1, "phase", FieldDescriptor.TYPE.INT));
2145 rgF.Add(new FieldDescriptor(2, "min_level", FieldDescriptor.TYPE.INT));
2146 rgF.Add(new FieldDescriptor(3, "max_level", FieldDescriptor.TYPE.INT));
2147 rgF.Add(new FieldDescriptor(4, "stage", FieldDescriptor.TYPE.STRING));
2148 rgF.Add(new FieldDescriptor(5, "not_stage", FieldDescriptor.TYPE.STRING));
2149 return rgF;
2150 }
2151
2152 private static List<FieldDescriptor> loadFillerParam()
2153 {
2154 List<FieldDescriptor> rgF = new List<common.FieldDescriptor>();
2155 rgF.Add(new FieldDescriptor(1, "type", FieldDescriptor.TYPE.STRING));
2156 rgF.Add(new FieldDescriptor(2, "value", FieldDescriptor.TYPE.FLOAT));
2157 rgF.Add(new FieldDescriptor(3, "min", FieldDescriptor.TYPE.FLOAT));
2158 rgF.Add(new FieldDescriptor(4, "max", FieldDescriptor.TYPE.FLOAT));
2159 rgF.Add(new FieldDescriptor(5, "mean", FieldDescriptor.TYPE.FLOAT));
2160 rgF.Add(new FieldDescriptor(6, "std", FieldDescriptor.TYPE.FLOAT));
2161 rgF.Add(new FieldDescriptor(7, "sparse", FieldDescriptor.TYPE.INT));
2162 rgF.Add(new FieldDescriptor(8, "variance_norm", FieldDescriptor.TYPE.INT));
2163 return rgF;
2164 }
2165
2166 private static List<FieldDescriptor> loadConvolutionParam()
2167 {
2168 List<FieldDescriptor> rgF = new List<common.FieldDescriptor>();
2169 rgF.Add(new FieldDescriptor(1, "num_output", FieldDescriptor.TYPE.UINT));
2170 rgF.Add(new FieldDescriptor(2, "bias_term", FieldDescriptor.TYPE.BOOL));
2171 rgF.Add(new FieldDescriptor(3, "pad", FieldDescriptor.TYPE.UINT));
2172 rgF.Add(new FieldDescriptor(4, "kernel_size", FieldDescriptor.TYPE.UINT));
2173 rgF.Add(new FieldDescriptor(6, "stride", FieldDescriptor.TYPE.UINT));
2174 rgF.Add(new FieldDescriptor(18, "dilation", FieldDescriptor.TYPE.UINT));
2175 rgF.Add(new FieldDescriptor(9, "pad_h", FieldDescriptor.TYPE.UINT));
2176 rgF.Add(new FieldDescriptor(10, "pad_w", FieldDescriptor.TYPE.UINT));
2177 rgF.Add(new FieldDescriptor(11, "kernel_h", FieldDescriptor.TYPE.UINT));
2178 rgF.Add(new FieldDescriptor(12, "kernel_w", FieldDescriptor.TYPE.UINT));
2179 rgF.Add(new FieldDescriptor(13, "stride_h", FieldDescriptor.TYPE.UINT));
2180 rgF.Add(new FieldDescriptor(14, "stride_w", FieldDescriptor.TYPE.UINT));
2181 rgF.Add(new FieldDescriptor(5, "group", FieldDescriptor.TYPE.UINT));
2182 rgF.Add(new FieldDescriptor(7, "weight_filler", FieldDescriptor.TYPE.FIELDDESC, loadFillerParam()));
2183 rgF.Add(new FieldDescriptor(8, "bias_filler", FieldDescriptor.TYPE.FIELDDESC, loadFillerParam()));
2184 rgF.Add(new FieldDescriptor(15, "engine", FieldDescriptor.TYPE.INT));
2185 rgF.Add(new FieldDescriptor(16, "axis", FieldDescriptor.TYPE.INT));
2186 rgF.Add(new FieldDescriptor(17, "force_nd", FieldDescriptor.TYPE.BOOL));
2187 return rgF;
2188 }
2189 }
2190
2191#pragma warning restore 1591
2192}
The Log class provides general output in text form.
Definition: Log.cs:13
The Utility class provides general utility funtions.
Definition: Utility.cs:35
static double[] ConvertVec(float[] rgf)
Convert an array of float to an array of generics.
Definition: Utility.cs:550
The BlobCollection contains a list of Blobs.
The Blob is the main holder of data that moves through the Layers of the Net.
Definition: Blob.cs:25
object Tag
Returns a user defined object associated with the Blob.
Definition: Blob.cs:2770
bool reshape_when_sharing
When true, this Blob is reshaped to the source when sharing the source data (default = false).
Definition: Blob.cs:1803
string Name
Get/set the name of the Blob.
Definition: Blob.cs:2184
The PersistCaffe class is used to load and save weight files in the .caffemodel format.
Definition: PersistCaffe.cs:20
PersistCaffe(Log log, bool bFailOnFirstTry)
The PersistCaffe constructor.
Definition: PersistCaffe.cs:30
BlobProto LoadBlobProto(string strFile, int nFieldId)
The LoadBlobProto function loads a BlobProto from a file.
SolverState LoadSolverState(byte[] rgState, SolverParameter.SolverType type=SolverParameter.SolverType.SGD)
Load the solver state from a byte array.
BlobCollection< T > LoadWeights(byte[] rgWeights, List< string > rgExpectedShapes, BlobCollection< T > colBlobs, bool bSizeToFit, out bool bLoadedDiffs, List< string > inputWtInfo=null, List< string > targetWtInfo=null, string strSkipBlobType=null)
Loads new weights into a BlobCollection
WeightInfo< T > LoadWeightInfo(BlobCollection< T > colBlobs)
Returns the weight information describing the weights containined within the Blob collection.
string MyCaffeTag
This tag is used to mark the ending section of each weighting file with 'MyCaffe' specific informatio...
Definition: PersistCaffe.cs:41
BlobProto LoadBlobProto(byte[] rg, int nFieldId)
The LoadBlobProto function loads a BlobProto from a proto buffer.
byte[] SaveWeights(BlobCollection< T > colBlobs, bool bSaveDiffs=false)
Save the weights to a byte array.
bool IsMyCaffe(byte[] rgWeights, out string strVer)
This method returns whether or not the weights have been marked as 'mycaffe.ai'.
Definition: PersistCaffe.cs:51
byte[] SaveSolverState(SolverState state, SolverParameter.SolverType type=SolverParameter.SolverType.SGD)
Save the solver state to a byte array.
Definition: PersistCaffe.cs:75
WeightInfo< T > LoadWeightInfo(byte[] rgWeights)
Returns the weight information describing the weights containined within the weight bytes.
The WeightInfo class describes the weights of a given weight set including the blob names and sizes o...
Definition: WeightInfo.cs:15
WeightInfo()
The constructor.
Definition: WeightInfo.cs:22
The BlobProto contains the descripion of a blob.
Definition: BlobProto.cs:15
List< float > data
Get/set the data as a List of float.
Definition: BlobProto.cs:180
BlobShape shape
Specifies the shape of the Blob.
Definition: BlobProto.cs:117
List< double > double_data
Get/set the data as a List of double.
Definition: BlobProto.cs:162
BlobProto()
Constructor for the BlobProto.
Definition: BlobProto.cs:31
List< int > dim
The blob shape dimensions.
Definition: BlobShape.cs:93
Specifies the base parameter for all layers.
LayerType
Specifies the layer type.
The SolverParameter is a parameter for the solver, specifying the train and test networks.
SolverType
Defines the type of solver.
The SolverState specifies the state of a given solver.
Definition: SolverState.cs:17
int end
Specifies the end used by L-BGFS
Definition: SolverState.cs:58
BlobProto gradients
Gradients used with L-BFGS state.
Definition: SolverState.cs:85
int iter
The current iteration.
Definition: SolverState.cs:40
List< BlobProto > history
The history for SGD solvers.
Definition: SolverState.cs:67
int start
Specifies the start used by L-BGFS
Definition: SolverState.cs:49
int current_step
The current step for learning rate.
Definition: SolverState.cs:76
List< BlobProto > s_history
S history used with L-BFGS state.
Definition: SolverState.cs:103
BlobProto direction
Direction used with L-BFGS state.
Definition: SolverState.cs:94
The IXPersist interface is used by the CaffeControl to load and save weights.
Definition: Interfaces.cs:187
The MyCaffe.basecode contains all generic types used throughout MyCaffe.
Definition: Annotation.cs:12
The MyCaffe.common namespace contains common MyCaffe classes.
Definition: BatchInput.cs:8
@ FLOAT
Specifies the single type.
@ DOUBLE
Specifies the double type.
BLOB_TYPE
Defines the tpe of data held by a given Blob.
Definition: Interfaces.cs:62
@ UNKNOWN
The blob is an unknown type.
The MyCaffe.param namespace contains parameters used to create models.
The MyCaffe namespace contains the main body of MyCaffe code that closesly tracks the C++ Caffe open-...
Definition: Annotation.cs:12