00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018 #ifndef __itkWeightSetBase_h
00019 #define __itkWeightSetBase_h
00020
00021 #include "itkLayerBase.h"
00022 #include "itkLightProcessObject.h"
00023 #include <vnl/vnl_matrix.h>
00024 #include <vnl/vnl_diag_matrix.h>
00025 #include "itkMacro.h"
00026 #include "itkVector.h"
00027 #include "itkMersenneTwisterRandomVariateGenerator.h"
00028 #include <math.h>
00029 #include <stdlib.h>
00030
00031 namespace itk
00032 {
00033 namespace Statistics
00034 {
00035
00036 template<class TVector, class TOutput>
00037 class WeightSetBase : public LightProcessObject
00038 {
00039 public:
00040
00041 typedef WeightSetBase Self;
00042 typedef LightProcessObject Superclass;
00043 typedef SmartPointer<Self> Pointer;
00044 typedef SmartPointer<const Self> ConstPointer;
00045 itkTypeMacro(WeightSetBase, LightProcessObject);
00046
00047 typedef MersenneTwisterRandomVariateGenerator RandomVariateGeneratorType;
00048
00049 typedef typename TVector::ValueType ValueType;
00050 typedef ValueType* ValuePointer;
00051
00052 void Initialize();
00053
00054 ValueType RandomWeightValue(ValueType low, ValueType high);
00055
00056 void ForwardPropagate(ValuePointer inputlayeroutputvalues);
00057
00058 void BackwardPropagate(ValuePointer inputerror);
00059
00060 void SetConnectivityMatrix(vnl_matrix < int>);
00061
00062 void SetNumberOfInputNodes(unsigned int n);
00063 unsigned int GetNumberOfInputNodes();
00064
00065 void SetNumberOfOutputNodes(unsigned int n);
00066 unsigned int GetNumberOfOutputNodes();
00067
00068 void SetRange(ValueType Range);
00069
00070 ValuePointer GetOutputValues();
00071
00072 ValuePointer GetInputValues();
00073
00074 ValuePointer GetTotalDeltaValues();
00075
00076 ValuePointer GetTotalDeltaBValues();
00077
00078 ValuePointer GetDeltaValues();
00079
00080 void SetDeltaValues(ValuePointer);
00081
00082 void SetDWValues(ValuePointer);
00083
00084 void SetDBValues(ValuePointer);
00085
00086 ValuePointer GetDeltaBValues();
00087
00088 void SetDeltaBValues(ValuePointer);
00089
00090 ValuePointer GetDWValues();
00091
00092 ValuePointer GetPrevDWValues();
00093
00094 ValuePointer GetPrevDBValues();
00095
00096 ValuePointer GetPrev_m_2DWValues();
00097
00098 ValuePointer GetPrevDeltaValues();
00099
00100 ValuePointer GetPrev_m_2DeltaValues();
00101
00102 ValuePointer GetPrevDeltaBValues();
00103
00104 ValuePointer GetWeightValues();
00105
00106
00107 void SetWeightValues(ValuePointer weights);
00108
00109 void UpdateWeights(ValueType LearningRate);
00110
00111 void SetMomentum(ValueType);
00112
00113 ValueType GetMomentum();
00114
00115 void SetBias(ValueType);
00116
00117 ValueType GetBias();
00118
00119 bool GetFirstPass();
00120
00121 void SetFirstPass(bool);
00122
00123 bool GetSecondPass();
00124
00125 void SetSecondPass(bool);
00126
00127 void InitializeWeights();
00128
00129 itkSetMacro(WeightSetId,int);
00130 itkGetMacro(WeightSetId,int);
00131
00132 itkSetMacro(InputLayerId,int);
00133 itkGetMacro(InputLayerId,int);
00134
00135 itkSetMacro(OutputLayerId,int);
00136 itkGetMacro(OutputLayerId,int);
00137
00138 protected:
00139
00140 WeightSetBase();
00141 ~WeightSetBase();
00142
00144 virtual void PrintSelf( std::ostream& os, Indent indent ) const;
00145
00146 typename RandomVariateGeneratorType::Pointer m_RandomGenerator;
00147 unsigned int m_NumberOfInputNodes;
00148 unsigned int m_NumberOfOutputNodes;
00149 vnl_matrix<ValueType> m_OutputValues;
00150 vnl_matrix<ValueType> m_InputErrorValues;
00151
00152
00153
00154
00155
00156
00157
00158 vnl_matrix<ValueType> m_DW;
00159 vnl_matrix<ValueType> m_DW_new;
00160 vnl_matrix<ValueType> m_DW_m_1;
00161 vnl_matrix<ValueType> m_DW_m_2;
00162 vnl_matrix<ValueType> m_DW_m;
00163
00164 vnl_vector<ValueType> m_DB;
00165 vnl_vector<ValueType> m_DB_new;
00166 vnl_vector<ValueType> m_DB_m_1;
00167 vnl_vector<ValueType> m_DB_m_2;
00168
00169 vnl_matrix<ValueType> m_del;
00170 vnl_matrix<ValueType> m_del_new;
00171 vnl_matrix<ValueType> m_del_m_1;
00172 vnl_matrix<ValueType> m_del_m_2;
00173
00174 vnl_vector<ValueType> m_delb;
00175 vnl_vector<ValueType> m_delb_new;
00176 vnl_vector<ValueType> m_delb_m_1;
00177 vnl_vector<ValueType> m_delb_m_2;
00178
00179 vnl_matrix<ValueType> m_InputLayerOutput;
00180 vnl_matrix<ValueType> m_WeightMatrix;
00181 vnl_matrix<int> m_ConnectivityMatrix;
00182
00183 ValueType m_Momentum;
00184 ValueType m_Bias;
00185 bool m_FirstPass;
00186 bool m_SecondPass;
00187 ValueType m_Range;
00188 int m_InputLayerId;
00189 int m_OutputLayerId;
00190 int m_WeightSetId;
00191
00192 };
00193
00194 }
00195 }
00196
00197 #ifndef ITK_MANUAL_INSTANTIATION
00198 #include "itkWeightSetBase.txx"
00199 #endif
00200
00201
00202 #endif
00203