首页 > 代码库 > songwenxin

songwenxin

#!/usr/bin/env python
# coding=utf-8
import sys
import math


class Segment:
    def __init__(self):
        self.word1_dict = {}
        self.word1_dict_count = {}
        self.word1_dict_count["<S>"] = 87297

        self.word2_dict = {}
        self.word2_dict_count = {}

        self.gmax_word_length = 0
        self.all_freq = 0

    def initial_dict(self, gram1_file, gram2_file):
        # 读取1_gram文件
        dict_file = open(gram1_file, "r")
        for line in dict_file:
            sequence = line.strip()
            key = sequence.split( )[0]
            value = float(sequence.split( )[1])
            self.word1_dict_count[key] = value
        # 计算频率
        #self.all_freq = sum(self.word1_dict_count.itervalues())
        #self.gmax_word_length = max(len(key) for key in self.word1_dict_count.iterkeys())
        self.gmax_word_length = 30
        self.all_freq = 1579053.0

        for key in self.word1_dict_count:
            self.word1_dict[key] = math.log(self.word1_dict_count[key] / self.all_freq)


        dict_file = open(gram2_file, "r")
        for line in dict_file:
            sequence = line.strip()
            key = sequence.split(" ")[0] + " " + sequence.split(" ")[1]
            value = float(sequence.split(" ")[2])
            first_word = key.split(" ")[0]
            second_word = key.split(" ")[1]
            self.word2_dict_count[key] = float(value)
            if self.word1_dict_count.has_key(first_word):
                self.word2_dict[key] = math.log(value / self.word1_dict_count[first_word])
            else:
                self.word2_dict[key] = self.word1_dict[second_word]

    def get_unkonw_word_prob(self, word):
        return math.log(10. / (self.all_freq * 10 ** len(word)))


    def get_word_prob(self, word):
        if self.word1_dict.has_key(word):
            prob = self.word1_dict[word]
        else:
            prob = self.get_unkonw_word_prob(word)
        return prob


    def get_word_trans_prob(self, first_word, second_word):
        trans_word = first_word + " " + second_word
        # print trans_word
        if self.word2_dict_count.has_key(trans_word):
            trans_prob =                 math.log(self.word2_dict_count[trans_word] / self.word1_dict_count[first_word])
        else:
            trans_prob = self.get_word_prob(second_word)
        return trans_prob


    def get_best_pre_node(self, sequence, node, node_state_list):
        max_seg_length = min([node, self.gmax_word_length])
        pre_node_list = []

        for segment_length in range(1, max_seg_length + 1):
            segment_start_node = node - segment_length
            segment = sequence[segment_start_node:node]

            pre_node = segment_start_node

            if pre_node == 0:
                segment_prob = self.get_word_trans_prob("<S>", segment)
            else:
                pre_pre_node = node_state_list[pre_node]["pre_node"]
                pre_pre_word = sequence[pre_pre_node:pre_node]
                segment_prob = self.get_word_trans_prob(pre_pre_word, segment)

            pre_node_prob_sum = node_state_list[pre_node]["prob_sum"]
            candidate_prob_sum = pre_node_prob_sum + segment_prob

            pre_node_list.append((pre_node, candidate_prob_sum))

        # 找到最大的候选概率值
        (best_pre_node, best_prob_sum) = max(pre_node_list, key=lambda d: d[1])
        return (best_pre_node, best_prob_sum)



    def mp_seg(self, sequence):
        sequence = sequence.strip()
        node_state_list = []

        ini_state = {}
        ini_state["pre_node"] = -1
        ini_state["prob_sum"] = 0
        node_state_list.append(ini_state)

        for node in range(1, len(sequence) + 1):
            (best_pre_node, best_prob_sum) = self.get_best_pre_node(sequence, node, node_state_list)#找最佳前驱并记录累计和
            cur_node = {}
            cur_node["pre_node"] = best_pre_node
            cur_node["prob_sum"] = best_prob_sum
            node_state_list.append(cur_node)

        best_path = []
        node = len(sequence)  # 最后一个点
        best_path.append(node)
        while True:
            pre_node = node_state_list[node]["pre_node"]
            if pre_node == -1:
                break
            node = pre_node
            best_path.append(node)
        best_path.reverse()


        word_list = []
        for i in range(len(best_path) - 1):
            left = best_path[i]
            right = best_path[i + 1]
            word = sequence[left:right]
            word_list.append(word)

        seg_sequence = " ".join(word_list)
        return seg_sequence





# test
if __name__ == __main__:
    segmenter = Segment()
    segResult = []
    segmenter.initial_dict("data.uni", "data.bi")
    #sequence = "这是一个测试"
    #seg_sequence = segmenter.mp_seg(sequence)
    #print seg_sequence
    with open(query.10w.seg.random.test.newn, rt) as f:
        for line in f:
            sequence = line
            seg_sequence = segmenter.mp_seg(sequence)
            segResult.append(seg_sequence)
    with open(query.test,w) as w:
        for i in range(len(segResult)):
            w.write(segResult[i]+\n)
            print segResult[i]

 

songwenxin