|
@@ -1,4 +1,8 @@
|
|
|
|
+import json
|
|
|
|
+
|
|
from dotenv import load_dotenv
|
|
from dotenv import load_dotenv
|
|
|
|
+from vanna import add_ddl
|
|
|
|
+
|
|
load_dotenv()
|
|
load_dotenv()
|
|
|
|
|
|
from functools import wraps
|
|
from functools import wraps
|
|
@@ -6,8 +10,11 @@ from flask import Flask, jsonify, Response, request, redirect, url_for
|
|
import flask
|
|
import flask
|
|
import os
|
|
import os
|
|
from cache import MemoryCache
|
|
from cache import MemoryCache
|
|
|
|
+from flask_cors import CORS
|
|
|
|
+from service.result import success, failed, MyEncoder
|
|
|
|
|
|
app = Flask(__name__, static_url_path='')
|
|
app = Flask(__name__, static_url_path='')
|
|
|
|
+CORS(app)
|
|
|
|
|
|
# SETUP
|
|
# SETUP
|
|
cache = MemoryCache()
|
|
cache = MemoryCache()
|
|
@@ -15,16 +22,39 @@ cache = MemoryCache()
|
|
# from vanna.local import LocalContext_OpenAI
|
|
# from vanna.local import LocalContext_OpenAI
|
|
# vn = LocalContext_OpenAI()
|
|
# vn = LocalContext_OpenAI()
|
|
|
|
|
|
-from vanna.remote import VannaDefault
|
|
|
|
-vn = VannaDefault(model=os.environ['VANNA_MODEL'], api_key=os.environ['VANNA_API_KEY'])
|
|
|
|
|
|
+from vanna.openai.openai_chat import OpenAI_Chat
|
|
|
|
+from vanna.chromadb import ChromaDB_VectorStore
|
|
|
|
+from vanna.flask import VannaFlaskApp
|
|
|
|
+from openai import OpenAI
|
|
|
|
+
|
|
|
|
+api_key = os.getenv('api_key')
|
|
|
|
+base_url = os.getenv('base_url')
|
|
|
|
+model_name = os.getenv('model')
|
|
|
|
+
|
|
|
|
+class MyVanna(ChromaDB_VectorStore, OpenAI_Chat):
|
|
|
|
+ def __init__(self, client=None,config=None):
|
|
|
|
+ ChromaDB_VectorStore.__init__(self, config=config)
|
|
|
|
+ OpenAI_Chat.__init__(self, client=client, config=config)
|
|
|
|
+
|
|
|
|
+client = OpenAI(
|
|
|
|
+ api_key=api_key,
|
|
|
|
+ base_url=base_url,
|
|
|
|
+)
|
|
|
|
+vn =MyVanna(client=client,config={'model':model_name})
|
|
|
|
|
|
-vn.connect_to_snowflake(
|
|
|
|
- account=os.environ['SNOWFLAKE_ACCOUNT'],
|
|
|
|
- username=os.environ['SNOWFLAKE_USERNAME'],
|
|
|
|
- password=os.environ['SNOWFLAKE_PASSWORD'],
|
|
|
|
- database=os.environ['SNOWFLAKE_DATABASE'],
|
|
|
|
- warehouse=os.environ['SNOWFLAKE_WAREHOUSE'],
|
|
|
|
-)
|
|
|
|
|
|
+# from vanna.remote import VannaDefault
|
|
|
|
+# vn = VannaDefault(model=os.environ['VANNA_MODEL'], api_key=os.environ['VANNA_API_KEY'])
|
|
|
|
+
|
|
|
|
+# vn.connect_to_snowflake(
|
|
|
|
+# account=os.environ['SNOWFLAKE_ACCOUNT'],
|
|
|
|
+# username=os.environ['SNOWFLAKE_USERNAME'],
|
|
|
|
+# password=os.environ['SNOWFLAKE_PASSWORD'],
|
|
|
|
+# database=os.environ['SNOWFLAKE_DATABASE'],
|
|
|
|
+# warehouse=os.environ['SNOWFLAKE_WAREHOUSE'],
|
|
|
|
+# )
|
|
|
|
+
|
|
|
|
+# vn.connect_to_mysql(host='192.168.3.86', dbname='digitization_data', user='root', password='123456', port=3306)
|
|
|
|
+vn.connect_to_mysql(host='192.168.3.91', dbname='digitization_data', user='root', password='2099citu##$$**.com', port=3306)
|
|
|
|
|
|
# NO NEED TO CHANGE ANYTHING BELOW THIS LINE
|
|
# NO NEED TO CHANGE ANYTHING BELOW THIS LINE
|
|
def requires_cache(fields):
|
|
def requires_cache(fields):
|
|
@@ -34,11 +64,14 @@ def requires_cache(fields):
|
|
id = request.args.get('id')
|
|
id = request.args.get('id')
|
|
|
|
|
|
if id is None:
|
|
if id is None:
|
|
- return jsonify({"type": "error", "error": "No id provided"})
|
|
|
|
-
|
|
|
|
|
|
+ # return jsonify({"type": "error", "error": "No id provided"})
|
|
|
|
+ res = failed({}, {"error":"No id provided"})
|
|
|
|
+ return json.dumps(res, ensure_ascii=False, cls=MyEncoder)
|
|
for field in fields:
|
|
for field in fields:
|
|
if cache.get(id=id, field=field) is None:
|
|
if cache.get(id=id, field=field) is None:
|
|
- return jsonify({"type": "error", "error": f"No {field} found"})
|
|
|
|
|
|
+ # return jsonify({"type": "error", "error": f"No {field} found"})
|
|
|
|
+ res = failed({}, {"error": f"No {field} found"})
|
|
|
|
+ return json.dumps(res, ensure_ascii=False, cls=MyEncoder)
|
|
|
|
|
|
field_values = {field: cache.get(id=id, field=field) for field in fields}
|
|
field_values = {field: cache.get(id=id, field=field) for field in fields}
|
|
|
|
|
|
@@ -51,18 +84,31 @@ def requires_cache(fields):
|
|
|
|
|
|
@app.route('/api/v0/generate_questions', methods=['GET'])
|
|
@app.route('/api/v0/generate_questions', methods=['GET'])
|
|
def generate_questions():
|
|
def generate_questions():
|
|
- return jsonify({
|
|
|
|
- "type": "question_list",
|
|
|
|
- "questions": vn.generate_questions(),
|
|
|
|
- "header": "Here are some questions you can ask:"
|
|
|
|
- })
|
|
|
|
|
|
+ # return jsonify({
|
|
|
|
+ # "type": "question_list",
|
|
|
|
+ # "questions": vn.generate_questions(),
|
|
|
|
+ # "header": "Here are some questions you can ask:"
|
|
|
|
+ # })
|
|
|
|
+ try:
|
|
|
|
+ response_data = {
|
|
|
|
+ "type": "question_list",
|
|
|
|
+ "questions": vn.generate_questions(),
|
|
|
|
+ "header": "Here are some questions you can ask:"
|
|
|
|
+ }
|
|
|
|
+ res = success(response_data, "success")
|
|
|
|
+ return json.dumps(res, ensure_ascii=False, cls=MyEncoder)
|
|
|
|
+ except Exception as e:
|
|
|
|
+ res = failed({}, {"error": f"{e}"})
|
|
|
|
+ return json.dumps(res, ensure_ascii=False, cls=MyEncoder)
|
|
|
|
|
|
@app.route('/api/v0/generate_sql', methods=['GET'])
|
|
@app.route('/api/v0/generate_sql', methods=['GET'])
|
|
def generate_sql():
|
|
def generate_sql():
|
|
question = flask.request.args.get('question')
|
|
question = flask.request.args.get('question')
|
|
|
|
|
|
if question is None:
|
|
if question is None:
|
|
- return jsonify({"type": "error", "error": "No question provided"})
|
|
|
|
|
|
+ # return jsonify({"type": "error", "error": "No question provided"})
|
|
|
|
+ res = failed({}, {"error": "No question provided"})
|
|
|
|
+ return json.dumps(res, ensure_ascii=False, cls=MyEncoder)
|
|
|
|
|
|
id = cache.generate_id(question=question)
|
|
id = cache.generate_id(question=question)
|
|
sql = vn.generate_sql(question=question)
|
|
sql = vn.generate_sql(question=question)
|
|
@@ -70,12 +116,19 @@ def generate_sql():
|
|
cache.set(id=id, field='question', value=question)
|
|
cache.set(id=id, field='question', value=question)
|
|
cache.set(id=id, field='sql', value=sql)
|
|
cache.set(id=id, field='sql', value=sql)
|
|
|
|
|
|
- return jsonify(
|
|
|
|
- {
|
|
|
|
- "type": "sql",
|
|
|
|
|
|
+ # return jsonify(
|
|
|
|
+ # {
|
|
|
|
+ # "type": "sql",
|
|
|
|
+ # "id": id,
|
|
|
|
+ # "text": sql,
|
|
|
|
+ # })
|
|
|
|
+ response_data = {
|
|
|
|
+ "type": "sql",
|
|
"id": id,
|
|
"id": id,
|
|
"text": sql,
|
|
"text": sql,
|
|
- })
|
|
|
|
|
|
+ }
|
|
|
|
+ res = success(response_data, "success")
|
|
|
|
+ return json.dumps(res, ensure_ascii=False, cls=MyEncoder)
|
|
|
|
|
|
@app.route('/api/v0/run_sql', methods=['GET'])
|
|
@app.route('/api/v0/run_sql', methods=['GET'])
|
|
@requires_cache(['sql'])
|
|
@requires_cache(['sql'])
|
|
@@ -85,15 +138,24 @@ def run_sql(id: str, sql: str):
|
|
|
|
|
|
cache.set(id=id, field='df', value=df)
|
|
cache.set(id=id, field='df', value=df)
|
|
|
|
|
|
- return jsonify(
|
|
|
|
- {
|
|
|
|
- "type": "df",
|
|
|
|
|
|
+ # return jsonify(
|
|
|
|
+ # {
|
|
|
|
+ # "type": "df",
|
|
|
|
+ # "id": id,
|
|
|
|
+ # "df": df.head(10).to_json(orient='records'),
|
|
|
|
+ # })
|
|
|
|
+ response_data = {
|
|
|
|
+ "type": "df",
|
|
"id": id,
|
|
"id": id,
|
|
"df": df.head(10).to_json(orient='records'),
|
|
"df": df.head(10).to_json(orient='records'),
|
|
- })
|
|
|
|
|
|
+ }
|
|
|
|
+ res = success(response_data, "success")
|
|
|
|
+ return json.dumps(res, ensure_ascii=False, cls=MyEncoder)
|
|
|
|
|
|
except Exception as e:
|
|
except Exception as e:
|
|
- return jsonify({"type": "error", "error": str(e)})
|
|
|
|
|
|
+ # return jsonify({"type": "error", "error": str(e)})
|
|
|
|
+ res = failed({}, {"error": f"{e}"})
|
|
|
|
+ return json.dumps(res, ensure_ascii=False, cls=MyEncoder)
|
|
|
|
|
|
@app.route('/api/v0/download_csv', methods=['GET'])
|
|
@app.route('/api/v0/download_csv', methods=['GET'])
|
|
@requires_cache(['df'])
|
|
@requires_cache(['df'])
|
|
@@ -162,11 +224,17 @@ def add_training_data():
|
|
|
|
|
|
try:
|
|
try:
|
|
id = vn.train(question=question, sql=sql, ddl=ddl, documentation=documentation)
|
|
id = vn.train(question=question, sql=sql, ddl=ddl, documentation=documentation)
|
|
|
|
+ if ddl:
|
|
|
|
+ vn.train(ddl=ddl)
|
|
|
|
|
|
- return jsonify({"id": id})
|
|
|
|
|
|
+ # return jsonify({"id": id})
|
|
|
|
+ res = success({"id": id}, "success")
|
|
|
|
+ return json.dumps(res, ensure_ascii=False, cls=MyEncoder)
|
|
except Exception as e:
|
|
except Exception as e:
|
|
print("TRAINING ERROR", e)
|
|
print("TRAINING ERROR", e)
|
|
- return jsonify({"type": "error", "error": str(e)})
|
|
|
|
|
|
+ # return jsonify({"type": "error", "error": str(e)})
|
|
|
|
+ res = failed({}, {"error": f"{e}"})
|
|
|
|
+ return json.dumps(res, ensure_ascii=False, cls=MyEncoder)
|
|
|
|
|
|
@app.route('/api/v0/generate_followup_questions', methods=['GET'])
|
|
@app.route('/api/v0/generate_followup_questions', methods=['GET'])
|
|
@requires_cache(['df', 'question', 'sql'])
|
|
@requires_cache(['df', 'question', 'sql'])
|
|
@@ -210,4 +278,4 @@ def root():
|
|
return app.send_static_file('index.html')
|
|
return app.send_static_file('index.html')
|
|
|
|
|
|
if __name__ == '__main__':
|
|
if __name__ == '__main__':
|
|
- app.run(debug=True)
|
|
|
|
|
|
+ app.run(debug=True,host='0.0.0.0',port = 3005)
|