123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286 |
- import json
- from dotenv import load_dotenv
- from vanna import add_ddl
- load_dotenv()
- from functools import wraps
- from flask import Flask, jsonify, Response, request, redirect, url_for
- import flask
- import os
- from cache import MemoryCache
- from flask_cors import CORS
- from service.result import success, failed, MyEncoder
- app = Flask(__name__, static_url_path='')
- CORS(app)
- # SETUP
- cache = MemoryCache()
- # from vanna.local import LocalContext_OpenAI
- # vn = LocalContext_OpenAI()
- from vanna.openai.openai_chat import OpenAI_Chat
- from vanna.chromadb import ChromaDB_VectorStore
- from vanna.flask import VannaFlaskApp
- from openai import OpenAI
- api_key = os.getenv('api_key')
- base_url = os.getenv('base_url')
- model_name = os.getenv('model')
- host = os.getenv('host')
- dbname = os.getenv('dbname')
- user = os.getenv('user')
- password = os.getenv('password')
- class MyVanna(ChromaDB_VectorStore, OpenAI_Chat):
- def __init__(self, client=None,config=None):
- ChromaDB_VectorStore.__init__(self, config=config)
- OpenAI_Chat.__init__(self, client=client, config=config)
- client = OpenAI(
- api_key=api_key,
- base_url=base_url,
- )
- vn =MyVanna(client=client,config={'model':model_name})
- # from vanna.remote import VannaDefault
- # vn = VannaDefault(model=os.environ['VANNA_MODEL'], api_key=os.environ['VANNA_API_KEY'])
- # vn.connect_to_snowflake(
- # account=os.environ['SNOWFLAKE_ACCOUNT'],
- # username=os.environ['SNOWFLAKE_USERNAME'],
- # password=os.environ['SNOWFLAKE_PASSWORD'],
- # database=os.environ['SNOWFLAKE_DATABASE'],
- # warehouse=os.environ['SNOWFLAKE_WAREHOUSE'],
- # )
- # vn.connect_to_mysql(host='192.168.3.86', dbname='digitization_data', user='root', password='123456', port=3306)
- # vn.connect_to_mysql(host='192.168.3.91', dbname='digitization_data', user='root', password='2099citu##$$**.com', port=3306)
- vn.connect_to_mysql(host=host, dbname=dbname, user=user, password=password,port=3306)
- # NO NEED TO CHANGE ANYTHING BELOW THIS LINE
- def requires_cache(fields):
- def decorator(f):
- @wraps(f)
- def decorated(*args, **kwargs):
- id = request.args.get('id')
- if id is None:
- # return jsonify({"type": "error", "error": "No id provided"})
- res = failed({}, {"error":"No id provided"})
- return json.dumps(res, ensure_ascii=False, cls=MyEncoder)
- for field in fields:
- if cache.get(id=id, field=field) is None:
- # return jsonify({"type": "error", "error": f"No {field} found"})
- res = failed({}, {"error": f"No {field} found"})
- return json.dumps(res, ensure_ascii=False, cls=MyEncoder)
-
- field_values = {field: cache.get(id=id, field=field) for field in fields}
-
- # Add the id to the field_values
- field_values['id'] = id
- return f(*args, **field_values, **kwargs)
- return decorated
- return decorator
- @app.route('/api/v0/generate_questions', methods=['GET'])
- def generate_questions():
- # return jsonify({
- # "type": "question_list",
- # "questions": vn.generate_questions(),
- # "header": "Here are some questions you can ask:"
- # })
- try:
- response_data = {
- "type": "question_list",
- "questions": vn.generate_questions(),
- "header": "Here are some questions you can ask:"
- }
- res = success(response_data, "success")
- return json.dumps(res, ensure_ascii=False, cls=MyEncoder)
- except Exception as e:
- res = failed({}, {"error": f"{e}"})
- return json.dumps(res, ensure_ascii=False, cls=MyEncoder)
- @app.route('/api/v0/generate_sql', methods=['GET'])
- def generate_sql():
- question = flask.request.args.get('question')
- if question is None:
- # return jsonify({"type": "error", "error": "No question provided"})
- res = failed({}, {"error": "No question provided"})
- return json.dumps(res, ensure_ascii=False, cls=MyEncoder)
- id = cache.generate_id(question=question)
- sql = vn.generate_sql(question=question)
- cache.set(id=id, field='question', value=question)
- cache.set(id=id, field='sql', value=sql)
- # return jsonify(
- # {
- # "type": "sql",
- # "id": id,
- # "text": sql,
- # })
- response_data = {
- "type": "sql",
- "id": id,
- "text": sql,
- }
- res = success(response_data, "success")
- return json.dumps(res, ensure_ascii=False, cls=MyEncoder)
- @app.route('/api/v0/run_sql', methods=['GET'])
- @requires_cache(['sql'])
- def run_sql(id: str, sql: str):
- try:
- df = vn.run_sql(sql=sql)
- cache.set(id=id, field='df', value=df)
- # return jsonify(
- # {
- # "type": "df",
- # "id": id,
- # "df": df.head(10).to_json(orient='records'),
- # })
- response_data = {
- "type": "df",
- "id": id,
- "df": df.head(10).to_json(orient='records'),
- }
- res = success(response_data, "success")
- return json.dumps(res, ensure_ascii=False, cls=MyEncoder)
- except Exception as e:
- # return jsonify({"type": "error", "error": str(e)})
- res = failed({}, {"error": f"{e}"})
- return json.dumps(res, ensure_ascii=False, cls=MyEncoder)
- @app.route('/api/v0/download_csv', methods=['GET'])
- @requires_cache(['df'])
- def download_csv(id: str, df):
- csv = df.to_csv()
- return Response(
- csv,
- mimetype="text/csv",
- headers={"Content-disposition":
- f"attachment; filename={id}.csv"})
- @app.route('/api/v0/generate_plotly_figure', methods=['GET'])
- @requires_cache(['df', 'question', 'sql'])
- def generate_plotly_figure(id: str, df, question, sql):
- try:
- code = vn.generate_plotly_code(question=question, sql=sql, df_metadata=f"Running df.dtypes gives:\n {df.dtypes}")
- fig = vn.get_plotly_figure(plotly_code=code, df=df, dark_mode=False)
- fig_json = fig.to_json()
- cache.set(id=id, field='fig_json', value=fig_json)
- return jsonify(
- {
- "type": "plotly_figure",
- "id": id,
- "fig": fig_json,
- })
- except Exception as e:
- # Print the stack trace
- import traceback
- traceback.print_exc()
- return jsonify({"type": "error", "error": str(e)})
- @app.route('/api/v0/get_training_data', methods=['GET'])
- def get_training_data():
- df = vn.get_training_data()
- return jsonify(
- {
- "type": "df",
- "id": "training_data",
- "df": df.head(25).to_json(orient='records'),
- })
- @app.route('/api/v0/remove_training_data', methods=['POST'])
- def remove_training_data():
- # Get id from the JSON body
- id = flask.request.json.get('id')
- if id is None:
- return jsonify({"type": "error", "error": "No id provided"})
- if vn.remove_training_data(id=id):
- return jsonify({"success": True})
- else:
- return jsonify({"type": "error", "error": "Couldn't remove training data"})
- @app.route('/api/v0/train', methods=['POST'])
- def add_training_data():
- question = flask.request.json.get('question')
- sql = flask.request.json.get('sql')
- ddl = flask.request.json.get('ddl')
- documentation = flask.request.json.get('documentation')
- try:
- id = vn.train(question=question, sql=sql, ddl=ddl, documentation=documentation)
- if ddl:
- vn.train(ddl=ddl)
- # return jsonify({"id": id})
- res = success({"id": id}, "success")
- return json.dumps(res, ensure_ascii=False, cls=MyEncoder)
- except Exception as e:
- print("TRAINING ERROR", e)
- # return jsonify({"type": "error", "error": str(e)})
- res = failed({}, {"error": f"{e}"})
- return json.dumps(res, ensure_ascii=False, cls=MyEncoder)
- @app.route('/api/v0/generate_followup_questions', methods=['GET'])
- @requires_cache(['df', 'question', 'sql'])
- def generate_followup_questions(id: str, df, question, sql):
- followup_questions = vn.generate_followup_questions(question=question, sql=sql, df=df)
- cache.set(id=id, field='followup_questions', value=followup_questions)
- return jsonify(
- {
- "type": "question_list",
- "id": id,
- "questions": followup_questions,
- "header": "Here are some followup questions you can ask:"
- })
- @app.route('/api/v0/load_question', methods=['GET'])
- @requires_cache(['question', 'sql', 'df', 'fig_json', 'followup_questions'])
- def load_question(id: str, question, sql, df, fig_json, followup_questions):
- try:
- return jsonify(
- {
- "type": "question_cache",
- "id": id,
- "question": question,
- "sql": sql,
- "df": df.head(10).to_json(orient='records'),
- "fig": fig_json,
- "followup_questions": followup_questions,
- })
- except Exception as e:
- return jsonify({"type": "error", "error": str(e)})
- @app.route('/api/v0/get_question_history', methods=['GET'])
- def get_question_history():
- return jsonify({"type": "question_history", "questions": cache.get_all(field_list=['question']) })
- @app.route('/')
- def root():
- return app.send_static_file('index.html')
- if __name__ == '__main__':
- app.run(debug=True,host='0.0.0.0',port = 3005)
|