修改记录:
2017/3/22修改代码中的部分BUG 核心代码 https://github.com/zhusimaji/ml/blob/master/prank.py
def learn_to_rank(self): print 'start to learn rank' new_label = [0 for x in range(self.rank_label)] tao = [] self.weight = [0.0 for x in range(self.rank_cate)] for num in range(self.rank_iter): for i in tqdm(range(self.rank_num)): predict_rank = 0 sumwx = sum([self.weight[x] * self.source_data[i][x + 2] for x in range(len(self.weight))]) # 预测排名 for r in range(self.rank_label): if sumwx - self.br[r] < 0: predict_rank = r break # 获取真实label if self.source_data[i][0] != predict_rank: for r in range(self.rank_label): if self.source_data[i][0] - r < 0: new_label[r] = -1 else: new_label[r] = 1 tao = [new_label[x] if ( sumwx - self.br[x]) * new_label[x] <= 0 else 0.0 for x in range(self.rank_label)] tao_sum = sum(tao) new_weight = [self.weight[x] + tao_sum * self.source_data[i][x+2] for x in range(self.rank_cate)] self.weight = new_weight for r in range(self.rank_label): self.br[r] = self.br[r] - tao[r]