Deep learning software for Windows C# programmers.
1// Copyright (c) 2018-2020 SignalPop LLC and contributors. All rights reserved.
2// License: Apache 2.0
3// License:
4// Original Source:
5using System;
6using System.Collections.Generic;
7using System.Diagnostics;
8using System.IO;
9using System.IO.Compression;
10using System.Linq;
11using System.Text;
12using System.Threading.Tasks;
13using MyCaffe.basecode;
14using MyCaffe.db.image;
16using System.Threading;
17using System.Drawing;
27 public class MnistDataLoader
28 {
29 MnistDataLoaderLite m_extractor;
30 MnistDataParameters m_param;
31 Log m_log;
32 CancelEvent m_evtCancel;
37 public event EventHandler<ProgressArgs> OnProgress;
41 public event EventHandler<ProgressArgs> OnError;
45 public event EventHandler OnCompleted;
53 public MnistDataLoader(MnistDataParameters param, Log log, CancelEvent evtCancel)
54 {
55 m_extractor = new MnistDataLoaderLite(Path.GetDirectoryName(param.TrainImagesFile));
56 m_extractor.OnProgress += m_extractor_OnProgress;
57 m_extractor.OnError += m_extractor_OnError;
59 m_param = param;
60 m_log = log;
61 m_evtCancel = evtCancel;
62 m_evtCancel.Reset();
63 }
65 private void m_extractor_OnError(object sender, ProgressArgs e)
66 {
67 if (OnError != null)
68 OnError(sender, e);
69 }
71 private void m_extractor_OnProgress(object sender, ProgressArgs e)
72 {
73 if (OnProgress != null)
74 OnProgress(sender, e);
75 }
77 private string dataset_name
78 {
79 get { return "MNIST"; }
80 }
87 public bool LoadDatabase(int nCreatorID = 0)
88 {
89 int nIdx = 0;
90 int nTotal = 0;
92 try
93 {
94 List<Tuple<byte[], int>> rgTrainImg;
95 List<Tuple<byte[], int>> rgTestImg;
97 m_extractor.ExtractImages(out rgTrainImg, out rgTestImg);
99 reportProgress(nIdx, nTotal, "Loading " + dataset_name + " database...");
101 DatasetFactory factory = null;
102 string strExportFolder = null;
104 if (m_param.ExportToFile)
105 {
106 strExportFolder = m_param.ExportPath.TrimEnd('\\') + "\\";
107 if (!Directory.Exists(strExportFolder))
108 Directory.CreateDirectory(strExportFolder);
109 }
111 string strTrainSrc = "training";
112 if (!m_param.ExportToFile)
113 {
114 factory = new DatasetFactory();
116 strTrainSrc = dataset_name + "." + strTrainSrc;
117 int nSrcId = factory.GetSourceID(strTrainSrc);
118 if (nSrcId != 0)
119 factory.DeleteSourceData(nSrcId);
120 }
122 if (!loadFile(factory, rgTrainImg, m_extractor.Channels, m_extractor.Height, m_extractor.Width, strTrainSrc, strExportFolder))
123 return false;
125 string strTestSrc = "testing";
126 if (!m_param.ExportToFile)
127 {
128 strTestSrc = dataset_name + "." + strTestSrc;
129 int nSrcId = factory.GetSourceID(strTestSrc);
130 if (nSrcId != 0)
131 factory.DeleteSourceData(nSrcId);
132 }
134 if (!loadFile(factory, rgTestImg, m_extractor.Channels, m_extractor.Height, m_extractor.Width, strTestSrc, strExportFolder))
135 return false;
137 if (!m_param.ExportToFile)
138 {
139 SourceDescriptor srcTrain = factory.LoadSource(strTrainSrc);
140 SourceDescriptor srcTest = factory.LoadSource(strTestSrc);
141 DatasetDescriptor ds = new DatasetDescriptor(nCreatorID, dataset_name, null, null, srcTrain, srcTest, dataset_name, dataset_name + " Character Dataset");
142 factory.AddDataset(ds);
143 factory.UpdateDatasetCounts(ds.ID);
144 }
146 return true;
147 }
148 catch (Exception excpt)
149 {
150 throw excpt;
151 }
152 finally
153 {
154 if (OnCompleted != null)
155 OnCompleted(this, new EventArgs());
156 }
157 }
159 private bool loadFile(DatasetFactory factory, List<Tuple<byte[], int>> rgData, int nC, int nH, int nW, string strSourceName, string strExportPath)
160 {
161 if (strExportPath != null)
162 {
163 strExportPath += strSourceName;
165 if (!Directory.Exists(strExportPath))
166 Directory.CreateDirectory(strExportPath);
167 }
169 Stopwatch sw = new Stopwatch();
171 reportProgress(0, 0, " Source: " + strSourceName);
173 try
174 {
175 if (factory != null)
176 {
177 int nSrcId = factory.AddSource(strSourceName, nC, nW, nH, false, 0, true);
179 factory.Open(nSrcId, 500, Database.FORCE_LOAD.NONE, m_log);
180 factory.DeleteSourceData();
181 }
183 // Storing to database;
184 int nLabel;
185 byte[] rgPixels;
187 Datum datum = new Datum(false, nC, nW, nH, -1, DateTime.MinValue, new List<byte>(), 0, false, -1);
188 string strAction = (m_param.ExportToFile) ? "exporing" : "loading";
190 reportProgress(0, rgData.Count, " " + strAction + " a total of " + rgData.Count.ToString() + " items.");
191 reportProgress(0, rgData.Count, " (with rows: " + nH.ToString() + ", cols: " + nW.ToString() + ")");
193 sw.Start();
195 List<SimpleDatum> rgImg = new List<SimpleDatum>();
197 FileStream fsFileDesc = null;
198 StreamWriter swFileDesc = null;
199 if (m_param.ExportToFile)
200 {
201 string strFile = strExportPath + "\\file_list.txt";
202 fsFileDesc = File.OpenWrite(strFile);
203 swFileDesc = new StreamWriter(fsFileDesc);
204 }
206 for (int i = 0; i < rgData.Count; i++)
207 {
208 rgPixels = rgData[i].Item1;
209 nLabel = rgData[i].Item2;
211 if (sw.Elapsed.TotalMilliseconds > 1000)
212 {
213 reportProgress(i, rgData.Count, " " + strAction + " data...");
214 sw.Restart();
215 }
217 datum.SetData(rgPixels, nLabel);
219 if (factory != null)
220 factory.PutRawImageCache(i, datum, 5);
221 else if (strExportPath != null)
222 saveToFile(strExportPath, i, datum, swFileDesc);
224 rgImg.Add(new SimpleDatum(datum));
226 if (m_evtCancel.WaitOne(0))
227 return false;
228 }
230 if (swFileDesc != null)
231 {
232 swFileDesc.Flush();
233 swFileDesc.Close();
234 swFileDesc.Dispose();
236 fsFileDesc.Close();
237 fsFileDesc.Dispose();
238 }
240 if (factory != null)
241 {
242 factory.ClearImageCache(true);
243 factory.UpdateSourceCounts();
244 factory.SaveImageMean(SimpleDatum.CalculateMean(m_log, rgImg.ToArray(), new WaitHandle[] { new ManualResetEvent(false) }), true);
245 }
247 reportProgress(rgData.Count, rgData.Count, " " + strAction + " completed.");
248 }
249 finally
250 {
251 }
253 return true;
254 }
256 private void saveToFile(string strPath, int nIdx, Datum d, StreamWriter sw)
257 {
258 string strFile = strPath.TrimEnd('\\') + "\\" + getImageFileName(nIdx, d);
259 Bitmap bmp = ImageData.GetImage(d);
261 bmp.Save(strFile);
262 bmp.Dispose();
264 if (sw != null)
265 sw.WriteLine(strFile + " " + d.Label.ToString());
266 }
268 private string getImageFileName(int nIdx, SimpleDatum sd)
269 {
270 return "img_" + nIdx.ToString() + "-" + sd.Label.ToString() + ".png";
271 }
273 private void Log_OnWriteLine(object sender, LogArg e)
274 {
275 reportProgress((int)(e.Progress * 1000), 1000, e.Message);
276 }
278 private string expandFile(string strFile)
279 {
280 FileInfo fi = new FileInfo(strFile);
281 string strNewFile = fi.DirectoryName;
282 int nPos = fi.Name.LastIndexOf('.');
284 if (nPos >= 0)
285 strNewFile += "\\" + fi.Name.Substring(0, nPos) + ".bin";
286 else
287 strNewFile += "\\" + fi.Name + ".bin";
289 if (!File.Exists(strNewFile))
290 {
291 using (FileStream fs = fi.OpenRead())
292 {
293 using (FileStream fsBin = File.Create(strNewFile))
294 {
295 using (GZipStream decompStrm = new GZipStream(fs, CompressionMode.Decompress))
296 {
297 decompStrm.CopyTo(fsBin);
298 }
299 }
300 }
301 }
303 return strNewFile;
304 }
306 private void reportProgress(int nIdx, int nTotal, string strMsg)
307 {
308 if (OnProgress != null)
309 OnProgress(this, new ProgressArgs(new ProgressInfo(nIdx, nTotal, strMsg)));
310 }
312 private void reportError(int nIdx, int nTotal, Exception err)
313 {
314 if (OnError != null)
315 OnError(this, new ProgressArgs(new ProgressInfo(nIdx, nTotal, "ERROR", err)));
316 }
317 }
