import json from dotenv import load_dotenv from vanna import add_ddl load_dotenv() from functools import wraps from flask import Flask, jsonify, Response, request, redirect, url_for import flask import os from cache import MemoryCache from flask_cors import CORS from service.result import success, failed, MyEncoder app = Flask(__name__, static_url_path='') CORS(app) # SETUP cache = MemoryCache() # from vanna.local import LocalContext_OpenAI # vn = LocalContext_OpenAI() from vanna.openai.openai_chat import OpenAI_Chat from vanna.chromadb import ChromaDB_VectorStore from vanna.flask import VannaFlaskApp from openai import OpenAI api_key = os.getenv('api_key') base_url = os.getenv('base_url') model_name = os.getenv('model') host = os.getenv('host') dbname = os.getenv('dbname') user = os.getenv('user') password = os.getenv('password') class MyVanna(ChromaDB_VectorStore, OpenAI_Chat): def __init__(self, client=None,config=None): ChromaDB_VectorStore.__init__(self, config=config) OpenAI_Chat.__init__(self, client=client, config=config) client = OpenAI( api_key=api_key, base_url=base_url, ) vn =MyVanna(client=client,config={'model':model_name}) # from vanna.remote import VannaDefault # vn = VannaDefault(model=os.environ['VANNA_MODEL'], api_key=os.environ['VANNA_API_KEY']) # vn.connect_to_snowflake( # account=os.environ['SNOWFLAKE_ACCOUNT'], # username=os.environ['SNOWFLAKE_USERNAME'], # password=os.environ['SNOWFLAKE_PASSWORD'], # database=os.environ['SNOWFLAKE_DATABASE'], # warehouse=os.environ['SNOWFLAKE_WAREHOUSE'], # ) # vn.connect_to_mysql(host='192.168.3.86', dbname='digitization_data', user='root', password='123456', port=3306) # vn.connect_to_mysql(host='192.168.3.91', dbname='digitization_data', user='root', password='2099citu##$$**.com', port=3306) vn.connect_to_mysql(host=host, dbname=dbname, user=user, password=password,port=3306) # NO NEED TO CHANGE ANYTHING BELOW THIS LINE def requires_cache(fields): def decorator(f): @wraps(f) def decorated(*args, **kwargs): id = request.args.get('id') if id is None: # return jsonify({"type": "error", "error": "No id provided"}) res = failed({}, {"error":"No id provided"}) return json.dumps(res, ensure_ascii=False, cls=MyEncoder) for field in fields: if cache.get(id=id, field=field) is None: # return jsonify({"type": "error", "error": f"No {field} found"}) res = failed({}, {"error": f"No {field} found"}) return json.dumps(res, ensure_ascii=False, cls=MyEncoder) field_values = {field: cache.get(id=id, field=field) for field in fields} # Add the id to the field_values field_values['id'] = id return f(*args, **field_values, **kwargs) return decorated return decorator @app.route('/api/v0/generate_questions', methods=['GET']) def generate_questions(): # return jsonify({ # "type": "question_list", # "questions": vn.generate_questions(), # "header": "Here are some questions you can ask:" # }) try: response_data = { "type": "question_list", "questions": vn.generate_questions(), "header": "Here are some questions you can ask:" } res = success(response_data, "success") return json.dumps(res, ensure_ascii=False, cls=MyEncoder) except Exception as e: res = failed({}, {"error": f"{e}"}) return json.dumps(res, ensure_ascii=False, cls=MyEncoder) @app.route('/api/v0/generate_sql', methods=['GET']) def generate_sql(): question = flask.request.args.get('question') if question is None: # return jsonify({"type": "error", "error": "No question provided"}) res = failed({}, {"error": "No question provided"}) return json.dumps(res, ensure_ascii=False, cls=MyEncoder) id = cache.generate_id(question=question) sql = vn.generate_sql(question=question) cache.set(id=id, field='question', value=question) cache.set(id=id, field='sql', value=sql) # return jsonify( # { # "type": "sql", # "id": id, # "text": sql, # }) response_data = { "type": "sql", "id": id, "text": sql, } res = success(response_data, "success") return json.dumps(res, ensure_ascii=False, cls=MyEncoder) @app.route('/api/v0/run_sql', methods=['GET']) @requires_cache(['sql']) def run_sql(id: str, sql: str): try: df = vn.run_sql(sql=sql) cache.set(id=id, field='df', value=df) # return jsonify( # { # "type": "df", # "id": id, # "df": df.head(10).to_json(orient='records'), # }) response_data = { "type": "df", "id": id, "df": df.head(10).to_json(orient='records'), } res = success(response_data, "success") return json.dumps(res, ensure_ascii=False, cls=MyEncoder) except Exception as e: # return jsonify({"type": "error", "error": str(e)}) res = failed({}, {"error": f"{e}"}) return json.dumps(res, ensure_ascii=False, cls=MyEncoder) @app.route('/api/v0/download_csv', methods=['GET']) @requires_cache(['df']) def download_csv(id: str, df): csv = df.to_csv() return Response( csv, mimetype="text/csv", headers={"Content-disposition": f"attachment; filename={id}.csv"}) @app.route('/api/v0/generate_plotly_figure', methods=['GET']) @requires_cache(['df', 'question', 'sql']) def generate_plotly_figure(id: str, df, question, sql): try: code = vn.generate_plotly_code(question=question, sql=sql, df_metadata=f"Running df.dtypes gives:\n {df.dtypes}") fig = vn.get_plotly_figure(plotly_code=code, df=df, dark_mode=False) fig_json = fig.to_json() cache.set(id=id, field='fig_json', value=fig_json) return jsonify( { "type": "plotly_figure", "id": id, "fig": fig_json, }) except Exception as e: # Print the stack trace import traceback traceback.print_exc() return jsonify({"type": "error", "error": str(e)}) @app.route('/api/v0/get_training_data', methods=['GET']) def get_training_data(): df = vn.get_training_data() return jsonify( { "type": "df", "id": "training_data", "df": df.head(25).to_json(orient='records'), }) @app.route('/api/v0/remove_training_data', methods=['POST']) def remove_training_data(): # Get id from the JSON body id = flask.request.json.get('id') if id is None: return jsonify({"type": "error", "error": "No id provided"}) if vn.remove_training_data(id=id): return jsonify({"success": True}) else: return jsonify({"type": "error", "error": "Couldn't remove training data"}) @app.route('/api/v0/train', methods=['POST']) def add_training_data(): question = flask.request.json.get('question') sql = flask.request.json.get('sql') ddl = flask.request.json.get('ddl') documentation = flask.request.json.get('documentation') try: id = vn.train(question=question, sql=sql, ddl=ddl, documentation=documentation) if ddl: vn.train(ddl=ddl) # return jsonify({"id": id}) res = success({"id": id}, "success") return json.dumps(res, ensure_ascii=False, cls=MyEncoder) except Exception as e: print("TRAINING ERROR", e) # return jsonify({"type": "error", "error": str(e)}) res = failed({}, {"error": f"{e}"}) return json.dumps(res, ensure_ascii=False, cls=MyEncoder) @app.route('/api/v0/generate_followup_questions', methods=['GET']) @requires_cache(['df', 'question', 'sql']) def generate_followup_questions(id: str, df, question, sql): followup_questions = vn.generate_followup_questions(question=question, sql=sql, df=df) cache.set(id=id, field='followup_questions', value=followup_questions) return jsonify( { "type": "question_list", "id": id, "questions": followup_questions, "header": "Here are some followup questions you can ask:" }) @app.route('/api/v0/load_question', methods=['GET']) @requires_cache(['question', 'sql', 'df', 'fig_json', 'followup_questions']) def load_question(id: str, question, sql, df, fig_json, followup_questions): try: return jsonify( { "type": "question_cache", "id": id, "question": question, "sql": sql, "df": df.head(10).to_json(orient='records'), "fig": fig_json, "followup_questions": followup_questions, }) except Exception as e: return jsonify({"type": "error", "error": str(e)}) @app.route('/api/v0/get_question_history', methods=['GET']) def get_question_history(): return jsonify({"type": "question_history", "questions": cache.get_all(field_list=['question']) }) @app.route('/') def root(): return app.send_static_file('index.html') if __name__ == '__main__': app.run(debug=True,host='0.0.0.0',port = 3005)