百度的 Familia 提供了工业界主题向量的应用,现在应该很多的工业界项目中会应用到,也取得不错的效果。官方的文档还是写不够细致甚至还是有点小错误。
github 上给出的 Docker 使用方法如下所示:
Docker
docker run -d \
--name familia \
-e MODEL_NAME=news \
-p 5000:5000 \
orctom/familia
MODEL_NAME can be one of news
/novel
/webpage
/webo
如果你只是使用 news 那么你可以正常的运行,如果你使用其他模型可能需要做下修改,我在 github 上面的 issue 上也看到有人问相同的问题,下午研究了一下搞通了。上面的 webo,如果你要使用“微博”模型需要更名为 weibo,因为你在执行这部分代码时下载的模型路径已经是 weibo 了,官方的 readme 页面一直都没有修改,是放弃了?
模型下载
$ cd model
$ sh download_model.sh
ok,今天博主是想看看微博的主题向量模型,所以下面就以微博主题模型来说明,其他应该是一样的。
使用镜像启动容器
sudo docker run -it -d --name familia -e MODEL_NAME=news -p 5000:5000 orctom/familia sh
这个在启动容器时执行了 sh 的命令,而且你会发现我仍然使用了 news 这个 model 来启动,不要着急,咱们只是把这个容器成功启动,sh 作为启动容器时默认执行的命令执行,因为你看官方的 Dockfile 可以看到如下
CMD ["python", "python/app.py"]
如果你不加 sh 启动程序那么后面我们会把启动的 app 程序杀掉之后就会退出容器,那么我们想做一些修改就实现不了了。
所以在启动容器的时候加上 sh。
容器操作
容器启动之后我们查看以下当前容器的状态
sudo docker ps -a
上面执行完之后你可以看到容易的 Container id 然后我们着手进入容器
sudo docker exec -it ***** sh
上面的****替换为你看到的 Container id
到这一步你成功进入了容器内部,下一步就是我们把现在运行的服务杀掉,因为我们后面要修改源码然后重启服务
ps -ef | grep app.py |grep -v grep | awk {'print $1'} | xargs kill -9
使用上面的代码就是就是杀掉下面执行的 sanic web 程序。
接下来我们把修改的源码贴上来,这里使用的 lrzsz 这个软件实现代码的上传,这里 Docker 镜像是基于 Apache alpine 这个来打包镜像的。现在没有 lrzsz 这个软件包,所以我们需要安装一下。
很可惜你使用
apk add lrzsz
没有任何的反映,是因为现在官方的仓库里面没有,在网上一搜现在 lrzsz 还在测试仓库里面,所以我们需要先添加相应的源。
echo "http://dl-cdn.alpinelinux.org/alpine/edge/testing" >> /etc/apk/repositories;
此时在执行上面的命令就会执行安装成功
这个版本安装的 lrzsz 还是比较高的,如果你以前使用过注意以下操作间就好了
rz ---> lrz
这个版本使用 lrz 来上传,同理 sz 也是一样的需要加个 l。
华丽的分割线
下面就是重点了,也不算是重点,就是修改 app.py 这部分源码。
在 girhub issue 里面看到有人遇到这样的问题,其实我也遇到了,这部分就是要修改源码了
check fail: file open pretext file /familia/model/webo/lda.conf
下面贴出我自己修改的代码,其他模型可以参考,主要核心部分就是修改 modelname,你如果要使用其他模型,你先去看下 model 这个文件夹下面对应的模块子文件夹中有哪些文件,按需修改,改完这些,你可以使用以下网址去看了
# -*- coding: utf-8 -*- import multiprocessing import os import re import traceback from collections import defaultdict from sanic import Sanic from sanic.exceptions import NotFound from sanic.log import logger from sanic.response import json from sanic_openapi import swagger_blueprint, doc from familia_wrapper import InferenceEngineWrapper, TopicalWordEmbeddingsWrapper # import argparse # # parser = argparse.ArgumentParser() # parser.add_argument("--load_lda", type=bool, # help="decide to load lda conf") # parser.add_argument("--load_twe", type=bool, # help="decide to load lda conf") # args = parser.parse_args() app = Sanic("Familia", strict_slashes=True) app.blueprint(swagger_blueprint) app.config.API_TITLE = 'Familia API' app.config.API_DESCRIPTION = 'A Toolkit for Industrial Topic Modeling' app.config.API_PRODUCES_CONTENT_TYPES = ['application/json'] RE_BACKSPACES = re.compile("\b+") model_name = 'weibo' n_workers = int(os.environ.get('WORKERS', multiprocessing.cpu_count())) model_dir = f"/familia/model/{model_name}" emb_file = f"{model_name}_slda.model" # # if args.load_lda: # inference_engine_lda = InferenceEngineWrapper(model_dir, 'lda.conf', emb_file) inference_engine_slda = InferenceEngineWrapper(model_dir, 'slda.conf',emb_file) # if args.load_twe: # twe = TopicalWordEmbeddingsWrapper(model_dir, emb_file) def read_topic_words_from_file(topic_words_file_name='topic_words.lda.txt'): logger.info(f"reading topic_words from file: {topic_words_file_name}") topic_words = defaultdict(list) file_path = os.path.join(model_dir, topic_words_file_name) if not os.path.exists(file_path): logger.warn(f"topic_words file not found: {file_path}") return topic_words with open(file_path, 'r') as f: line = f.readline() while line: pos = line.find('=') line = line[pos + 2:] topic_id, num = line.strip().split('\t') topic_id, num = int(topic_id), int(num) f.readline() items = list() for i in range(num): data = f.readline() word, score = data.strip().split('\t') items.append([word, float(score)]) topic_words[topic_id] = items line = f.readline() return topic_words lda_topic_words = read_topic_words_from_file('topic_words.slda.txt') def get_param(request, param_name, default_value=None, is_list=False): param_value = (request.form.getlist(param_name) if is_list else request.form.getlist(param_name)) or \ request.args.get(param_name) or \ default_value if param_value is None: return param_value value_type = type(param_value) if is_list: return param_value if value_type == list else [param_value] return param_value[0] if value_type == list else param_value def strip_to_none(text: str): if text is None: return None text = text.strip() text = re.sub(RE_BACKSPACES, '', text) if len(text) == 0: return None if text == 'None': return None return text def response(success: bool = True, data=None, message=None): data = {'success': success, 'message': message, 'data': data} data = {k: v for k, v in data.items() if v is not None} try: return json(data, ensure_ascii=False) except Exception as err: logger.error(err, exc_info=True) msg = traceback.format_exc() data = {'success': success, 'message': msg} return json(data, ensure_ascii=False) def error_response(message='Invalid request'): return response(success=False, message=message) def handle_404(request, exception): return api_index(request) def handle_exception(request, exception): return error_response(str(exception)) @app.route('/') @doc.description("ping") def api_index(request): message = f"Familia API is running, check out the api doc at http://{request.host}/swagger/" return response(message=message) @app.exception(NotFound) async def ignore_404s(request, exception): message = f"Yep, I totally found the page: {request.url}" return response(message=message) # # @app.route('/tokenize', methods=["POST"]) # @doc.summary("分词") # @doc.description("简易的 FMM 分词工具,只针对主题模型中出现的词表进行正向匹配") # @doc.consumes(doc.String(name='text', description="文本"), required=True) # @doc.response(200, None, description="""返回一个 list 对象,其中每个元素为分词后的结果。""") # async def api_tokenize(request): # try: # text = get_param(request, 'text') # if text is None: # return error_response() # result = inference_engine_lda.tokenize(text) # return response(data=result) # except Exception as err: # logger.error(err, exc_info=True) # return error_response(str(err)) # @app.route('/lda', methods=["POST"]) # @doc.summary("LDA 模型推断") # @doc.description("使用 LDA 模型对输入文本进行推断,得到其主题分布") # @doc.consumes(doc.Integer(name='n', description="top n,默认 10"), required=False) # @doc.consumes(doc.String(name='text', description="文本"), required=True) # @doc.response(200, None, description="""返回一个 list 对象,存放输入文本对应的稀疏主题分布 # list 中每个元素为 tuple # 每个 tuple 包含一个主题 ID 以及该主题对应的概率,按照概率从大到小排序。 # 例如:[(15, 0.5), (10, 0.25), (1999, 0.25)]""") # async def api_lda(request): # try: # text = get_param(request, 'text') # n = int(get_param(request, 'n', 10)) # if text is None: # return error_response() # words = inference_engine_lda.tokenize(text) # result = inference_engine_lda.lda_infer(words) # result = result[:n] # result = [ # { # 'topic_id': topic_id, # 'score': score, # 'topic_words': twe.nearest_words_around_topic(topic_id), # 'topic_words_poly': lda_topic_words.get(topic_id), # } for topic_id, score in result # ] # return response(data=result) # except Exception as err: # traceback.print_exc() # logger.error(err, exc_info=True) # return error_response(str(err)) @app.route('/slda', methods=["POST"]) @doc.summary("SentenceLDA 模型推断") @doc.description("使用 SentenceLDA 模型对输入文本进行推断,得到其主题分布") @doc.consumes(doc.Integer(name='n', description="top n,默认 10"), required=False) @doc.consumes(doc.String(name='sep', description="多段文本之间的分割符,默认\\n 分割"), required=False) @doc.consumes(doc.String(name='text', description="多段文本"), required=True) @doc.response(200, None, description="""返回一个 list 对象,存放输入文本对应的稀疏主题分布 list 中每个元素为 tuple 每个 tuple 包含一个主题 ID 以及该主题对应的概率,按照概率从大到小排序。 例如:[(15, 0.5), (10, 0.25), (1999, 0.25)]""") async def api_slda(request): try: text = str(get_param(request, 'text')) sep = get_param(request, 'sep') n = int(get_param(request, 'n', 10)) if text is None: return error_response('Invalid request') sentences = text.splitlines() if sep is None else text.split(sep=sep) sentences = map(inference_engine_slda.tokenize, sentences) result = inference_engine_slda.slda_infer(sentences) result = result[:n] result = [ { 'topic_id': topic_id, 'score': score, # 'topic_words': twe.nearest_words_around_topic(topic_id), 'topic_words_poly': lda_topic_words.get(topic_id) } for topic_id, score in result ] return response(data=result) except Exception as err: logger.error(err, exc_info=True) return error_response(str(err)) # @app.route('/distance', methods=["POST"]) # @doc.summary("计算长文本与长文本之间的距离") # @doc.description("计算两个长文本的主题分布之间的距离,包括 jensen_shannon_divergence 和 hellinger_distance") # @doc.consumes(doc.String(name='b', description="文本 b"), required=True) # @doc.consumes(doc.String(name='a', description="文本 a"), required=True) # @doc.response(200, None, description="""返回一个 list 对象,其中有两个 float 元素 # 第一个表示 jensen_shannon_divergence 距离 # 第二个表示 hellinger_distance 距离 # 例如:[0.187232, 0.23431]""") # async def api_distance(request): # try: # a = get_param(request, 'a') # b = get_param(request, 'b') # if a is None or b is None: # return error_response('Invalid request') # words_a = inference_engine_lda.tokenize(a) # words_b = inference_engine_lda.tokenize(b) # result = inference_engine_lda.cal_doc_distance(words_a, words_b) # return response(data=result) # except Exception as err: # logger.error(err, exc_info=True) # return error_response(str(err)) # @app.route('/similarity/keywords', methods=["POST"]) # @doc.summary("关键词计算") # @doc.description("使用 LDA/TWE 模型计算候选关键词与文档的相关性") # @doc.consumes(doc.Boolean(name='use_twe', description="是否使用 TWE 模型, 默认不使用"), required=False) # @doc.consumes(doc.String(name='text', description="文本"), required=True) # @doc.consumes(doc.String(name='keywords', description="关键词列表, 空格分割"), required=True) # @doc.response(200, None, description="""返回一个 list 对象,每个元素为关键词以及其与文档相关性。""") # async def api_similarity_keywords(request): # try: # keywords = get_param(request, 'keywords') # text = get_param(request, 'text') # use_twe = get_param(request, 'use_twe') # if keywords is None or len(keywords) == 0 or text is None: # return error_response('Invalid request') # text = ' '.join(inference_engine_lda.tokenize(text)) # if use_twe: # result = inference_engine_lda.cal_keywords_twe_similarity(keywords, text) # else: # result = inference_engine_lda.cal_keywords_similarity(keywords, text) # return response(data=result) # except Exception as err: # logger.error(err, exc_info=True) # return error_response(str(err)) # @app.route('/similarity/query', methods=["POST"]) # @doc.summary("计算短文本与长文本之间的相关性") # @doc.description("使用 LDA 模型和 TWE 模型分别衡量短文本跟长文本之间的相关性") # @doc.consumes(doc.String(name='text', description="长文本"), required=True) # @doc.consumes(doc.String(name='query', description="短文本"), required=True) # @doc.response(200, None, description="""返回一个 list 对象,其中有两个 float 元素 # 第一个表示根据 LDA 模型得到的相关性 # 第二个表示通过 TWE 模型衡量得到的相关性 # 例如:[0.397232, 0.45431]""") # async def api_similarity_query(request): # try: # query = get_param(request, 'query') # text = get_param(request, 'text') # if query is None or text is None: # return error_response('Invalid request') # words_query = inference_engine_lda.tokenize(query) # words_text = inference_engine_lda.tokenize(text) # result = inference_engine_lda.cal_query_doc_similarity(words_query, words_text) # return response(data=result) # except Exception as err: # logger.error(err, exc_info=True) # return error_response(str(err)) # @app.route('/nearest-words', methods=["POST"]) # @doc.summary("寻求与目标主题最相关的词") # @doc.description("对模型中的所有词语进行检索,通过计算 cosine 相似度,返回最相关的 n 个词语") # @doc.consumes(doc.Integer(name='n', description="top n,默认 10"), required=False) # @doc.consumes(doc.Integer(name='topic_id', description="目标主题,可以为多个(一个和多个的返回结果结构不相同), word 与 topic_id 二选一"), required=False) # @doc.consumes(doc.String(name='word', description="目标词,可以为多个(一个和多个的返回结果结构不相同), word 与 topic 二选一"), required=False) # @doc.response(200, None, description="""返回一个 list 对象,长度为 n # list 中每个元素为 tuple,包含了返回词以及该词与目标词的 cosine 相关性,并按照相关性从高到低排序 # 例如输入"篮球"目标词返回前 10 个结果: # [(篮球队,0.833797), # (排球, 0.833721) # ..., # (篮球圈, 0.752021)] # 如果输入目标词不在词典中,则返回 None。""") # async def nearest_words(request): # try: # word = get_param(request, 'word', is_list=True) # topic_id = get_param(request, 'topic_id', is_list=True) # n = int(get_param(request, 'n', 10)) # if word: # result = {w: twe.nearest_words(w, n) for w in word} # return response(data=result) # if topic_id: # result = {int(_id): twe.nearest_words_around_topic(int(_id), n) for _id in topic_id} # return response(data=result) # return error_response() # except Exception as err: # logger.error(err, exc_info=True) # return error_response(str(err)) if __name__ == '__main__': logger.info(f"running familia api with {n_workers} workers") app.run(host='0.0.0.0', port=5000, workers=n_workers)