K-Means++算法之NBA后卫分类

984阅读 0评论2012-12-26 qizheguang
分类:

    coolshell有一篇讲K-Mean算法的博文,讲的非常好,研究生阶段学过这个算法。记得研究生阶段学过ISODATA算法,当时的家庭作业是这个算法的实现。K-Means比较简单,简单的讲,就是距离比较进的点,应该聚成一类。

   
    K-Means的算法如下(陈皓大牛总结的):
  1. 随机在图中取K(这里K=2)个种子点。
  2. 然后对图中的所有点求到这K个种子点的距离,假如点Pi离种子点Si最近,那么Pi属于Si点群。(上图中,我们可以看到A,B属于上面的种子点,C,D,E属于下面中部的种子点)
  3. 接下来,我们要移动种子点到属于他的“点群”的中心。(见图上的第三步)
  4. 然后重复第2)和第3)步,直到,种子点没有移动(我们可以看到图中的第四步上面的种子点聚合了A,B,C,下面的种子点聚合了D,E)。

    K-Means算法有两个缺点:

    1 K 是指定的,给你一堆样本,其实是很难知道应该分成几类的,比如上图,明显是分成四类,但是如果K 取值5,那么结果肯定很蛋疼;

    2 种子点的选择是随机的,如果初始点的选择不同,最后的结果就可能不同。

    对于第二个缺点,有个K-Means++算法,可以解决,很有效的选择初始种子点。这个算法又称为Lloyd算法算法步骤如下,也是陈皓总结的:

  1. 先从我们的数据库随机挑个随机点当“种子点”。
  2. 对于每个点,我们都计算其和最近的一个“种子点”的距离D(x)并保存在一个数组里,然后把这些距离加起来得到Sum(D(x))。
  3. 然后,再取一个随机值,用权重的方式来取计算下一个“种子点”。这个算法的实现是,先取一个能落在Sum(D(x))中的随机值Random,然后用Random -= D(x),直到其<=0,此时的点就是下一个“种子点”。
  4. 重复第(2)和第(3)步直到所有的K个种子点都被选出来。
  5. 进行K-Means算法。

    我是一个NBA球迷,hoopchina是我常逛的论坛,最近林书豪的水平如何是大家口水的话题,我最初的想法是将NBA所有的后卫球员做个聚类分析,看下林书豪属于那个层次的球员。我去NBA官网取下了100名后卫球员的统计数据,格式如下:

  1. 球员 命中率 三分命中率 罚球命中率 篮板 助攻 抢断 盖帽 失误 犯规 得分
  2. 科比-布莱恩特 0.551 0.441 0.92 5.5 4.6 1.4 0 3.8 2.5 26.4
  3. 詹姆斯-哈登 0.441 0.24 0.823 4.6 4.5 1.4 0.8 4.1 2 26.4
  4. 拉塞尔-威斯布鲁克 0.388 0.279 0.754 4.9 8.5 1.3 0.1 3 2.9 19.6
  5. 克里斯-保罗 0.478 0.4 0.913 3.3 10.3 2.3 0 2.1 2.4 17
    数据太多,我选择了篮板/助攻/得分三个维度来作为衡量后卫球员的指标,我用awk选择了我关心的指标,存成了文件,文件如下:

  1. 科比-布莱恩特 5.5000 4.6000 26.4000
  2. 詹姆斯-哈登 4.6000 4.5000 26.4000
  3. 拉塞尔-威斯布鲁克 4.9000 8.5000 19.6000
  4. 克里斯-保罗 3.3000 10.3000 17.0000
  5. 尼古拉斯-巴通姆 6.4000 3.1000 19.0000
  6. Damian-Lillard 3.1000 6.6000 18.4000
  7. O.J.-梅奥 3.6000 2.9000 21.5000
  8. 。。。。。。
    格式如下:姓名    篮板    助攻    得分。有一个问题是这样的,助攻和得分是不同的,一个球员10个助攻的难度,不低于得20分的,如果不将衡量指标归一化,聚类就会不公平,得分就会变成主要的指标,而篮板和助攻就会变成次要的指标。但是如何归一化?

    我采取的算法,计算所有的球员的平均值,比如2个篮板 4个助攻 8分,那么我就会将篮板*4,主攻*2,得分*1,
这样来提高篮板的比重,提高助攻的比重。
    下面我们开始看下代码,代码流程包括,加载数据---归一化数据-----K-Meas++聚类。


   
  1. #include<stdio.h>
  2. #include<stdlib.h>
  3. #include<unistd.h>
  4. #include<math.h>

  5. #define NORMALIZE

  6. #define FILEPATH "./data_main"
  7. #define MEASURE_MAX
  8. #define BUFSIZE 4096
  9. #define PLAYER_NUM 100
  10. #define NAME_MAX 256
  11. #define MEASURE_DIMENSION 3
  12. typedef enum
  13. {
  14.     HIT_RATE =1,
  15.     HIT_RATE_3P,
  16.     HIT_RATE_FREE,
  17.     BOARD,
  18.     ASSIST,
  19.     STEAL,
  20.     BLOCK,
  21.     TURNOVER,
  22.     FAULT,
  23.     SCORE
  24. };


  25. typedef struct player
  26. {
  27.     char name[NAME_MAX];
  28.     double measure[MEASURE_DIMENSION];
  29.     int group;
  30. }player;

  31. double ratio[MEASURE_DIMENSION] = {0.0};
  32. double ratio_sqrt[MEASURE_DIMENSION] = {0.0};
  33. double randf(double m)
  34. {
  35.         return m * rand() / (RAND_MAX - 1.);
  36. }
  37. struct player* load_data(const char* path,int player_num)
  38. {
  39.     FILE* fp= NULL;
  40.     char buf[BUFSIZE] = {0};
  41.     struct player* players = (struct player*) malloc(sizeof(player)*player_num);
  42.     if(players == NULL)
  43.     {
  44.         fprintf(stderr,"malloc failed for players\n");
  45.         return NULL;
  46.     }

  47.     if(access(FILEPATH,R_OK) < 0)
  48.     {
  49.         fprintf(stderr,"can not find the file %s\n",FILEPATH);
  50.         goto err_out;
  51.     }

  52.     fp = fopen(FILEPATH,"rb");
  53.     if(fp == NULL)
  54.     {
  55.         fprintf(stderr,"open file(%s) failed\n",FILEPATH);
  56.         goto err_out;
  57.     }

  58.     int player_index = 0;
  59.     while(fgets(buf,BUFSIZE,fp))
  60.     {
  61.         char *delimit = "\t";
  62.         char *save_ptr;
  63.         char *token=strtok_r(buf,delimit,&save_ptr);

  64.         int field_index = 0;
  65.         while(token != NULL)
  66.         {
  67.             if(field_index == 0)
  68.             {
  69.                 strncpy(players[player_index].name,token,NAME_MAX-1);
  70.                 players[player_index].name[NAME_MAX-1] = '\0';
  71.             }
  72.             else
  73.             {
  74.                 players[player_index].measure[field_index-1] = atof(token);
  75.             }
  76.             token = strtok_r(NULL,delimit,&save_ptr);
  77.             field_index++;
  78.         }
  79.         if(field_index != MEASURE_DIMENSION + 1 )
  80.         {
  81.             fprintf(stderr,"data file have err format ,exit\n" );
  82.             goto err_out;
  83.         }
  84.         player_index++;
  85.         if(player_index == player_num)
  86.         {
  87.             fprintf(stderr,"more than %d players existed in data file\n",player_num);
  88.             break;
  89.         }
  90.     }

  91.     fprintf(stderr,"%d player record got\n",player_index);
  92.     return players;
  93. err_out:
  94.     if(players)
  95.     {
  96.         free(players);
  97.         return NULL;
  98.     }

  99. }

  100. int calc_ratio(struct player* players,int player_num,double* ratio)
  101. {
  102.     int i ;
  103.     double average[MEASURE_DIMENSION] ;
  104.     for (i = 0;i<MEASURE_DIMENSION;i++)
  105.     {
  106.         average[i] = 0.0;
  107.     }
  108.     int j = 0;
  109.     for(i = 0;i<player_num;i++)
  110.     {
  111.         for(j = 0;j<MEASURE_DIMENSION;j++)
  112.         {
  113.             average[j] +=players[i].measure[j];
  114.         }
  115.     }

  116.     for (i = 0;i<MEASURE_DIMENSION;i++)
  117.     {
  118.         average[i]/=player_num;
  119.         ratio_sqrt[i] = 10/average[i];
  120.     }
  121.     for(i = 0; i<player_num;i++)
  122.     {
  123.         for(j = 0;j<MEASURE_DIMENSION;j++)
  124.         {
  125.             players[i].measure[j] = ratio_sqrt[j]*players[i].measure[j];
  126.         }
  127.     }
  128.     return 0;
  129. }

  130. double distance(struct player* player_A,struct player* player_B)
  131. {
  132.     int i = 0;
  133.     double distance = 0.0;
  134.         for(i = 0 ;i<MEASURE_DIMENSION;i++ )
  135.         {
  136.             distance +=pow( (player_A->measure[i] - player_B->measure[i]),2);
  137.         }
  138.     return distance;
  139. }

  140. int nearest(struct player* player, struct player* cent, int n_cluster, double *d2)
  141. {
  142.     int i, min_i;
  143.     struct player* c;
  144.     double d, min_d;

  145.      //for (c = cent, i = 0; i < n_cluster; i++, c++)
  146.      {
  147.         min_d = HUGE_VAL;
  148.         min_i = player->group;
  149.        
  150.         for (c = cent, i = 0; i < n_cluster; i++, c++)
  151.         {
  152.             if (min_d > (d = distance(c, player))) {
  153.                 min_d = d; min_i = i;
  154.             }
  155.         }
  156.     }
  157.     if (d2) *d2 = min_d;
  158.     return min_i;
  159. }

  160. void K_findseed(struct player* players, int player_num, struct player* cent, int n_cent)
  161. {
  162.     int i, j;
  163.     int n_cluster;
  164.     double sum, *d = malloc(sizeof(double) * player_num);

  165.     struct player* p;
  166.     cent[0] = players[ rand() % player_num ];
  167.     for (n_cluster = 1; n_cluster < n_cent; n_cluster++) {
  168.         sum = 0;
  169.         
  170.        for (j = 0, p = players; j < player_num; j++, p++)
  171.         {
  172.             nearest(p, cent, n_cluster, d + j);
  173.             sum += d[j];
  174.         }
  175.         sum = randf(sum);
  176.        for (j = 0, p = players; j < player_num; j++, p++)
  177.         {
  178.             if ((sum -= d[j]) > 0) continue;
  179.             cent[n_cluster] = players[j];
  180.             break;
  181.         }
  182.     }
  183.     for (j = 0, p = players; j < player_num; j++, p++)
  184.     {
  185.         p->group = nearest(p, cent, n_cluster, 0);
  186.     }
  187.     free(d);
  188. }

  189. int K_mean_plus(struct player* players,int player_num,int cluster_num)
  190. {
  191.     struct player* center = malloc(sizeof(player)*cluster_num);
  192.     struct player *p,*c;
  193.     K_findseed(players,player_num,center,cluster_num);
  194.     output_result(players,player_num,cluster_num);
  195.     int i ,j ,min_i;
  196.     int changed;
  197.     do {
  198.         
  199.         for (c = center, i = 0; i < cluster_num; i++, c++)
  200.         {
  201.             c->group = 0;
  202.             for(j = 0;j<MEASURE_DIMENSION;j++)
  203.             {
  204.                 c->measure[j] = 0.0;
  205.             }
  206.         }
  207.        
  208.         for (j = 0, p = players; j < player_num; j++, p++)
  209.         {
  210.             c = center+p->group;
  211.             c->group++;
  212.             for(i = 0;i<MEASURE_DIMENSION;i++)
  213.             {
  214.                 c->measure[i] += p->measure[i];
  215.             }
  216.         }
  217.        
  218.         for (c = center, i = 0; i < cluster_num; i++, c++)
  219.         {
  220.             for(j = 0;j<MEASURE_DIMENSION;j++)
  221.                 c->measure[j]/=c->group;
  222.         }
  223.         changed = 0;


  224.         for (j = 0, p = players; j < player_num; j++, p++)
  225.          {
  226.             min_i = nearest(p, center, cluster_num, 0);
  227.             if (min_i != p->group)
  228.             {
  229.                 changed++;
  230.                 p->group = min_i;
  231.             }
  232.         }
  233.         fprintf(stderr,"%d changed \n",changed);
  234.     } while (changed > 2);

  235.     for (c = center, i = 0; i < cluster_num; i++, c++)
  236.     {
  237.         fprintf(stderr,"\ncenter %d\n",i);
  238.         for(j = 0;j<MEASURE_DIMENSION;j++)
  239.         #ifdef NORMALIZE
  240.            fprintf(stderr," %lf\t",c->measure[j]/ratio_sqrt[j]);
  241.         #else
  242.            fprintf(stderr," %lf\t",c->measure[j]);
  243.         #endif
  244.     }
  245.     return 0;
  246. }

  247. int output_result(struct player* players,int player_num,int cluster_num)
  248. {
  249.     int i ,j;
  250.     char cmd[256] = {0};
  251.     struct player *p =players;
  252.     for(i =0 ; i< cluster_num;i++)
  253.     {
  254.         fprintf(stderr,"\nthe group %d\n",i);
  255.         for(j=0,p=players;j<player_num;j++,p++)
  256.         {
  257.             if(p->group == i)
  258.             {
  259.                 snprintf(cmd,256,"cat %s |sed -n \"%dp\"",FILEPATH,j+1);
  260.                 system(cmd);
  261.                // fprintf(stderr,"%s\t",p->name);
  262.             }
  263.         }
  264.     }

  265.     fprintf(stderr,"\n");
  266. }
  267. int main()
  268. {
  269.     struct player* players = load_data(FILEPATH,100);
  270.     if(players == NULL)
  271.     {
  272.         fprintf(stderr,"load data failed\n");
  273.         return -1;
  274.     }
  275.     int ret = 0;
  276. #ifdef NORMALIZE
  277.     ret = calc_ratio(players,100,ratio);
  278. #endif
  279.     if(ret < 0 )
  280.     {
  281.         fprintf(stderr,"calc ratio failed \n");
  282.         return -2;
  283.     }
  284.  
  285.     ret = K_mean_plus(players,100,8);
  286.     
  287.     ret = output_result(players,100,8);
  288. }
呵呵可以看下执行结果:

  1. ......
  2. center 0
  3.  4.187500     1.137500     8.687500    
  4. center 1
  5.  2.982353     4.223529     8.600000    
  6. center 2
  7.  6.900000     3.420000     13.860000    
  8. center 3
  9.  4.360000     4.413333     19.300000    
  10. center 4
  11.  1.650133     1.893333     8.166667    
  12. center 5
  13.  4.075000     9.037500     14.775000    
  14. center 6
  15.  2.794444     2.122222     13.694444    
  16. center 7
  17.  2.857143     6.850000     14.557143    
  18. 。。。。。。
    看一下我们最后的聚类中心,center 3场均19.3分 4.4助攻这是得分狂人组织能力有限型的后卫,也就是说他们更像一个得分后卫。

  1. the group 3
  2. 科比-布莱恩特    5.5000    4.6000    26.4000
  3. 詹姆斯-哈登      4.6000    4.5000    26.4000
  4. O.J.-梅奥        3.6000    2.9000    21.5000
  5. 斯蒂芬-库里      4.4000    5.4000    17.0000
  6. 凯里-欧文        4.1000    6.5000    24.3000
  7. 凯尔-洛瑞        5.8000    6.3000    18.3000
  8. 蒙塔-艾利斯      3.4000    5.4000    20.0000
  9. 肯巴-沃克        3.7000    5.1000    19.0000
  10. DeMar-DeRozan    5.1000    2.1000    20.0000
  11. 德怀恩-韦德      4.1000    4.9000    16.9000
  12. J.R.-史密斯      5.1000    3.0000    16.7000
  13. 雷蒙-塞申斯      3.4000    4.6000    16.3000
  14. 乔-约翰逊        3.9000    3.6000    16.0000
  15. 阿隆-阿夫拉罗    4.5000    2.4000    16.5000
  16. 乔治-希尔        4.2000    4.9000    14.2000
    在分析下center 5,center5的聚类中心在9.03个助攻,场均得分是14.7分,这群后卫的特点是助攻狂人,得分能力不错。
  1. center 5
  2. 4.075000 9.037500 14.775000

  1. the group 5
  2. 拉塞尔-威斯布鲁克     4.9000 8.5000  19.6000
  3. 克里斯-保罗           3.3000 10.3000 17.0000
  4. 托尼-帕克             3.3000 8.1000  14.3000
  5. 格雷维斯-瓦斯奎兹     3.9000 8.6000  11.6000
  6. 拉简-朗多             4.9000 12.6000 14.3000
  7. 朱-霍利迪             4.0000 8.6000  19.1000
  8. 贾米尔-尼尔森         4.0000 8.5000  11.0000
    看下NBA官网的后卫助攻榜的名单,除了布兰顿詹宁斯助攻不逆天,篮板太少,被分到了group7。




  1. 布兰登-詹宁斯    3.1000     7.9000       16.9000
  2. center 5
  3.                  4.075000   9.037500     14.775000    
  4. center 7
  5.                  2.857143   6.850000     14.557143

    center 2也比较有意思,场均篮板6.9个,这些后卫简直就是中锋型后卫。我立刻就想到了拉简-朗多,但是很不幸,朗多的篮板没有想象的那么多,只有4.9个,因为助攻狂人被分到了group 5.

  1. center 2
  2.                 6.900000 3.420000 13.860000
  3. the group 2
  4. 尼古拉斯-巴通姆 6.4000   3.1000   19.0000
  5. 安德鲁-伊格达拉 7.3000   4.0000   14.8000
  6. 泰瑞克-埃文斯   5.6000   3.4000   11.3000
  7. 保罗-乔治       7.6000   3.2000   13.5000
  8. 埃文-特纳       7.6000   3.4000   10.7000
    本文似乎没有解决林书豪是什么层次的组织后卫的问题,但是主要原因是我们的原始数据太过庞杂,像科比这种球员就不应该纳入分类的范畴,因为他的数据不是典型的组织后卫。如果我能拿到联盟首发的组织后卫的数据,我们能分出3或者5个聚类,看到我们的林书豪在那个聚类中。

参考文献
1
2


上一篇:Light TCP proxy开发日志备忘录
下一篇:vim格式化C代码