MyCaffe  1.12.2.41
Deep learning software for Windows C# programmers.
TokenizedDataParameter.cs
1using System;
2using System.Collections.Generic;
3using System.Linq;
4using System.Text;
6using MyCaffe.basecode;
8
9namespace MyCaffe.param.gpt
10{
14 [Serializable]
15 [TypeConverter(typeof(ExpandableObjectConverter))]
17 {
22
23 uint m_nBatchSize;
24 uint m_nBlockSize;
25 INPUT_TYPE m_inputType;
26 string m_strSource;
27 int? m_nSeed = null;
28 string m_strDbgIdxFile;
29 VOCABULARY_TYPE m_vocabType = VOCABULARY_TYPE.CHARACTER;
30 SAMPLE_METHOD m_sampleMethod = SAMPLE_METHOD.ARGMAX;
31
35 public enum VOCABULARY_TYPE
36 {
40 CHARACTER,
44 WORD,
48 SENTENCEPIECE,
52 CUSTOM
53 }
54
58 public enum SAMPLE_METHOD
59 {
63 ARGMAX,
68 PROBABILITY
69 }
70
74 public enum INPUT_TYPE
75 {
79 TEXT_FILE,
83 CUSTOM
84 }
85
88 {
89 }
90
95 {
96 get { return m_pythonParam; }
97 set { m_pythonParam = value; }
98 }
99
103 public int? seed
104 {
105 get { return m_nSeed; }
106 set { m_nSeed = value; }
107 }
108
112 [Description("Specifies data source input type.")]
114 {
115 get { return m_inputType; }
116 set { m_inputType = value; }
117 }
118
122 [Description("Specifies the vocabulary type to use.")]
124 {
125 get { return m_vocabType; }
126 set { m_vocabType = value; }
127 }
128
132 [Description("Specifies the sampling method used when post processing logits (default = ARGMAX).")]
134 {
135 get { return m_sampleMethod; }
136 set { m_sampleMethod = value; }
137 }
138
142 [Description("Specifies the data source based on the INPUT_TYPE used. Each dataset has both a training and testing data source.")]
143 public string source
144 {
145 get { return m_strSource; }
146 set { m_strSource = value; }
147 }
148
152 [Description("Specifies an optional data index file used for debuging only.")]
153 public string debug_index_file
154 {
155 get { return m_strDbgIdxFile; }
156 set { m_strDbgIdxFile = value; }
157 }
158
162 [Description("Specifies batch size.")]
163 public uint batch_size
164 {
165 get { return m_nBatchSize; }
166 set { m_nBatchSize = value; }
167 }
168
172 public uint block_size
173 {
174 get { return m_nBlockSize; }
175 set { m_nBlockSize = value; }
176 }
177
179 public override object Load(System.IO.BinaryReader br, bool bNewInstance = true)
180 {
181 RawProto proto = RawProto.Parse(br.ReadString());
183
184 if (!bNewInstance)
185 Copy(p);
186
187 return p;
188 }
189
191 public override void Copy(LayerParameterBase src)
192 {
194
196 m_inputType = p.input_type;
197 m_strSource = p.source;
198 m_nBatchSize = p.batch_size;
199 m_nBlockSize = p.block_size;
200 m_nSeed = p.seed;
201 m_strDbgIdxFile = p.debug_index_file;
202 m_vocabType = p.vocabulary_type;
203 m_sampleMethod = p.sample_method;
204 }
205
207 public override LayerParameterBase Clone()
208 {
210 p.Copy(this);
211 return p;
212 }
213
219 public override RawProto ToProto(string strName)
220 {
221 RawProtoCollection rgChildren = new RawProtoCollection();
222
223 if (m_pythonParam != null)
224 rgChildren.Add(m_pythonParam.ToProto("python_param"));
225
226 rgChildren.Add("input_type", input_type.ToString());
227 rgChildren.Add("vocabulary_type", vocabulary_type.ToString());
228 rgChildren.Add("sample_method", sample_method.ToString());
229 rgChildren.Add("source", "\"" + source + "\"");
230 rgChildren.Add("batch_size", batch_size.ToString());
231 rgChildren.Add("block_size", block_size.ToString());
232
233 if (!string.IsNullOrEmpty(debug_index_file))
234 rgChildren.Add("debug_index_file", debug_index_file);
235
236 if (seed != null)
237 rgChildren.Add("seed", seed.ToString());
238
239 return new RawProto(strName, "", rgChildren);
240 }
241
248 {
249 string strVal;
251
252 RawProto rpPython = rp.FindChild("python_param");
253 if (rpPython != null)
255
256 if ((strVal = rp.FindValue("block_size")) != null)
257 p.block_size = uint.Parse(strVal);
258
259 if ((strVal = rp.FindValue("batch_size")) != null)
260 p.batch_size = uint.Parse(strVal);
261
262 if ((strVal = rp.FindValue("source")) != null)
263 p.source = strVal.Trim('\"');
264
265 if ((strVal = rp.FindValue("seed")) != null)
266 p.seed = int.Parse(strVal);
267
268 if ((strVal = rp.FindValue("debug_index_file")) != null)
269 p.debug_index_file = strVal;
270
271 if ((strVal = rp.FindValue("input_type")) != null)
272 {
273 if (strVal == INPUT_TYPE.TEXT_FILE.ToString())
274 p.input_type = INPUT_TYPE.TEXT_FILE;
275 else if (strVal == INPUT_TYPE.CUSTOM.ToString())
276 p.input_type = INPUT_TYPE.CUSTOM;
277 else
278 throw new Exception("Unknown input type '" + strVal + "'");
279 }
280
281 if ((strVal = rp.FindValue("vocabulary_type")) != null)
282 {
283 if (strVal == VOCABULARY_TYPE.CHARACTER.ToString())
284 p.vocabulary_type = VOCABULARY_TYPE.CHARACTER;
285 else if (strVal == VOCABULARY_TYPE.WORD.ToString())
287 else if (strVal == VOCABULARY_TYPE.SENTENCEPIECE.ToString())
288 p.vocabulary_type = VOCABULARY_TYPE.SENTENCEPIECE;
289 else if (strVal == VOCABULARY_TYPE.CUSTOM.ToString())
291 else
292 throw new Exception("Unknown vocabulary type '" + strVal + "'");
293 }
294
295 if ((strVal = rp.FindValue("sample_method")) != null)
296 {
297 if (strVal == SAMPLE_METHOD.ARGMAX.ToString())
298 p.sample_method = SAMPLE_METHOD.ARGMAX;
299 else if (strVal == SAMPLE_METHOD.PROBABILITY.ToString())
300 p.sample_method = SAMPLE_METHOD.PROBABILITY;
301 else
302 throw new Exception("Unknown sample method '" + strVal + "'");
303 }
304
305 return p;
306 }
307 }
308}
The RawProtoCollection class is a list of RawProto objects.
void Add(RawProto p)
Adds a RawProto to the collection.
The RawProto class is used to parse and output Google prototxt file data.
Definition: RawProto.cs:17
RawProto FindChild(string strName)
Searches for a given node.
Definition: RawProto.cs:231
static RawProto Parse(string str)
Parses a prototxt and places it in a new RawProto.
Definition: RawProto.cs:306
string FindValue(string strName)
Searches for a falue of a node within this nodes children.
Definition: RawProto.cs:105
The LayerParameterBase is the base class for all other layer specific parameters.
Specifies the parameters for the TokenizedDataLayer.
string debug_index_file
Specifies an optional data index file used for debugging only.
SAMPLE_METHOD sample_method
Specifies the sampling method used when post processing logits (default = ARGMAX).
override RawProto ToProto(string strName)
Convert the parameter into a RawProto.
INPUT_TYPE input_type
Specifies data source input type.
PythonParameter python_param
Specifies the PythonParameter used by the python implementation of the TokenizedDataPairsLayer,...
string source
Specifies the data source based on the INPUT_TYPE used. Each dataset has both a training and testing ...
uint block_size
Specifies size of the block.
static TokenizedDataParameter FromProto(RawProto rp)
Parses the parameter from a RawProto.
override LayerParameterBase Clone()
Creates a new copy of this instance of the parameter.
SAMPLE_METHOD
Defines the sampling method used.
VOCABULARY_TYPE vocabulary_type
Specifies the vocabulary type to use.
int? seed
Specifies the seed used to initialize the random number generator (normally only for testing).
PythonParameter m_pythonParam
Python layer implementations use this parameter for Python specific settings such as the location of ...
TokenizedDataParameter()
Constructor for the parameter.
VOCABULARY_TYPE
Defines the vocabulary type to use.
override object Load(System.IO.BinaryReader br, bool bNewInstance=true)
Load the parameter from a binary reader.
override void Copy(LayerParameterBase src)
Copy on parameter to another.
Specifies the parameters for the PythonLayer.
static PythonParameter FromProto(RawProto rp)
Parses the parameter from a RawProto.
override RawProto ToProto(string strName)
Convert the parameter into a RawProto.
The MyCaffe.basecode contains all generic types used throughout MyCaffe.
Definition: Annotation.cs:12
@ CUSTOM
Defines a purely custom training method.
The MyCaffe.param namespace contains parameters used to create models.
The MyCaffe namespace contains the main body of MyCaffe code that closesly tracks the C++ Caffe open-...
Definition: Annotation.cs:12