k-nn classification & Statistical Pattern Recognition Andreas C. Kapourani (Credit: Hiroshi Shimodaira) February 27 k-nn classification In classification, the data consist of a training set and a test set. The training set is a set of N feature vectors and their class labels; and a learning algorithm is used to train a classifier using the training set. The test set is a set of feature vectors to which the classifier must assign labels. An intuitive way to decide how to classify an unlabelled test item is to look at the training data points nearby, and make the classification according to the classes of those nearby labelled data points. This intuition is formalised in a classification approach called K-nearest neighbour (k-nn) classification. The k-nn approach looks at the K points in the training set that are closest to the test point; the test point is then classified according to the class to which the majority of the K-nearest neighbours belong. The training of the K-NN classifier is simple, we just need to store all the training set! However, testing is much slower, since it involves measuring the distance between each test point and every training point. We can write the k-nn algorithm precisely as follows, where X is the training data set with class labels so that X = {(x, c)}, Z is the test set, there are C possible classes, and r is the distance metric (typically the Euclidean distance): For each test example z Z: Compute the distance r(z, x) between z and each training example (x, c) X Select U k (z) X, the set of the k nearest training examples to z. Assign test point z to class c, where the majority of the K-nearest neighbours belongs. To illustrate how the k-nn algorithm works, we will use the Lemon-Orange dataset described in the lecture slides. Since we do not have class labels for this data set, we will use the output of the K- means algorithm as ground truth labels for the training data; then, based on this training set we will run k-nn algorithm to classify new test points. Exercise Download the lemon-orange.txt file from the course website, and load the data in MAT- LAB. Save the actual data in matrix A. Run MATLAB s K-means algorithm for K = clusters and plot the data together with the cluster means. The result should look like Figure. Note: Before running k-means type rng(2). Check what rng command does.
9 8 height 7 6 6 7 8 9 width Figure : Lemon-Orange dataset after running k-means algorithm for K =. Based on k-means output, we will assign each data point to that specific class using the following code. % Orange-lemons data training_data = A; % set seed for repeatable results rng(2); % Number of clusters K = ; % Corresponding class labels for each training point C = kmeans(training_data, K); % we use Matlab s built in function % Concatenate labels to the training data (type: help cat) training_data = cat(2, training_data, C); We can create some random test data using the following code: % Random test data test_data = [6. 7.6; 7. 8.7; 8. 9]; % show test data on the same plot with the training data plot(training_data(:,), training_data(:,2), + ); hold on; xlabel(col_headers{}); ylabel(col_headers{2}); % show test data as red circles plot(test_data(:,), test_data(:,2), ro ); axis([3.7 3.7 ]); hold off; 2
9 8 height 7 6 6 7 8 9 width Figure 2: Lemon-Orange training dataset together with some random test data shown in red circles. Now we can implement k-nn algorithm for different K and observe on which class each test point will be assigned. The following code illustrates the k-nn algorithm for the first test point: % distance between first test point and each training observation % NOTE: we need the square_dist function from previous labs r_zx = square_dist(a, test_data(,:)); % Sort the distances in ascending order [r_zx,idx] = sort(r_zx, 2, ascend ); Knn = 3; % K nearest neighbors, e.g. Knn = 3 r_zx = r_zx(:knn); % keep the first Knn distances idx = idx(:knn); % keep the first Knn indexes % majority vote only on those Knn indexes prediction = mode(c(idx)) % class label 2, belongs to the green colour prediction = 2 Exercise For the same test point, use different number of k-nearest neighbours and check if the class label changes. Write a function simpleknn which will implement a simple k-nn algorithm, similar to what was shown in the lab session. Classify the other two test samples using your k-nn algorithm for K-nearest neighbours =,, 6,. 3
. Plot decision boundary We can draw a line such that one side of it corresponds to one class and the other side to the other. Such a line is called a decision boundary. In the case where we need to classify more than two classes a more complex decision boundary will be created. For the Lemon-Orange dataset, we will create two class labels using the K-means algorithm. Then, the decision boundaries using -nearest neighbour and -nearest neighbours will be calculated. % Colormap we will use to colour each classes. cmap = [.836989,.68689,.66737;.8766,.8272,.9962;.83393,.6277,.7933779;.83293,.83,.7798;.7793273,.69836,.82;.727833,.8783,.33927;.969888,.8627,.968399;.9388233,.8686,.2968 ;.83622,.77723,.68336;.7968,.7968,.7968]; rng(2); % Set seed Knn = ; % K- nearest neighbours K = ; % Number of clusters C = kmeans(a, K); % Class labels for each training point Xplot = linspace(min(a(:,)), max(a(:,)), ) ; Yplot = linspace(min(a(:,2)), max(a(:,2)), ) ; % Obtain the grid vectors for the two dimensions [Xv Yv] = meshgrid(xplot, Yplot); gridx = [Xv(:), Yv(:)]; % Concatenate to get a 2-D point. classes = length(xv(:)); for i = :length(gridx) % Apply k-nn for each test point dists = square_dist(a, gridx(i, :)) ; % Compute distances [d I] = sort(dists, ascend ); classes(i) = mode(c(i(:knn))); end figure; % This function will draw the decision boundaries [CC,h] = contourf(xplot(:), Yplot(:), reshape(classes, length(xplot ), length(yplot))); set(h, LineColor, none ); colormap(cmap); hold on; % Plot the scatter plots grouped by their classes scatters = gscatter(a(:,), A(:,2), C, [,,], o, ); % Fill in the color of each point according to the class labels. for n = :length(scatters) set(scatters(n), MarkerFaceColor, cmap(n,:)); end
Running the above code for 2 and clusters with knn = we obtain Figure 3. 2 9 9 8 8 7 6 7 6 2 3 6 6. 7 7. 8 8. 9 9. 6 6. 7 7. 8 8. 9 9. (a) (b) Figure 3: Decision boundaries for (a) C = 2 and (b) C = using knn =. Exercise Show the decision boundaries using knn =, 2,,, when we have two clusters. 2 Statistical Pattern Recognition In many real life problems, we have to make decisions based on uncertainty, e.g. due to inaccurate or incomplete information about a problem. The mathematics of probability provides the way to deal with uncertainty, and tells us how to update our knowledge and beliefs if new information becomes available. In this lab session we introduce the use of probability and statistics for pattern recognition and learning. Consider a pattern classification problem in which there are K classes. Let C denote the class, taking values,..., K, and let P(C k ) be the prior probability of class k. The observed input data, which is a D-dimensional feature vector, is denoted by X. Once the training set is used to train the classifier, a new, unlabelled data point x is observed. Let P(x C k ) be the likelihood of class k for the data x. To perform classification we could use Bayes theorem to compute the posterior probabilities P(C k x), for every class k =,..., K; we can then classify x by assigning it to the class with the highest posterior probability. That is, we need to compute: P(C k x) = P(x C k)p(c k ) P(x) = P(x C k)p(c k ) k P(x C k )P(C k ) for every class k and then assign x to the class with the highest posterior probability (i.e. find arg maxp(c k x)). This procedure is sometimes called MAP (maximum a posteriori) decision rule. k Thus, for each class k we need to provide an estimate of the likelihood P(x C k ) and the prior probability P(C k ). () (2)
2. Fish Example data To illustrate the use of posterior probabilities to perform classification, we will use a dataset which contains measurements of fish lengths. The dataset comprises 2 observations ( male and female fish), each representing the length of the fish. The objective is to classify the fish as male or female based on their length measurement. You can download the Fish dataset from the course website: http://www.inf.ed.ac.uk/teaching/courses/inf2b/learnlabschedule.html You will find a file named fish.txt, download it and save it in your current folder. Note that this file is already pre-processed, and each line corresponds to three columns, the first column denotes the fish length x, the second and the third columns denote the number of male n M (x) and female n F (x) observations for that length, respectively. Exercise Read the file and load the data in MATLAB. Store the fish data in a matrix A. 2.2 Compute prior and likelihood Let class C = M represent male, and C = F represent female fish. The prior probability expresses our beliefs about the sex of the fish before any evidence is taken into account. We can assume that male and female fish have different prior probabilities (e.g. P(C M ) =.6, P(C F ) =.) or we can compute an estimate of those from the actual data by finding the proportion of male and female fish out of the total observations: % Total number of male fish, i.e. N_M = sum(a(:,2)); % Total number of female fish, i.e. N_F = sum(a(:,3)); % total number of observations, i.e. 2 N_total = N_M + N_F; % prior probability of male fish prior_m = N_M / N_total prior_m =. % prior probability for female is -P(M), since P(M) + P(F) =. prior_f = - prior_m prior_f =. We can now estimate the likelihoods P(x C M ) and P(x C F ) as the counts in each class for length x divided by the total number of examples in that class: P(x C M ) n M(x) N M (3) P(x C F ) n F(x) N F () Thus we can estimate the likelihoods of the length of each fish given each class using relative frequencies (i.e. using the training set of examples from each class). Note that we obtain estimates of P(x C M ) and P(x C F ), since N M and N F are finite. We can compute the likelihood for each fish length x, simply by computing the relative frequencies as follows: 6
% Likelihood vector for each length x for male fish lik_m = A(:,2)/N_M; % Likelihood vector for each length x for female fish lik_f = A(:,3)/N_F; Let s observe the length distribution for each class. We can do this easily by plotting histograms. Figure shows the length distribution for male and female fish. For each class, also the Cumulative Distribution Function (CDF) is shown. The CDF is the probability that a real-valued random variable X will take a value less than or equal to x, that is, CDF(x) = P(X x), where P denotes the probability..2 Lengths of male fish.2 Lengths of female fish.2.2.. Rel. Freq. Rel. Freq..... (a) (b).9.9.8.8.7.7.6.6 cdf. cdf....3.3.2.2.. (c) (d) Figure : (a) Relative frequency of lengths of male fish. (b) Relative frequency of lengths of female fish. (c) CDF for male fish. (d) CDF for female fish. The code for creating plots (a) and (c) in Figure is the following: % Create a histogram bar and return a vector of handles to this object hh = bar(a(:,), A(:,2)/N_M, ); % Modify the initial plot set(hh, FaceColor, white ); set(hh, EdgeColor, red ); set(hh, linewidth,.); % Define x and y labels ylabel( Rel. Freq. ); xlabel( ); 7
% Create title title( Lengths of male fish ); % Define only x-axis limits xlim([ 2]); % Create CDF plot. Check what the cumsum function does in Matlab hh = plot(a(:,), cumsum(a(:,2))/n_m, -r ); % Modify the initial plot set(hh, linewidth,.); % Define x and y labels ylabel( cdf ); xlabel( ); % Define only x-axis limits xlim([ 2]); We can also plot the likelihood P(x C k ) for each class as shown in Figure ; note that the shapes of each class are similar to Figure, since we computed the likelihood from the relative frequencies. We observe that fish with length around 3cm are most likely to be male fish, since the likelihood is P(x = 2 C M ).22, whereas for female fish is only P(x = 2 C F )...2 P(x M) P(x F).2 Likelihood P(x C)... Figure : Likelihood function for male and female fish lengths. Exercises Plot histogram of female fish lengths as shown in Figure (b). Figures (a) and (b), show relative frequencies. Modify your code so it can show the actual frequencies. Show both the male and female histograms in the same bar plot. Plot the likelihood functions for male and female fish lengths as shown in Figure. 8
2.3 Compute posterior probabilities Having computed the prior and likelihood, we can now compute the posterior probabilities using Eq.. First we need to compute the evidence P(x), which can be thought as a normalization constant ensuring that we have actual (posterior) probabilities (i.e. P(C k x) and k P(C k x) = ). % Compute evidence vector for each fish length Px = prior_m * lik_m + prior_f * lik_f; % Compute vector of posterior probabilities for male fish lengths post_m = lik_m * prior_m./ Px; % Compute vector of posterior probabilities for female fish lengths post_f = lik_f * prior_f./ Px; We can now plot the posterior probabilities using the following code: % Posterior probabilities for male fish hh = plot(a(:,), post_m, -r ); set(hh, linewidth,.); ylabel( Posterior P(C x) ); xlabel( ); xlim([ 2]); hold on % Posterior probabilities for female fish hh = plot(a(:,), post_f, -b ); set(hh, linewidth,.); legend( P(M x), P(F x), Location, northwest ); hold on % Show decision boundary dec_bound =.7; plot([dec_bound dec_bound], get(gca, ylim ), --k );.9 P(M x) P(F x).8.7 Posterior P(C x).6...3.2. Figure 6: Posterior probabilities for male and female fish lengths. The vertical black line around x = denotes the decision boundary. 9
Figure 6 depicts how the posterior probability changes as a function of the fish length for both the male (red) and female (blue) fish. The vertical black line denotes the decision boundary (i.e. where the posterior probabilities of male and female fish are the same). If we used these probabilities to assign a new (unlabelled) fish, we would classify it as female if its length was on the left of the decision boundary, and male otherwise. Assume that we observe a new fish for which we know that has length x = 8. Should we classify it as male or female fish? % Fish of length x = 8, are in the th element of the likelihood vectors % So we compute the test point likelihood directly from that element >> test_lik_m = lik_m(); >> test_lik_f = lik_f(); % Compute posterior probabilities for each class >> test_post_m = test_lik_m * prior_m / Px() test_post_m =.2 >> test_post_f = test_lik_f * prior_f / Px() test_post_f =.87 Hence the fish would be classified as female, which could be observed directly from Figure 6. 2. Bayes decision rule In the previous section (for the sake of illustration) we computed the actual posterior probabilities for each class and then assigned each example to the class with the maximum posterior probability. However, computing posterior probabilities for real life problems is often impractical, mainly due the denominator in the Bayes theorem (i.e. the evidence). Since our goal is to classify a test example to the most probable class, we can compute their ratio: P(C M x) P(C F x) = P(x C M )P(C M ) P(x) P(x C F )P(C F ) P(x) = P(x C M)P(C M ) P(x C F )P(C F ) If the ratio in the above equation is greater than then x is classified as M, if x is less than then x is classified as F. As you can observe, the denominator term P(x) cancels, so there is no need to compute it at all. Let s compute the ratio of the above example for a test fish of length x = 8. We would expect the ratio to be less than, since the fish should be classified as female. % Compute ratio of posterior probabilities for test example x = 8 >> test_ratio = (test_lik_m * prior_m) / (test_lik_f * prior_f) test_ratio =.29 () Exercises Compute the posterior probabilities for each class using the following prior distributions P(C M ) =.9 and P(C F =.). Create the likelihood and the posterior probability plots as shown in the previous sections. What do you observe? Does the likelihood depend on the prior? Classify the test example x = 8 using the updated posterior probabilities. Assuming equal prior probabilities, classify the following test examples: x = 2, 9, 2, 6.