1
2
3 """
4 This module provides code for doing k-nearest-neighbors classification.
5
6 k Nearest Neighbors is a supervised learning algorithm that classifies
7 a new observation based the classes in its surrounding neighborhood.
8
9 Glossary:
10 distance The distance between two points in the feature space.
11 weight The importance given to each point for classification.
12
13
14 Classes:
15 kNN Holds information for a nearest neighbors classifier.
16
17
18 Functions:
19 train Train a new kNN classifier.
20 calculate Calculate the probabilities of each class, given an observation.
21 classify Classify an observation into a class.
22
23 Weighting Functions:
24 equal_weight Every example is given a weight of 1.
25
26 """
27 try:
28 from Numeric import *
29 except ImportError, x:
30 raise ImportError, "This module requires Numeric (precursor to NumPy)"
31
32 from Bio import listfns
33 from Bio import distance
34
36 """Holds information necessary to do nearest neighbors classification.
37
38 Members:
39 classes List of the possible classes.
40 xs List of the neighbors.
41 ys List of the classes that the neighbors belong to.
42 k Number of neighbors to look at.
43
44 """
46 """kNN()"""
47 self.classes = []
48 self.xs = []
49 self.ys = []
50 self.k = None
51
53 """equal_weight(x, y) -> 1"""
54
55 return 1
56
57 -def train(xs, ys, k, typecode=None):
58 """train(xs, ys, k) -> kNN
59
60 Train a k nearest neighbors classifier on a training set. xs is a
61 list of observations and ys is a list of the class assignments.
62 Thus, xs and ys should contain the same number of elements. k is
63 the number of neighbors that should be examined when doing the
64 classification.
65
66 """
67 knn = kNN()
68 knn.classes = listfns.items(ys)
69 knn.xs = asarray(xs, typecode)
70 knn.ys = ys
71 knn.k = k
72 return knn
73
75 """calculate(knn, x[, weight_fn][, distance_fn]) -> weight dict
76
77 Calculate the probability for each class. knn is a kNN object. x
78 is the observed data. weight_fn is an optional function that
79 takes x and a training example, and returns a weight. distance_fn
80 is an optional function that takes two points and returns the
81 distance between them. Returns a dictionary of the class to the
82 weight given to the class.
83
84 """
85 x = asarray(x)
86
87 order = []
88 for i in range(len(knn.xs)):
89 dist = distance_fn(x, knn.xs[i])
90 order.append((dist, i))
91 order.sort()
92
93
94 weights = {}
95 for k in knn.classes:
96 weights[k] = 0.0
97 for dist, i in order[:knn.k]:
98 klass = knn.ys[i]
99 weights[klass] = weights[klass] + weight_fn(x, knn.xs[i])
100
101 return weights
102
104 """classify(knn, x[, weight_fn][, distance_fn]) -> class
105
106 Classify an observation into a class. If not specified, weight_fn will
107 give all neighbors equal weight and distance_fn will be the euclidean
108 distance.
109
110 """
111 weights = calculate(
112 knn, x, weight_fn=weight_fn, distance_fn=distance_fn)
113
114 most_class = None
115 most_weight = None
116 for klass, weight in weights.items():
117 if most_class is None or weight > most_weight:
118 most_class = klass
119 most_weight = weight
120 return most_class
121