MyCaffe  1.12.2.41
Deep learning software for Windows C# programmers.
BeamSearch.cs
1using MyCaffe.basecode;
2using MyCaffe.layers;
3using System;
4using System.Collections.Generic;
5using System.Diagnostics;
6using System.Linq;
7using System.Text;
8using System.Threading.Tasks;
9
10namespace MyCaffe.common
11{
18 public class BeamSearch<T>
19 {
20 Net<T> m_net;
21 Layer<T> m_layer = null;
22
27 public BeamSearch(Net<T> net)
28 {
29 m_net = net;
30
31 foreach (Layer<T> layer1 in m_net.layers)
32 {
34 {
35 m_layer = layer1;
36 break;
37 }
38 }
39
40 if (m_layer == null)
41 throw new Exception("At least one layer in the net must support pre and post processing!");
42 }
43
56 public List<Tuple<double, bool, List<Tuple<string, int, double>>>> Search(PropertySet input, int nK, double dfThreshold = 0.01, int nMax = 80)
57 {
58 List<Tuple<double, bool, List<Tuple<string, int, double>>>> rgSequences = new List<Tuple<double, bool, List<Tuple<string, int, double>>>>();
59 rgSequences.Add(new Tuple<double, bool, List<Tuple<string, int, double>>>(0, false, new List<Tuple<string, int, double>>()));
60 int nSeqLen;
61
62 BlobCollection<T> colBottom = m_layer.PreProcessInput(input, out nSeqLen, null);
63 double dfLoss;
64 string strInput = input.GetProperty("InputData");
65 bool bDone = false;
66
67 BlobCollection<T> colTop = m_net.Forward(colBottom, out dfLoss);
68 List<Tuple<string, int, double>> rgRes = m_layer.PostProcessOutput(colTop[0], nK);
69 rgRes = rgRes.Where(p => p.Item3 >= dfThreshold).ToList();
70 List<List<Tuple<string, int, double>>> rgrgRes = new List<List<Tuple<string, int, double>>>();
71
72 rgrgRes.Add(rgRes);
73
74 while (!bDone && nMax > 0)
75 {
76 int nProcessedCount = 0;
77
78 List<Tuple<double, bool, List<Tuple<string, int, double>>>> rgCandidates = new List<Tuple<double, bool, List<Tuple<string, int, double>>>>();
79
80 for (int i = 0; i < rgSequences.Count; i++)
81 {
82 if (rgrgRes[i].Count > 0)
83 {
84 for (int j = 0; j < rgrgRes[i].Count; j++)
85 {
86 if (rgrgRes[i][j].Item1.Length > 0)
87 {
88 double dfScore = rgSequences[i].Item1 - Math.Log(rgrgRes[i][j].Item3);
89
90 List<Tuple<string, int, double>> rgSequence1 = new List<Tuple<string, int, double>>();
91 rgSequence1.AddRange(rgSequences[i].Item3);
92 rgSequence1.Add(rgrgRes[i][j]);
93
94 rgCandidates.Add(new Tuple<double, bool, List<Tuple<string, int, double>>>(dfScore, false, rgSequence1));
95 nProcessedCount++;
96 }
97 }
98 }
99 else
100 {
101 rgCandidates.Add(new Tuple<double, bool, List<Tuple<string, int, double>>>(rgSequences[i].Item1, true, rgSequences[i].Item3));
102 }
103 }
104
105 if (nProcessedCount > 0)
106 {
107 rgSequences = rgCandidates.OrderBy(p => p.Item1).Take(nK).ToList();
108 rgrgRes = new List<List<Tuple<string, int, double>>>();
109
110 for (int i = 0; i < rgSequences.Count; i++)
111 {
112 if (!rgSequences[i].Item2)
113 {
114 rgRes = new List<Tuple<string, int, double>>();
115
116 // Reset state.
117 m_layer.PreProcessInput(strInput, 1, colBottom);
118 m_net.Forward(colBottom, out dfLoss, true);
119
120 // Re-run through each branch to get correct state at the leaf
121 for (int j = 0; j < rgSequences[i].Item3.Count; j++)
122 {
123 int nIdx = rgSequences[i].Item3[j].Item2;
124
125 m_layer.PreProcessInput(strInput, nIdx, colBottom);
126 colTop = m_net.Forward(colBottom, out dfLoss, true);
127
128 if (j == rgSequences[i].Item3.Count - 1)
129 {
130 List<Tuple<string, int, double>> rgRes1 = m_layer.PostProcessOutput(colTop[0], nK);
131 rgRes1 = rgRes1.Where(p => p.Item3 >= dfThreshold).ToList();
132
133 for (int k = 0; k < rgRes1.Count; k++)
134 {
135 if (rgRes1[k].Item1.Length > 0)
136 rgRes.Add(rgRes1[k]);
137 else
138 Trace.WriteLine("EOS");
139 }
140
141 rgrgRes.Add(rgRes);
142 }
143 }
144 }
145 else
146 {
147 rgrgRes.Add(new List<Tuple<string, int, double>>());
148 }
149 }
150 }
151 else
152 {
153 bDone = true;
154 }
155
156 nMax--;
157 }
158
159 return rgSequences;
160 }
161 }
162}
Specifies a key-value pair of properties.
Definition: PropertySet.cs:16
string GetProperty(string strName, bool bThrowExceptions=true)
Returns a property as a string value.
Definition: PropertySet.cs:146
The BeamSearch uses the softmax output from the network and continually runs the net on each output (...
Definition: BeamSearch.cs:19
List< Tuple< double, bool, List< Tuple< string, int, double > > > > Search(PropertySet input, int nK, double dfThreshold=0.01, int nMax=80)
Perform the beam-search.
Definition: BeamSearch.cs:56
BeamSearch(Net< T > net)
The constructor.
Definition: BeamSearch.cs:27
The BlobCollection contains a list of Blobs.
Connects Layer's together into a direct acrylic graph (DAG) specified by a NetParameter
Definition: Net.cs:23
An interface for the units of computation which can be composed into a Net.
Definition: Layer.cs:31
virtual bool SupportsPostProcessing
Should return true when pre PostProcessing methods are overriden.
Definition: Layer.cs:264
virtual List< Tuple< string, int, double > > PostProcessOutput(Blob< T > blobSofmtax, int nK=1)
The PostProcessOutput allows derivative data layers to post-process the results, converting them back...
Definition: Layer.cs:328
virtual bool SupportsPreProcessing
Should return true when PreProcessing methods are overriden.
Definition: Layer.cs:256
virtual BlobCollection< T > PreProcessInput(PropertySet customInput, out int nSeqLen, BlobCollection< T > colBottom=null)
The PreprocessInput allows derivative data layers to convert a property set of input data into the bo...
Definition: Layer.cs:294
The MyCaffe.basecode contains all generic types used throughout MyCaffe.
Definition: Annotation.cs:12
The MyCaffe.common namespace contains common MyCaffe classes.
Definition: BatchInput.cs:8
The MyCaffe.layers namespace contains all layers that have a solidified code base,...
Definition: LayerFactory.cs:15
The MyCaffe namespace contains the main body of MyCaffe code that closesly tracks the C++ Caffe open-...
Definition: Annotation.cs:12