8 Star 0 Fork 0

amer/ValRegWithGLM

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
main.py 17.88 KB
一键复制 编辑 原始数据 按行查看 历史
#!/usr/bin/env pytho
# -*- coding: utf-8 -*-
# @Author : junpeng_chen
# @Time : 2023/7/26 10:39
# @File : main
# @annotation : 值识别任务主函数,conda环境为163:2041 - smartbi_merge
import os
# os.environ['CUDA_VISIBLE_DEVICES'] = "6"
from flask import Flask, request, jsonify
from flask_cors import CORS, cross_origin
import json
import requests
from modules.cond_conn_op_recognition import CondConnOpRecognition
from modules.group_by_recognition import GroupByRecognition
from modules.limit_recognition import LimitRecognition
from modules.literal_recognition import LiteralRecognition
from modules.mdx_recognition import MDXRecognition
from modules.measure_recognition import MeasureRecognition
from modules.measure_type_recognition import MeasureTypeRecognition
from modules.order_by_recognition import OrderByRecognition
from modules.others_recognition import OthersRecognition
from modules.select_package import Sel_Package
from modules.time_recognition import TimeRecognition
from modules.agg_recognition import AggRecognition
from modules.header_recognition import HeaderRecognition
from modules.norm_recognition import NormRecognition
from modules.col_and_row_regulator import ColAndRowRegulator
from utils.log import logger
from utils.chatglm_utils import ChatGLMUtils
from utils.timer import Timer
class ValRegMain:
def __init__(self, model=None, data_dict_path=None, headers_path=None):
try:
self.headers_path = headers_path
self.data_dict_path = data_dict_path
self.model = model
if self.model == None:
print("模型赋值失败")
raise Exception
self.sel_reg = None
self.data_dict = None
self.literal_reg = None
self.measure_type_reg = None
self.limit_reg = None
self.time_reg = None
self.measure_reg = None
self.other_reg = None
self.agg_reg = None
self.mdx_reg = None
self.group_by_reg = None
self.order_by_reg = None
self.cond_conn_op_reg = None
self.col_and_row_regulator = None
except Exception as err:
logger.error("[ValRegMain]Model initialization err[{}]: {}".format(type(err), str(err)))
raise
def configure_components(self):
"""
配置各子任务组件
"""
try:
# header选择
self.header_reg = HeaderRecognition(self.model, self.headers_path)
# query规范化
self.norm_reg = NormRecognition(self.model, self.headers_path)
# select相关组件加载
self.sel_reg = Sel_Package()
# 将原数据字典中的header列表转换成字典形式,方便根据header查询对应属性
self.data_dict = self.sel_reg.list_to_dict()
# 字面量识别相关组件加载
self.literal_reg = LiteralRecognition(self.model)
# measure识别组件
self.measure_type_reg = MeasureTypeRecognition(self.data_dict_path)
# limit识别
self.limit_reg = LimitRecognition()
# 时间维度条件格式转换
self.time_reg = TimeRecognition(self.model)
# 数值条件识别与格式转换
self.measure_reg = MeasureRecognition(self.model)
# 其他类型字面量条件识别与格式转换
self.other_reg = OthersRecognition(self.data_dict_path)
# agg识别
self.agg_reg = AggRecognition(self.model, self.data_dict_path)
# mdx识别
self.mdx_reg = MDXRecognition(self.model)
# group_by识别
self.group_by_reg = GroupByRecognition()
# order_by识别
self.order_by_reg = OrderByRecognition()
# cond_conn_op识别
self.cond_conn_op_reg = CondConnOpRecognition()
# col与row字段的识别
self.col_and_row_regulator = ColAndRowRegulator()
except Exception as err:
logger.error("[ValRegMain]Components initialization err[{}]: {}".format(type(err), str(err)))
raise
# 获取当前文件所在目录的绝对路径
current_directory = os.path.dirname(os.path.abspath(__file__))
model = ChatGLMUtils()
model.load_model(local_loading=True,
model_path="/home/cike/ytc/GLM2/checkpoints/Lora7_25_combine/checkpoint-64000/pytorch_model.bin",
cuda_index=7)
if model:
val_reg_main = ValRegMain(model=model, data_dict_path=current_directory+"/data/data_dict.json", headers_path=current_directory+"/data/new_header.json")
val_reg_main.configure_components()
else:
print("模型加载失败")
raise Exception
def main_function(query: str) -> dict:
"""
值识别任务-联调主函数
:param query: 自然语言查询
:return:
"""
timer = Timer()
global_timer = Timer()
global_timer.start()
conds = []
having = []
'''
查询问题的规范化
'''
timer.start()
query = val_reg_main.norm_reg.recognize(query)
timer.stop('norm')
'''
获得对应的列名
'''
timer.start()
column = val_reg_main.header_reg.recognize(query)
timer.stop('header')
'''
获取sel识别字段
'''
# 输入样例,封装sel并返回输出结果
timer.start()
try:
sel = val_reg_main.sel_reg.select_package(query, column, val_reg_main.data_dict)["sel"]
except Exception as err:
logger.error("[Main]select error[{}]: {}".format(type(err), str(err)))
raise
timer.stop('select')
'''
字面量与操作符识别
'''
timer.start()
try:
literals_meta = val_reg_main.literal_reg.recognize(query, sel)
literals = literals_meta['processed']['literals']
literals_origin = literals_meta['origin']
except Exception as err:
logger.error("[Main]literals&opts error[{}]: {}".format(type(err), str(err)))
raise
timer.stop('literals&opts')
'''
measure识别
'''
timer.start()
try:
measure = val_reg_main.measure_type_reg.measure_type(sel)['measure']
except Exception as err:
logger.error("[Main]measure error[{}]: {}".format(type(err), str(err)))
raise
timer.stop('measure')
'''
limit识别
'''
timer.start()
try:
limit = val_reg_main.limit_reg.limit(query)['limit']
except Exception as err:
logger.error("[Main]limit error[{}]: {}".format(type(err), str(err)))
raise
timer.stop('limit')
'''
时间维度条件格式转换
'''
timer.start()
try:
# 根据literal的结果过滤时间条件。literals是子任务1输出的格式化数据
time_literals = val_reg_main.time_reg.filter(literals)
# 调用转换方法
time_conds_meta = val_reg_main.time_reg.transform(query, time_literals)
time_conds = time_conds_meta['processed']
time_conds_origin = time_conds_meta['origin']
# conds为最终的类SQL中的conds字段
conds.extend(time_conds)
except Exception as err:
logger.error("[Main]time condition transformation error[{}]: {}".format(type(err), str(err)))
raise
timer.stop('time condition transformation')
'''
其他类型字面量转换
'''
timer.start()
try:
other_literals = val_reg_main.other_reg.filter(literals)
other_conds = val_reg_main.other_reg.transform(query, other_literals)
conds.extend(other_conds)
except Exception as err:
logger.error("[Main]other condition transformation error[{}]: {}".format(type(err), str(err)))
raise
timer.stop('other condition transformation')
'''
agg
'''
timer.start()
try:
agg_meta = val_reg_main.agg_reg.recognition(query, sel, measure)
agg = agg_meta['processed']['agg']
agg_origin = agg_meta['origin']
except Exception as err:
logger.error("[Main]agg error[{}]: {}".format(type(err), str(err)))
raise
timer.stop('agg_reg')
'''
数值条件格式转换
'''
timer.start()
try:
# 根据literal的结果过滤数值条件。query是自然语言查询,literals是子任务1输出的格式化数据
measure_literals,measure_agg = val_reg_main.measure_reg.filter(query, literals,agg)
# 调用转换方法
measure_meta = val_reg_main.measure_reg.transform(query, measure_literals, measure,measure_agg)
measure_conds, measure_having = measure_meta['processed']
measure_origin = measure_meta['origin']
# conds为最终的类SQL中的conds字段,having为最终的类SQL中的having字段
conds.extend(measure_conds)
having.extend(measure_having)
except Exception as err:
logger.error("[Main]measure condition transformation error[{}]: {}".format(type(err), str(err)))
raise
timer.stop('measure condition transformation')
'''
mdx
'''
timer.start()
try:
mdx_meta = val_reg_main.mdx_reg.recognition(query, sel, measure)
mdx = mdx_meta['processed']['mdx']
mdx_origin = mdx_meta['origin']
except Exception as err:
logger.error("[Main]agg error[{}]: {}".format(type(err), str(err)))
raise
timer.stop('mdx_reg')
'''
group_by
'''
timer.start()
try:
group_by = val_reg_main.group_by_reg.group_by(query, sel, agg, measure)
except Exception as err:
logger.error("[Main]group by error[{}]: {}".format(type(err), str(err)))
raise
timer.stop('group by')
'''
order_by
'''
try:
order_by = val_reg_main.order_by_reg.order_by(query, sel, having, measure)
except Exception as err:
logger.error("[Main]order by error[{}]: {}".format(type(err), str(err)))
raise
timer.stop('order by')
'''
cond_conn_op
'''
timer.start()
try:
cond_conn_op = val_reg_main.cond_conn_op_reg.conn_op_recognition(conds)
except Exception as err:
logger.error("[Main]cond_conn_op error[{}]: {}".format(type(err), str(err)))
raise
timer.stop('cond_conn_op')
'''
col 与 row
'''
try:
col, row = val_reg_main.col_and_row_regulator.col_and_row(sel, measure)
except Exception as err:
logger.error("[Main]col and row error[{}]: {}".format(type(err), str(err)))
raise
response = dict()
temp_response = dict()
#填充发给数据库的json
response['sel'] = sel
response['agg'] = agg
response['conds'] = conds
response['group_by'] = group_by
response['order_by'] = order_by['order_by'] #先这样改下
response['limit'] = limit
response['having'] = having
response['cond_conn_op'] = cond_conn_op['cond_conn_op'] #先这样改下
response['measure'] = measure
# 现阶段需要拼接has_time字段、forceType字段、type字段、rowNotEmpty字段和data_source字段,暂时保持默认值
response['row'] = row # row的话先把sel里面的dimension给塞进去
response['col'] = col # col先不管
response['has_time'] = 'true'
response['forceType'] = ''
response['type'] = 'TABLE_CROSS'
response['rowNotEmpty'] = 'true'
response['data_source'] = 'AUGMENTED_DATASET'
#填充临时json
temp_response['agg'] = agg_origin
temp_response['literal_and_operator'] = literals_origin
temp_response['measure_conds'] = measure_origin
temp_response['time_conds'] = time_conds_origin
temp_response['mdx'] = mdx_origin
temp_response['header_reg'] = column
temp_response['normalization'] = query
global_timer.stop('Total')
return response, temp_response
# app = Flask(__name__)
# @app.route('/getSQLs', methods=["POST"])
# def getSQLs():
# try:
# header_data = request.get_json()
# if "query" not in header_data or "column" not in header_data:
# raise Exception
# query = header_data['query']
# column = header_data['column']
# except Exception as err:
# logger.error("parameter passing error[{}]: {}".format(type(err),str(err)))
# raise
# # 核心识别函数
# sqls = main_function(query, column)
# return jsonify(sqls)
#理论上新的系统不需要再给Header了
# @app.route('/getSQLs', methods=["POST"])
# def getSQLs():
# try:
# header_data = request.get_json()
# print(header_data)
# if "query" not in header_data:
# raise Exception
# query = header_data['query']
# except Exception as err:
# logger.error("parameter passing error[{}]: {}".format(type(err),str(err)))
# raise
# # 核心识别函数
# sqls = main_function(query)
# return jsonify(sqls)
def Post(url, data):
response = requests.post(url, data=data)
if response.status_code == 200:
print('请求成功!')
#print('响应内容:', response.text)
else:
print('请求失败,状态码:', response.status_code)
json_data = json.loads(response.text)
return json_data
const_token = None
app = Flask(__name__)
cors = CORS(app)
@app.route('/getSQLs', methods=["POST"])
def getSQLs():
global const_token
url_login = 'http://proj.smartbi.com.cn:9070/aiweb/api/v1/login'
url_sql = 'http://proj.smartbi.com.cn:9070/aiweb/integration/api/v1/query_with_nl2sql'
try:
header_data = request.get_json()['data']
print(header_data)
if "question" not in header_data:
raise Exception
query = header_data['question']
except Exception as err:
logger.error("parameter passing error[{}]: {}".format(type(err),str(err)))
raise
# 核心识别函数
sqls, origin = main_function(query)
# sqls = json_sql
str_sql = json.dumps(sqls)
if not const_token:
data_login = {
'userName': 'huagong1',
'password': 'huagong1'
}
token_data = Post(url_login, data_login)
const_token = token_data['token']
data_sql = {
'table': 'I8a8ae5ca0178549554951b9501785cefe3f00058',
'token': const_token,
'nl2sql': str_sql
}
result = Post(url_sql, data_sql)
code = result['code']
if 'result' in result and result['result'] is not None:
result = result['result']
if 'html' in result:
result = result['html']
#result = json.loads(result)
if code == -2:
data_login = {
'userName': 'huagong1',
'password': 'huagong1'
}
token_data = Post(url_login, data_login)
const_token = token_data['token']
data_sql = {
'table': 'I8a8ae5ca0178549554951b9501785cefe3f00058',
'token': const_token,
'nl2sql': str_sql
}
result = Post(url_sql, data_sql)
code = result['code']
if 'result' in result and result['result'] is not None:
result = result['result']
if 'html' in result:
result = result['html']
#result = json.loads(result)
final_data_list = [sqls]
final_answer_list = [result]
final_origin_list = [origin]
final_result = {
'data': final_data_list,
'answer': final_answer_list,
'model_output': final_origin_list,
'status': 0
}
return jsonify(final_result)
@app.route('/resendJSON', methods=["POST"])
def resendJSON():
global const_token
url_login = 'http://proj.smartbi.com.cn:9070/aiweb/api/v1/login'
url_sql = 'http://proj.smartbi.com.cn:9070/aiweb/integration/api/v1/query_with_nl2sql'
str_sql = ''
try:
data = request.get_json()['data']
str_sql = json.dumps(data)
except Exception as err:
logger.error("json error[{}]: {}".format(type(err),str(err)))
raise
#print(str_sql)
if not const_token:
data_login = {
'userName': 'huagong1',
'password': 'huagong1'
}
token_data = Post(url_login, data_login)
const_token = token_data['token']
data_sql = {
'table': 'I8a8ae5ca0178549554951b9501785cefe3f00058',
'token': const_token,
'nl2sql': str_sql
}
result = Post(url_sql, data_sql)
code = result['code']
if 'result' in result and result['result'] is not None:
result = result['result']
if 'html' in result:
result = result['html']
#result = json.loads(result)
if code == -2:
data_login = {
'userName': 'huagong1',
'password': 'huagong1'
}
token_data = Post(url_login, data_login)
const_token = token_data['token']
data_sql = {
'table': 'I8a8ae5ca0178549554951b9501785cefe3f00058',
'token': const_token,
'nl2sql': str_sql
}
result = Post(url_sql, data_sql)
code = result['code']
if 'result' in result and result['result'] is not None:
result = result['result']
if 'html' in result:
result = result['html']
#result = json.loads(result)
final_result = {
'answer': [result],
}
return jsonify(final_result)
if __name__ == '__main__':
# 这里host是你的后端地址,这里写0.0.0.0, 表示的是这个接口在任何服务器上都可以被访问的到,只需要前端访问该服务器地址就可以的,
# 当然你也可以写死,如222.222.222.222, 那么前端只能访问222.222.222.222, port是该接口的端口号,
# debug = True ,表示的是,调试模式,每次修改代码后不用重新启动服务
# app.run(host='0.0.0.0', port=5000, debug=True)
# 上线部署时运行的端口为5003
app.run(host='0.0.0.0', port=5003, debug=False)
# 第一次查询耗时长个2s左右
# query = "东南地区上个月的MQL个数是多少?"
# response = main_function(query)
# print(response)
# response = main_function(query)
# print(response)
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/cjp_scut/val-reg-with-glm.git
[email protected]:cjp_scut/val-reg-with-glm.git
cjp_scut
val-reg-with-glm
ValRegWithGLM
master

搜索帮助