MyCaffe  1.12.2.41
Deep learning software for Windows C# programmers.
MemoryCollection.cs
1using MyCaffe.basecode;
2using System;
3using System.Collections.Generic;
4using System.IO;
5using System.Linq;
6using System.Text;
7using System.Threading.Tasks;
8
10{
14 public class MemoryCollection
15 {
16 int m_nCount = 0;
17 int[] m_rgIdx = null;
18 double[] m_rgfPriorities = null;
22 protected MemoryItem[] m_rgItems;
26 protected int m_nNextIdx = 0;
27
32 public MemoryCollection(int nMax)
33 {
34 m_rgItems = new MemoryItem[nMax];
35 }
36
40 public int[] Indexes
41 {
42 get { return m_rgIdx; }
43 set { m_rgIdx = value; }
44 }
45
49 public double[] Priorities
50 {
51 get { return m_rgfPriorities; }
52 set { m_rgfPriorities = value; }
53 }
54
58 public int NextIndex
59 {
60 get { return m_nNextIdx; }
61 }
62
66 public int Count
67 {
68 get { return m_nCount; }
69 }
70
76 public MemoryItem this[int nIdx]
77 {
78 get { return m_rgItems[nIdx]; }
79 }
80
85 public virtual void Add(MemoryItem item)
86 {
87 m_rgItems[m_nNextIdx] = item;
88 m_nNextIdx++;
89
90 if (m_nNextIdx == m_rgItems.Length)
91 m_nNextIdx = 0;
92
93 if (m_nCount < m_rgItems.Length)
94 m_nCount++;
95 }
96
104 {
105 if (nCount >= Count)
106 return this;
107
108 MemoryCollection col = new MemoryCollection(nCount);
109
110 while (col.Count < nCount)
111 {
112 int nIdx = random.Next(Count);
113 col.Add(m_rgItems[nIdx]);
114 }
115
116 return col;
117 }
118
123 public List<StateBase> GetNextState()
124 {
125 return m_rgItems.Select(p => p.NextState).ToList();
126 }
127
132 public List<SimpleDatum> GetNextStateData()
133 {
134 return m_rgItems.Select(p => p.NextData).ToList();
135 }
136
141 public List<SimpleDatum> GetNextStateClip()
142 {
143 if (m_rgItems[0].NextState != null && m_rgItems[0].NextState.Clip != null)
144 return m_rgItems.Select(p => p.NextState.Clip).ToList();
145
146 return null;
147 }
148
153 public List<SimpleDatum> GetCurrentStateData()
154 {
155 return m_rgItems.Select(p => p.CurrentData).ToList();
156 }
157
162 public List<SimpleDatum> GetCurrentStateClip()
163 {
164 if (m_rgItems[0].CurrentState != null && m_rgItems[0].CurrentState.Clip != null)
165 return m_rgItems.Select(p => p.CurrentState.Clip).ToList();
166
167 return null;
168 }
169
175 public float[] GetActionsAsOneHotVector(int nActionCount)
176 {
177 float[] rg = new float[m_rgItems.Length * nActionCount];
178
179 for (int i = 0; i < m_rgItems.Length; i++)
180 {
181 int nAction = m_rgItems[i].Action;
182
183 for (int j = 0; j < nActionCount; j++)
184 {
185 rg[(i * nActionCount) + j] = (j == nAction) ? 1 : 0;
186 }
187 }
188
189 return rg;
190 }
191
197 {
198 float[] rgDoneInv = new float[m_rgItems.Length];
199
200 for (int i = 0; i < m_rgItems.Length; i++)
201 {
202 if (m_rgItems[i].IsTerminated)
203 rgDoneInv[i] = 0;
204 else
205 rgDoneInv[i] = 1;
206 }
207
208 return rgDoneInv;
209 }
210
215 public float[] GetRewards()
216 {
217 return m_rgItems.Select(p => (float)p.Reward).ToArray();
218 }
219
224 public void Save(string strFile)
225 {
226 if (File.Exists(strFile))
227 File.Delete(strFile);
228
229 using (StreamWriter sw = new StreamWriter(strFile))
230 {
231 for (int i = 0; i < Count; i++)
232 {
233 MemoryItem mi = m_rgItems[i];
234 string strLine = mi.CurrentData.ToArrayAsString(4) + "," + mi.Action.ToString() + "," + mi.Reward.ToString() + "," + (mi.IsTerminated ? 1 : 0).ToString() + "," + mi.NextData.ToArrayAsString(4);
235 sw.WriteLine(strLine);
236 }
237 }
238 }
239
244 public void Load(string strFile)
245 {
246 m_nNextIdx = 0;
247 m_nCount = 0;
248
249 List<MemoryItem> rg = new List<MemoryItem>();
250
251 using (StreamReader sr = new StreamReader(strFile))
252 {
253 string strLine = sr.ReadLine();
254
255 while (strLine != null)
256 {
257 string[] rgstr = strLine.Split(',');
258 int nIdx = 0;
259
260 List<double> rgdfData = new List<double>();
261 rgdfData.Add(BaseParameter.ParseDouble(rgstr[nIdx])); nIdx++;
262 rgdfData.Add(BaseParameter.ParseDouble(rgstr[nIdx])); nIdx++;
263 rgdfData.Add(BaseParameter.ParseDouble(rgstr[nIdx])); nIdx++;
264 rgdfData.Add(BaseParameter.ParseDouble(rgstr[nIdx])); nIdx++;
265 SimpleDatum sdCurrent = new SimpleDatum(true, 4, 1, 1, -1, DateTime.MinValue, rgdfData, 0, false, -1);
266
267 int nAction = int.Parse(rgstr[nIdx]); nIdx++;
268 double dfReward = BaseParameter.ParseDouble(rgstr[nIdx]); nIdx++;
269 bool bTerminated = (int.Parse(rgstr[nIdx]) == 1) ? true : false; nIdx++;
270
271 rgdfData = new List<double>();
272 rgdfData.Add(BaseParameter.ParseDouble(rgstr[nIdx])); nIdx++;
273 rgdfData.Add(BaseParameter.ParseDouble(rgstr[nIdx])); nIdx++;
274 rgdfData.Add(BaseParameter.ParseDouble(rgstr[nIdx])); nIdx++;
275 rgdfData.Add(BaseParameter.ParseDouble(rgstr[nIdx])); nIdx++;
276 SimpleDatum sdNext = new SimpleDatum(true, 4, 1, 1, -1, DateTime.MinValue, rgdfData, 0, false, -1);
277
278 rg.Add(new MemoryItem(null, sdCurrent, nAction, null, sdNext, dfReward, bTerminated, 0, 0));
279 strLine = sr.ReadLine();
280 }
281 }
282
283 foreach (MemoryItem m in rg)
284 {
285 Add(m);
286 }
287 }
288 }
289
293 public class MemoryItem
294 {
295 StateBase m_state0;
296 StateBase m_state1;
297 SimpleDatum m_x0;
298 SimpleDatum m_x1;
299 int m_nAction;
300 int m_nIteration;
301 int m_nEpisode;
302 bool m_bTerminated;
303 double m_dfReward;
304
317 public MemoryItem(StateBase currentState, SimpleDatum currentData, int nAction, StateBase nextState, SimpleDatum nextData, double dfReward, bool bTerminated, int nIteration, int nEpisode)
318 {
319 m_state0 = currentState;
320 m_state1 = nextState;
321 m_x0 = currentData;
322 m_x1 = nextData;
323 m_nAction = nAction;
324 m_bTerminated = bTerminated;
325 m_dfReward = dfReward;
326 m_nIteration = nIteration;
327 m_nEpisode = nEpisode;
328 }
329
333 public bool IsTerminated
334 {
335 get { return m_bTerminated; }
336 }
337
341 public double Reward
342 {
343 get { return m_dfReward; }
344 set { m_dfReward = value; }
345 }
346
351 {
352 get { return m_state0; }
353 }
354
359 {
360 get { return m_state1; }
361 }
362
367 {
368 get { return m_x0; }
369 }
370
375 {
376 get { return m_x1; }
377 }
378
382 public int Action
383 {
384 get { return m_nAction; }
385 }
386
390 public int Iteration
391 {
392 get { return m_nIteration; }
393 }
394
398 public int Episode
399 {
400 get { return m_nEpisode; }
401 }
402
407 public override string ToString()
408 {
409 return "episode = " + m_nEpisode.ToString() + " action = " + m_nAction.ToString() + " reward = " + m_dfReward.ToString("N2");
410 }
411
412 private string tostring(float[] rg)
413 {
414 string str = "{";
415
416 for (int i = 0; i < rg.Length; i++)
417 {
418 str += rg[i].ToString("N5");
419 str += ",";
420 }
421
422 str = str.TrimEnd(',');
423 str += "}";
424
425 return str;
426 }
427 }
428}
The BaseParameter class is the base class for all other parameter classes.
static double ParseDouble(string strVal)
Parse double values using the US culture if the decimal separator = '.', then using the native cultur...
The CryptoRandom is a random number generator that can use either the standard .Net Random objec or t...
Definition: CryptoRandom.cs:14
int Next(int nMinVal, int nMaxVal, bool bMaxInclusive=true)
Returns a random int within the range
The SimpleDatum class holds a data input within host memory.
Definition: SimpleDatum.cs:161
string ToArrayAsString(int nMaxItems)
Returns a string containing the items of the SimpleDatum.
The StateBase is the base class for the state of each observation - this is defined by actual trainer...
Definition: StateBase.cs:16
The memory collection stores a set of memory items.
int m_nNextIdx
Specifies the next available index in the rolling list.
float[] GetInvertedDoneAsOneHotVector()
Returns the inverted done (1 - done) values as a one-hot vector.
void Load(string strFile)
Load all memory items from file.
List< SimpleDatum > GetNextStateClip()
Returns the list of clip items associated with the next state.
MemoryCollection GetRandomSamples(CryptoRandom random, int nCount)
Retrieves a random sample of items from the list.
double[] Priorities
Get/set the priorities associated with the collection (if any).
List< StateBase > GetNextState()
Returns the list of Next State items.
int Count
Returns the current count of items.
int NextIndex
Returns the next index.
int[] Indexes
Get/set the indexes associated with the collection (if any).
void Save(string strFile)
Save the memory items to file.
MemoryCollection(int nMax)
The constructor.
List< SimpleDatum > GetCurrentStateData()
Returns the list of data items associated with the current state.
float[] GetActionsAsOneHotVector(int nActionCount)
Returns the action items as a set of one-hot vectors.
List< SimpleDatum > GetCurrentStateClip()
Returns the list of clip items associated with the current state.
virtual void Add(MemoryItem item)
Adds a new memory item to the array of items and if at capacity, removes an item.
MemoryItem[] m_rgItems
Specifies the memory item list.
float[] GetRewards()
Returns the rewards as a vector.
List< SimpleDatum > GetNextStateData()
Returns the list of data items associated with the next state.
The MemoryItem stores the information about a given cycle.
override string ToString()
Returns a string representation of the state transition.
StateBase CurrentState
Returns the current state.
int Iteration
Returns the iteration of the state transition.
MemoryItem(StateBase currentState, SimpleDatum currentData, int nAction, StateBase nextState, SimpleDatum nextData, double dfReward, bool bTerminated, int nIteration, int nEpisode)
The constructor.
double Reward
Returns the reward of the state transition.
bool IsTerminated
Returns the termination status of the next state.
int Episode
Returns the episode of the state transition.
StateBase NextState
Returns the next state.
SimpleDatum NextData
Returns the data associated with the next state.
SimpleDatum CurrentData
Returns the data associated with the current state.
The MyCaffe.basecode contains all generic types used throughout MyCaffe.
Definition: Annotation.cs:12
The MyCaffe namespace contains the main body of MyCaffe code that closesly tracks the C++ Caffe open-...
Definition: Annotation.cs:12