00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018 #ifndef __itkRBFNetwork_h
00019 #define __itkRBFNetwork_h
00020
00021
00022 #include "itkMultilayerNeuralNetworkBase.h"
00023 #include "itkBackPropagationLayer.h"
00024 #include "itkRBFLayer.h"
00025 #include "itkCompletelyConnectedWeightSet.h"
00026 #include "itkSigmoidTransferFunction.h"
00027 #include "itkLogSigmoidTransferFunction.h"
00028 #include "itkSymmetricSigmoidTransferFunction.h"
00029 #include "itkTanSigmoidTransferFunction.h"
00030 #include "itkHardLimitTransferFunction.h"
00031 #include "itkSignedHardLimitTransferFunction.h"
00032 #include "itkGaussianTransferFunction.h"
00033 #include "itkTanHTransferFunction.h"
00034 #include "itkIdentityTransferFunction.h"
00035 #include "itkSumInputFunction.h"
00036
00037 namespace itk
00038 {
00039 namespace Statistics
00040 {
00041 template<class TVector, class TOutput>
00042 class RBFNetwork : public MultilayerNeuralNetworkBase<TVector, TOutput>
00043 {
00044 public:
00045
00046 typedef RBFNetwork Self;
00047 typedef MultilayerNeuralNetworkBase<TVector, TOutput> Superclass;
00048 typedef SmartPointer<Self> Pointer;
00049 typedef SmartPointer<const Self> ConstPointer;
00050 typedef typename Superclass::ValueType ValueType;
00051 typedef Array<ValueType> ArrayType;
00052 typedef TransferFunctionBase<ValueType> TransferFunctionType;
00053 typedef RadialBasisFunctionBase<ValueType> RBFType;
00054 typedef InputFunctionBase<ValueType*, ValueType> InputFunctionType;
00055 typedef EuclideanDistance<ArrayType> DistanceMetricType;
00056
00057 typename InputFunctionType::Pointer InputFunction;
00058 typename DistanceMetricType::Pointer DistanceMetric;
00059
00060 typename TransferFunctionType::Pointer InputTransferFunction;
00061 typename RBFType::Pointer HiddenTransferFunction;
00062 typename TransferFunctionType::Pointer OutputTransferFunction;
00063
00064 typedef typename Superclass::NetworkOutputType NetworkOutputType;
00065
00066
00067 itkTypeMacro(RBFNetwork,
00068 MultilayerNeuralNetworkBase);
00069 itkNewMacro(Self) ;
00070
00071
00072
00073 void Initialize();
00074
00075 itkSetMacro(NumOfInputNodes, int);
00076 itkGetConstReferenceMacro(NumOfInputNodes, int);
00077
00078 itkSetMacro(NumOfHiddenNodes, int);
00079 itkGetConstReferenceMacro(NumOfHiddenNodes, int);
00080
00081 itkSetMacro(NumOfOutputNodes, int);
00082 itkGetConstReferenceMacro(NumOfOutputNodes, int);
00083
00084 itkSetMacro(HiddenLayerBias, ValueType);
00085 itkGetConstReferenceMacro(HiddenLayerBias, ValueType);
00086
00087 itkSetMacro(OutputLayerBias, ValueType);
00088 itkGetConstReferenceMacro(OutputLayerBias, ValueType);
00089
00090 itkSetMacro(Classes, int);
00091 itkGetConstReferenceMacro(Classes,int);
00092
00093
00094 virtual NetworkOutputType GenerateOutput(TVector samplevector);
00095
00096 void SetInputTransferFunction(TransferFunctionType* f);
00097 void SetDistanceMetric(DistanceMetricType* f);
00098 void SetHiddenTransferFunction(TransferFunctionType* f);
00099 void SetOutputTransferFunction(TransferFunctionType* f);
00100
00101 void SetInputFunction(InputFunctionType* f);
00102 void InitializeWeights();
00103
00104 void SetCenter(TVector c);
00105 void SetRadius(ValueType r);
00106
00107 protected:
00108
00109 RBFNetwork();
00110 ~RBFNetwork(){};
00111
00113 virtual void PrintSelf( std::ostream& os, Indent indent ) const;
00114
00115 private:
00116
00117 int m_NumOfInputNodes;
00118 int m_NumOfHiddenNodes;
00119 int m_NumOfOutputNodes;
00120 int m_Classes;
00121 ValueType m_HiddenLayerBias;
00122 ValueType m_OutputLayerBias;
00123 std::vector<TVector> m_Centers;
00124 std::vector<double> m_Radii;
00125 };
00126
00127 }
00128 }
00129
00130 #ifndef ITK_MANUAL_INSTANTIATION
00131 #include "itkRBFNetwork.txx"
00132 #endif
00133
00134 #endif
00135