Mercurial > hg > tvii
comparison tests/test_kmeans.py @ 87:9d5a5e9f5c3b
add kmeans + dataset
| author | Jeff Hammel <k0scist@gmail.com> |
|---|---|
| date | Sun, 17 Dec 2017 14:05:57 -0800 |
| parents | |
| children | 596dac7f3e98 |
comparison
equal
deleted
inserted
replaced
| 86:b56d329c238d | 87:9d5a5e9f5c3b |
|---|---|
| 1 #!/usr/bin/env python | |
| 2 | |
| 3 """ | |
| 4 tests K means algorithm | |
| 5 """ | |
| 6 | |
| 7 import unittest | |
| 8 from tvii import kmeans | |
| 9 from nettwerk.dataset.circle import CircularRandom | |
| 10 | |
| 11 | |
| 12 class TestKMeans(unittest.TestCase): | |
| 13 | |
| 14 def test_dualing_gaussians(self): | |
| 15 """tests two gaussian distributions; first, cut overlap""" | |
| 16 # TODO | |
| 17 | |
| 18 def test_circles(self): | |
| 19 """test with two circles of points""" | |
| 20 | |
| 21 # generate two non-overlapping circles | |
| 22 n_points = 10000 # per circle | |
| 23 p1 = CircularRandom((-1.5, 0), 1)(n_points) | |
| 24 p2 = CircularRandom((1.5, 0), 1)(n_points) | |
| 25 | |
| 26 # run kmeans | |
| 27 classes, centroids = kmeans.kmeans(p1+p2, 2) | |
| 28 | |
| 29 # sanity | |
| 30 assert len(centroids) == 2 | |
| 31 assert len(classes) == 2 | |
| 32 | |
| 33 # the centroids should have opposite x values | |
| 34 xprod = centroids[0][0] * centroids[1][0] | |
| 35 assert xprod < 0. | |
| 36 assert abs(xprod + 2.25) < 0.1 | |
| 37 | |
| 38 # assert we're kinda close | |
| 39 for c in centroids: | |
| 40 c = [abs(i) for i in c] | |
| 41 assert abs(c[0]-1.5) < 0.1 | |
| 42 assert abs(c[1]) < 0.1 | |
| 43 | |
| 44 # its a pretty clean break; our points should be exact, most likely | |
| 45 if centroids[0][0] < 0.: | |
| 46 left = 0 | |
| 47 right = 1 | |
| 48 else: | |
| 49 left = 1 | |
| 50 right = 0 | |
| 51 assert sorted(p1) == sorted(classes[left]) | |
| 52 assert sorted(p2) == sorted(classes[right]) | |
| 53 | |
| 54 def test_help(self): | |
| 55 """smoketest for CLI""" | |
| 56 | |
| 57 try: | |
| 58 kmeans.main(['--help']) | |
| 59 except SystemExit: | |
| 60 # this is expected | |
| 61 pass | |
| 62 | |
| 63 | |
| 64 if __name__ == '__main__': | |
| 65 unittest.main() |
