MyCaffe  1.12.2.41
Deep learning software for Windows C# programmers.
MnistDataLoader.cs
1// Copyright (c) 2018-2020 SignalPop LLC and contributors. All rights reserved.
2// License: Apache 2.0
3// License: https://github.com/MyCaffe/MyCaffe/blob/master/LICENSE
4// Original Source: https://github.com/MyCaffe/MyCaffe/blob/master/MyCaffe.data/MnistDataLoader.cs
5using System;
6using System.Collections.Generic;
7using System.Diagnostics;
8using System.IO;
9using System.IO.Compression;
10using System.Linq;
11using System.Text;
12using System.Threading.Tasks;
13using MyCaffe.basecode;
14using MyCaffe.db.image;
16using System.Threading;
17using System.Drawing;
18
19namespace MyCaffe.data
20{
27 public class MnistDataLoader
28 {
29 MnistDataLoaderLite m_extractor;
30 MnistDataParameters m_param;
31 Log m_log;
32 CancelEvent m_evtCancel;
33
37 public event EventHandler<ProgressArgs> OnProgress;
41 public event EventHandler<ProgressArgs> OnError;
45 public event EventHandler OnCompleted;
46
53 public MnistDataLoader(MnistDataParameters param, Log log, CancelEvent evtCancel)
54 {
55 m_extractor = new MnistDataLoaderLite(Path.GetDirectoryName(param.TrainImagesFile));
56 m_extractor.OnProgress += m_extractor_OnProgress;
57 m_extractor.OnError += m_extractor_OnError;
58
59 m_param = param;
60 m_log = log;
61 m_evtCancel = evtCancel;
62 m_evtCancel.Reset();
63 }
64
65 private void m_extractor_OnError(object sender, ProgressArgs e)
66 {
67 if (OnError != null)
68 OnError(sender, e);
69 }
70
71 private void m_extractor_OnProgress(object sender, ProgressArgs e)
72 {
73 if (OnProgress != null)
74 OnProgress(sender, e);
75 }
76
77 private string dataset_name
78 {
79 get { return "MNIST"; }
80 }
81
87 public bool LoadDatabase(int nCreatorID = 0)
88 {
89 int nIdx = 0;
90 int nTotal = 0;
91
92 try
93 {
94 List<Tuple<byte[], int>> rgTrainImg;
95 List<Tuple<byte[], int>> rgTestImg;
96
97 m_extractor.ExtractImages(out rgTrainImg, out rgTestImg);
98
99 reportProgress(nIdx, nTotal, "Loading " + dataset_name + " database...");
100
101 DatasetFactory factory = null;
102 string strExportFolder = null;
103
104 if (m_param.ExportToFile)
105 {
106 strExportFolder = m_param.ExportPath.TrimEnd('\\') + "\\";
107 if (!Directory.Exists(strExportFolder))
108 Directory.CreateDirectory(strExportFolder);
109 }
110
111 string strTrainSrc = "training";
112 if (!m_param.ExportToFile)
113 {
114 factory = new DatasetFactory();
115
116 strTrainSrc = dataset_name + "." + strTrainSrc;
117 int nSrcId = factory.GetSourceID(strTrainSrc);
118 if (nSrcId != 0)
119 factory.DeleteSourceData(nSrcId);
120 }
121
122 if (!loadFile(factory, rgTrainImg, m_extractor.Channels, m_extractor.Height, m_extractor.Width, strTrainSrc, strExportFolder))
123 return false;
124
125 string strTestSrc = "testing";
126 if (!m_param.ExportToFile)
127 {
128 strTestSrc = dataset_name + "." + strTestSrc;
129 int nSrcId = factory.GetSourceID(strTestSrc);
130 if (nSrcId != 0)
131 factory.DeleteSourceData(nSrcId);
132 }
133
134 if (!loadFile(factory, rgTestImg, m_extractor.Channels, m_extractor.Height, m_extractor.Width, strTestSrc, strExportFolder))
135 return false;
136
137 if (!m_param.ExportToFile)
138 {
139 SourceDescriptor srcTrain = factory.LoadSource(strTrainSrc);
140 SourceDescriptor srcTest = factory.LoadSource(strTestSrc);
141 DatasetDescriptor ds = new DatasetDescriptor(nCreatorID, dataset_name, null, null, srcTrain, srcTest, dataset_name, dataset_name + " Character Dataset");
142 factory.AddDataset(ds);
143 factory.UpdateDatasetCounts(ds.ID);
144 }
145
146 return true;
147 }
148 catch (Exception excpt)
149 {
150 throw excpt;
151 }
152 finally
153 {
154 if (OnCompleted != null)
155 OnCompleted(this, new EventArgs());
156 }
157 }
158
159 private bool loadFile(DatasetFactory factory, List<Tuple<byte[], int>> rgData, int nC, int nH, int nW, string strSourceName, string strExportPath)
160 {
161 if (strExportPath != null)
162 {
163 strExportPath += strSourceName;
164
165 if (!Directory.Exists(strExportPath))
166 Directory.CreateDirectory(strExportPath);
167 }
168
169 Stopwatch sw = new Stopwatch();
170
171 reportProgress(0, 0, " Source: " + strSourceName);
172
173 try
174 {
175 if (factory != null)
176 {
177 int nSrcId = factory.AddSource(strSourceName, nC, nW, nH, false, 0, true);
178
179 factory.Open(nSrcId, 500, Database.FORCE_LOAD.NONE, m_log);
180 factory.DeleteSourceData();
181 }
182
183 // Storing to database;
184 int nLabel;
185 byte[] rgPixels;
186
187 Datum datum = new Datum(false, nC, nW, nH, -1, DateTime.MinValue, new List<byte>(), 0, false, -1);
188 string strAction = (m_param.ExportToFile) ? "exporing" : "loading";
189
190 reportProgress(0, rgData.Count, " " + strAction + " a total of " + rgData.Count.ToString() + " items.");
191 reportProgress(0, rgData.Count, " (with rows: " + nH.ToString() + ", cols: " + nW.ToString() + ")");
192
193 sw.Start();
194
195 List<SimpleDatum> rgImg = new List<SimpleDatum>();
196
197 FileStream fsFileDesc = null;
198 StreamWriter swFileDesc = null;
199 if (m_param.ExportToFile)
200 {
201 string strFile = strExportPath + "\\file_list.txt";
202 fsFileDesc = File.OpenWrite(strFile);
203 swFileDesc = new StreamWriter(fsFileDesc);
204 }
205
206 for (int i = 0; i < rgData.Count; i++)
207 {
208 rgPixels = rgData[i].Item1;
209 nLabel = rgData[i].Item2;
210
211 if (sw.Elapsed.TotalMilliseconds > 1000)
212 {
213 reportProgress(i, rgData.Count, " " + strAction + " data...");
214 sw.Restart();
215 }
216
217 datum.SetData(rgPixels, nLabel);
218
219 if (factory != null)
220 factory.PutRawImageCache(i, datum, 5);
221 else if (strExportPath != null)
222 saveToFile(strExportPath, i, datum, swFileDesc);
223
224 rgImg.Add(new SimpleDatum(datum));
225
226 if (m_evtCancel.WaitOne(0))
227 return false;
228 }
229
230 if (swFileDesc != null)
231 {
232 swFileDesc.Flush();
233 swFileDesc.Close();
234 swFileDesc.Dispose();
235
236 fsFileDesc.Close();
237 fsFileDesc.Dispose();
238 }
239
240 if (factory != null)
241 {
242 factory.ClearImageCache(true);
243 factory.UpdateSourceCounts();
244 factory.SaveImageMean(SimpleDatum.CalculateMean(m_log, rgImg.ToArray(), new WaitHandle[] { new ManualResetEvent(false) }), true);
245 }
246
247 reportProgress(rgData.Count, rgData.Count, " " + strAction + " completed.");
248 }
249 finally
250 {
251 }
252
253 return true;
254 }
255
256 private void saveToFile(string strPath, int nIdx, Datum d, StreamWriter sw)
257 {
258 string strFile = strPath.TrimEnd('\\') + "\\" + getImageFileName(nIdx, d);
259 Bitmap bmp = ImageData.GetImage(d);
260
261 bmp.Save(strFile);
262 bmp.Dispose();
263
264 if (sw != null)
265 sw.WriteLine(strFile + " " + d.Label.ToString());
266 }
267
268 private string getImageFileName(int nIdx, SimpleDatum sd)
269 {
270 return "img_" + nIdx.ToString() + "-" + sd.Label.ToString() + ".png";
271 }
272
273 private void Log_OnWriteLine(object sender, LogArg e)
274 {
275 reportProgress((int)(e.Progress * 1000), 1000, e.Message);
276 }
277
278 private string expandFile(string strFile)
279 {
280 FileInfo fi = new FileInfo(strFile);
281 string strNewFile = fi.DirectoryName;
282 int nPos = fi.Name.LastIndexOf('.');
283
284 if (nPos >= 0)
285 strNewFile += "\\" + fi.Name.Substring(0, nPos) + ".bin";
286 else
287 strNewFile += "\\" + fi.Name + ".bin";
288
289 if (!File.Exists(strNewFile))
290 {
291 using (FileStream fs = fi.OpenRead())
292 {
293 using (FileStream fsBin = File.Create(strNewFile))
294 {
295 using (GZipStream decompStrm = new GZipStream(fs, CompressionMode.Decompress))
296 {
297 decompStrm.CopyTo(fsBin);
298 }
299 }
300 }
301 }
302
303 return strNewFile;
304 }
305
306 private void reportProgress(int nIdx, int nTotal, string strMsg)
307 {
308 if (OnProgress != null)
309 OnProgress(this, new ProgressArgs(new ProgressInfo(nIdx, nTotal, strMsg)));
310 }
311
312 private void reportError(int nIdx, int nTotal, Exception err)
313 {
314 if (OnError != null)
315 OnError(this, new ProgressArgs(new ProgressInfo(nIdx, nTotal, "ERROR", err)));
316 }
317 }
318}
The CancelEvent provides an extension to the manual cancel event that allows for overriding the manua...
Definition: CancelEvent.cs:17
void Reset()
Resets the event clearing any signaled state.
Definition: CancelEvent.cs:279
bool WaitOne(int nMs=int.MaxValue)
Waits for the signal state to occur.
Definition: CancelEvent.cs:290
The Datum class is a simple wrapper to the SimpleDatum class to ensure compatibility with the origina...
Definition: Datum.cs:12
The ImageData class is a helper class used to convert between Datum, other raw data,...
Definition: ImageData.cs:14
static Bitmap GetImage(SimpleDatum d, ColorMapper clrMap=null, List< int > rgClrOrder=null)
Converts a SimplDatum (or Datum) into an image, optionally using a ColorMapper.
Definition: ImageData.cs:506
The LogArg is passed as an argument to the Log::OnWriteLine event.
Definition: EventArgs.cs:53
string Message
Returns the message logged.
Definition: EventArgs.cs:101
The Log class provides general output in text form.
Definition: Log.cs:13
double Progress
Returns the progress value.
Definition: EventArgs.cs:44
The SimpleDatum class holds a data input within host memory.
Definition: SimpleDatum.cs:161
static SimpleDatum CalculateMean(Log log, SimpleDatum[] rgImg, WaitHandle[] rgAbort)
Calculate the mean of an array of SimpleDatum and return the mean as a new SimpleDatum.
void SetData(SimpleDatum d)
Set the data of the current SimpleDatum by copying the data of another.
int Label
Return the known label of the data.
int ID
Get/set the database ID of the item.
The DatasetDescriptor class describes a dataset which contains both a training data source and testin...
The SourceDescriptor class contains all information describing a data source.
The MnistDataLoader is used to create the MNIST dataset and load it into the database managed by the ...
EventHandler OnCompleted
The OnComplete event fires once the dataset creation has completed.
EventHandler< ProgressArgs > OnError
The OnError event fires when an error occurs.
MnistDataLoader(MnistDataParameters param, Log log, CancelEvent evtCancel)
The constructor.
EventHandler< ProgressArgs > OnProgress
The OnProgress event fires during the creation process to show the progress.
bool LoadDatabase(int nCreatorID=0)
Create the dataset and load it into the database.
The MnistDataLoader is used to extrac the MNIST dataset to disk and load the data into the training p...
int Channels
Return the image channel count (should = 1 for black and white images).
void ExtractImages(out List< Tuple< byte[], int > > rgTrainingData, out List< Tuple< byte[], int > > rgTestingData)
Extract the images from the .bin files and save to disk
int Width
Return the image with.
EventHandler< ProgressArgs > OnProgress
The OnProgress event fires during the creation process to show the progress.
EventHandler< ProgressArgs > OnError
The OnError event fires when an error occurs.
int Height
Return the image height.
Contains the dataset parameters used to create the MNIST dataset.
string TrainImagesFile
Specifies the training image file 'train-images-idx3-ubyte.gz'.
string ExportPath
Specifies where to export the files when 'ExportToFile' = true.
bool ExportToFile
Specifies to export the images to a folder.
Defines the arguments sent to the OnProgress and OnError events.
Definition: ProgressArgs.cs:17
The Database class manages the actual connection to the physical database using Entity Framworks from...
Definition: Database.cs:23
FORCE_LOAD
Defines the force load type.
Definition: Database.cs:57
The DatasetFactory manages the connection to the Database object.
void PutRawImageCache(int nIdx, SimpleDatum sd, int nBackgroundWritingThreadCount=0, string strDescription=null, bool bActive=true, params ParameterData[] rgParams)
Add a SimpleDatum to the RawImage cache.
int GetSourceID(string strName)
Returns the ID of a data source given its name.
bool SaveImageMean(SimpleDatum sd, bool bUpdate, int nSrcId=0)
Save the SimpleDatum as a RawImageMean in the database.
void DeleteSourceData(int nSrcId=0)
Delete the data source data (images, means, results and parameters) from the database.
int AddSource(SourceDescriptor src, ConnectInfo ci=null, bool? bSaveImagesToFileOverride=null)
Adds a new data source to the database.
void UpdateSourceCounts()
Saves the label cache, updates the label counts from the database and then updates the source counts ...
int AddDataset(DatasetDescriptor ds, ConnectInfo ci=null, bool? bSaveImagesToFileOverride=null)
Adds or updates the training source, testing source, dataset creator and dataset to the database.
void ClearImageCache(bool bSave)
Clear the RawImage cache and optionally save the images.
SourceDescriptor LoadSource(string strSource)
Load the source descriptor from a data source name.
void UpdateDatasetCounts(int nDsId, ConnectInfo ci=null)
Updates the dataset counts, and training/testing source counts.
void Open(SourceDescriptor src, int nCacheMax=500, ConnectInfo ci=null)
Open a given data source.
The descriptors namespace contains all descriptor used to describe various items stored within the da...
The MyCaffe.basecode contains all generic types used throughout MyCaffe.
Definition: Annotation.cs:12
The MyCaffe.data namespace contains dataset creators used to create common testing datasets such as M...
Definition: BinaryFile.cs:16
The MyCaffe.db.image namespace contains all image database related classes.
Definition: Database.cs:18
The MyCaffe namespace contains the main body of MyCaffe code that closesly tracks the C++ Caffe open-...
Definition: Annotation.cs:12