MyCaffe  1.12.2.41
Deep learning software for Windows C# programmers.
SliceLayer.cs
1using System;
2using System.Collections.Generic;
3using System.Linq;
4using System.Text;
5using MyCaffe.basecode;
6using MyCaffe.common;
7using MyCaffe.param;
8
9namespace MyCaffe.layers
10{
17 public class SliceLayer<T> : Layer<T>
18 {
19 int m_nNumSlices;
20 int m_nSliceSize;
21 int m_nSliceAxis;
22 List<uint> m_rgSlicePoints = new List<uint>();
23
36 : base(cuda, log, p)
37 {
39 }
40
44 public override int ExactNumBottomBlobs
45 {
46 get { return 1; }
47 }
48
52 public override int MinTopBlobs
53 {
54 get { return 1; }
55 }
56
62 public override void LayerSetUp(BlobCollection<T> colBottom, BlobCollection<T> colTop)
63 {
64 m_rgSlicePoints = Utility.Clone<uint>(m_param.slice_param.slice_point);
65 }
66
72 public override void Reshape(BlobCollection<T> colBottom, BlobCollection<T> colTop)
73 {
74 int nNumAxes = colBottom[0].num_true_axes;
75
77 {
78 m_nSliceAxis = (int)m_param.slice_param.slice_dim;
79 // Don't allow negative indexing for slice_dim a (uint) -- almost
80 // certainly unintended.
81 m_log.CHECK_GE(m_nSliceAxis, 0, "casting slice_dim from uint to int produced a negative result; slice_dim must satisfy 0 <= slice_dim < " + Blob<T>.MAX_BLOB_AXES.ToString());
82 m_log.CHECK_LT(m_nSliceAxis, nNumAxes, "slice_dim is out of range.");
83 }
84 else
85 {
86 m_nSliceAxis = colBottom[0].CanonicalAxisIndex(m_param.slice_param.axis);
87 }
88
89 List<int> rgTopShape = Utility.Clone<int>(colBottom[0].shape());
90 int bottom_slice_axis = colBottom[0].shape(m_nSliceAxis);
91
92 m_nNumSlices = colBottom[0].count(0, m_nSliceAxis);
93 m_nSliceSize = colBottom[0].count(m_nSliceAxis + 1);
94
95 int nCount = 0;
96
97 if (m_rgSlicePoints.Count != 0)
98 {
99 m_log.CHECK_EQ(m_rgSlicePoints.Count, colTop.Count - 1, "The slice point count is incorrect.");
100 m_log.CHECK_LE(colTop.Count, bottom_slice_axis, "slice axis: " + bottom_slice_axis.ToString() + ", bottom[0] shape: '" + colBottom[0].shape_string + "'");
101
102 int nPrev = 0;
103 List<int> rgSlices = new List<int>();
104
105 for (int i = 0; i < m_rgSlicePoints.Count; i++)
106 {
107 m_log.CHECK_GT((int)m_rgSlicePoints[i], nPrev, "The slice point at " + i.ToString() + " should be greater than the previous slice point of " + nPrev.ToString());
108 rgSlices.Add((int)m_rgSlicePoints[i] - nPrev);
109 nPrev = (int)m_rgSlicePoints[i];
110 }
111
112 rgSlices.Add(bottom_slice_axis - nPrev);
113
114 for (int i = 0; i < colTop.Count; i++)
115 {
116 rgTopShape[m_nSliceAxis] = rgSlices[i];
117 colTop[i].Reshape(rgTopShape);
118 nCount += colTop[i].count();
119 }
120 }
121 else
122 {
123 m_log.CHECK_EQ(bottom_slice_axis % colTop.Count, 0, "Number of top blobs (" + colTop.Count.ToString() + ") should evenly divide input slice axis (" + bottom_slice_axis.ToString() + ")");
124 rgTopShape[m_nSliceAxis] = bottom_slice_axis / colTop.Count;
125
126 for (int i = 0; i < colTop.Count; i++)
127 {
128 colTop[i].Reshape(rgTopShape);
129 nCount += colTop[i].count();
130 }
131 }
132
133 m_log.CHECK_EQ(nCount, colBottom[0].count(), "The count (" + nCount.ToString() + ") should be the same as the bottom count (" + colBottom[0].count().ToString() + ")");
134
135 if (colTop.Count == 1)
136 {
137 colTop[0].ShareData(colBottom[0]);
138 colTop[0].ShareDiff(colBottom[0]);
139 }
140 }
141
151 protected override void forward(BlobCollection<T> colBottom, BlobCollection<T> colTop)
152 {
153 if (colTop.Count == 1)
154 return;
155
156 int nOffsetSliceAxis = 0;
157 long hBottomData = colBottom[0].gpu_data;
158 int nBottomSliceAxis = colBottom[0].shape(m_nSliceAxis);
159
160 for (int i = 0; i < colTop.Count; i++)
161 {
162 long hTopData = colTop[i].mutable_gpu_data;
163 int nTopSliceAxis = colTop[i].shape(m_nSliceAxis);
164 int nTopSliceSize = nTopSliceAxis * m_nSliceSize;
165 int nCount = nTopSliceSize * m_nNumSlices;
166
167 m_cuda.slice_fwd(nCount, hBottomData, m_nNumSlices, m_nSliceSize, nBottomSliceAxis, nTopSliceAxis, nOffsetSliceAxis, hTopData);
168 nOffsetSliceAxis += nTopSliceAxis;
169 }
170 }
171
180 protected override void backward(BlobCollection<T> colTop, List<bool> rgbPropagateDown, BlobCollection<T> colBottom)
181 {
182 if (!rgbPropagateDown[0] || colTop.Count == 1)
183 return;
184
185 int nOffsetSliceAxis = 0;
186 long hBottomDiff = colBottom[0].mutable_gpu_diff;
187 int nBottomSliceAxis = colBottom[0].shape(m_nSliceAxis);
188
189 for (int i = 0; i < colTop.Count; i++)
190 {
191 long hTopDiff = colTop[i].gpu_diff;
192 int nTopSliceAxis = colTop[i].shape(m_nSliceAxis);
193 int nTopSliceSize = nTopSliceAxis * m_nSliceSize;
194 int nCount = nTopSliceSize * m_nNumSlices;
195
196 m_cuda.slice_bwd(nCount, hTopDiff, m_nNumSlices, m_nSliceSize, nBottomSliceAxis, nTopSliceAxis, nOffsetSliceAxis, hBottomDiff);
197 nOffsetSliceAxis += nTopSliceAxis;
198 }
199 }
200 }
201}
The Log class provides general output in text form.
Definition: Log.cs:13
void CHECK_EQ(double df1, double df2, string str)
Test whether one number is equal to another.
Definition: Log.cs:239
void CHECK_GT(double df1, double df2, string str)
Test whether one number is greater than another.
Definition: Log.cs:299
void CHECK_LE(double df1, double df2, string str)
Test whether one number is less than or equal to another.
Definition: Log.cs:263
void CHECK_GE(double df1, double df2, string str)
Test whether one number is greater than or equal to another.
Definition: Log.cs:287
void CHECK_LT(double df1, double df2, string str)
Test whether one number is less than another.
Definition: Log.cs:275
The Utility class provides general utility funtions.
Definition: Utility.cs:35
The BlobCollection contains a list of Blobs.
int Count
Returns the number of items in the collection.
void Reshape(int[] rgShape)
Reshapes all blobs in the collection to the given shape.
The Blob is the main holder of data that moves through the Layers of the Net.
Definition: Blob.cs:25
const int MAX_BLOB_AXES
Defines the maximum number of Axes supported by the Blob.
Definition: Blob.cs:55
The CudaDnn object is the main interface to the Low-Level Cuda C++ DLL.
Definition: CudaDnn.cs:969
An interface for the units of computation which can be composed into a Net.
Definition: Layer.cs:31
Log m_log
Specifies the Log for output.
Definition: Layer.cs:43
LayerParameter m_param
Specifies the LayerParameter describing the Layer.
Definition: Layer.cs:47
CudaDnn< T > m_cuda
Specifies the CudaDnn connection to Cuda.
Definition: Layer.cs:39
LayerParameter.LayerType m_type
Specifies the Layer type.
Definition: Layer.cs:35
The SliceLayer takes a blob and slices it along either the num or channel dimensions outputting multi...
Definition: SliceLayer.cs:18
override int ExactNumBottomBlobs
Returns the exact number of required bottom (input) Blobs: input.
Definition: SliceLayer.cs:45
SliceLayer(CudaDnn< T > cuda, Log log, LayerParameter p)
The SliceLayer constructor.
Definition: SliceLayer.cs:35
override int MinTopBlobs
Returns the minimum number of required top (output) Blobs: slice
Definition: SliceLayer.cs:53
override void backward(BlobCollection< T > colTop, List< bool > rgbPropagateDown, BlobCollection< T > colBottom)
Computes the error gradient w.r.t the inputs.
Definition: SliceLayer.cs:180
override void Reshape(BlobCollection< T > colBottom, BlobCollection< T > colTop)
Reshape the bottom (input) and top (output) blobs.
Definition: SliceLayer.cs:72
override void LayerSetUp(BlobCollection< T > colBottom, BlobCollection< T > colTop)
Setup the layer.
Definition: SliceLayer.cs:62
override void forward(BlobCollection< T > colBottom, BlobCollection< T > colTop)
Computes the forward calculation.
Definition: SliceLayer.cs:151
Specifies the base parameter for all layers.
SliceParameter slice_param
Returns the parameter set when initialized with LayerType.SLICE
LayerType
Specifies the layer type.
uint slice_dim
DEPRECIATED: alias for 'axis' – does not support negative indexing.
List< uint > slice_point
Specifies optional slice points which indicate the indexes in the selected dimensions (the number of ...
int axis
Specifies the axis along wich to slice – may be negative to index from the end (e....
The MyCaffe.basecode contains all generic types used throughout MyCaffe.
Definition: Annotation.cs:12
The MyCaffe.common namespace contains common MyCaffe classes.
Definition: BatchInput.cs:8
The MyCaffe.layers namespace contains all layers that have a solidified code base,...
Definition: LayerFactory.cs:15
The MyCaffe.param namespace contains parameters used to create models.
The MyCaffe namespace contains the main body of MyCaffe code that closesly tracks the C++ Caffe open-...
Definition: Annotation.cs:12