|
@@ -0,0 +1,213 @@
|
|
|
+from dotenv import load_dotenv
|
|
|
+load_dotenv()
|
|
|
+
|
|
|
+from functools import wraps
|
|
|
+from flask import Flask, jsonify, Response, request, redirect, url_for
|
|
|
+import flask
|
|
|
+import os
|
|
|
+from cache import MemoryCache
|
|
|
+
|
|
|
+app = Flask(__name__, static_url_path='')
|
|
|
+
|
|
|
+# SETUP
|
|
|
+cache = MemoryCache()
|
|
|
+
|
|
|
+# from vanna.local import LocalContext_OpenAI
|
|
|
+# vn = LocalContext_OpenAI()
|
|
|
+
|
|
|
+from vanna.remote import VannaDefault
|
|
|
+vn = VannaDefault(model=os.environ['VANNA_MODEL'], api_key=os.environ['VANNA_API_KEY'])
|
|
|
+
|
|
|
+vn.connect_to_snowflake(
|
|
|
+ account=os.environ['SNOWFLAKE_ACCOUNT'],
|
|
|
+ username=os.environ['SNOWFLAKE_USERNAME'],
|
|
|
+ password=os.environ['SNOWFLAKE_PASSWORD'],
|
|
|
+ database=os.environ['SNOWFLAKE_DATABASE'],
|
|
|
+ warehouse=os.environ['SNOWFLAKE_WAREHOUSE'],
|
|
|
+)
|
|
|
+
|
|
|
+# NO NEED TO CHANGE ANYTHING BELOW THIS LINE
|
|
|
+def requires_cache(fields):
|
|
|
+ def decorator(f):
|
|
|
+ @wraps(f)
|
|
|
+ def decorated(*args, **kwargs):
|
|
|
+ id = request.args.get('id')
|
|
|
+
|
|
|
+ if id is None:
|
|
|
+ return jsonify({"type": "error", "error": "No id provided"})
|
|
|
+
|
|
|
+ for field in fields:
|
|
|
+ if cache.get(id=id, field=field) is None:
|
|
|
+ return jsonify({"type": "error", "error": f"No {field} found"})
|
|
|
+
|
|
|
+ field_values = {field: cache.get(id=id, field=field) for field in fields}
|
|
|
+
|
|
|
+ # Add the id to the field_values
|
|
|
+ field_values['id'] = id
|
|
|
+
|
|
|
+ return f(*args, **field_values, **kwargs)
|
|
|
+ return decorated
|
|
|
+ return decorator
|
|
|
+
|
|
|
+@app.route('/api/v0/generate_questions', methods=['GET'])
|
|
|
+def generate_questions():
|
|
|
+ return jsonify({
|
|
|
+ "type": "question_list",
|
|
|
+ "questions": vn.generate_questions(),
|
|
|
+ "header": "Here are some questions you can ask:"
|
|
|
+ })
|
|
|
+
|
|
|
+@app.route('/api/v0/generate_sql', methods=['GET'])
|
|
|
+def generate_sql():
|
|
|
+ question = flask.request.args.get('question')
|
|
|
+
|
|
|
+ if question is None:
|
|
|
+ return jsonify({"type": "error", "error": "No question provided"})
|
|
|
+
|
|
|
+ id = cache.generate_id(question=question)
|
|
|
+ sql = vn.generate_sql(question=question)
|
|
|
+
|
|
|
+ cache.set(id=id, field='question', value=question)
|
|
|
+ cache.set(id=id, field='sql', value=sql)
|
|
|
+
|
|
|
+ return jsonify(
|
|
|
+ {
|
|
|
+ "type": "sql",
|
|
|
+ "id": id,
|
|
|
+ "text": sql,
|
|
|
+ })
|
|
|
+
|
|
|
+@app.route('/api/v0/run_sql', methods=['GET'])
|
|
|
+@requires_cache(['sql'])
|
|
|
+def run_sql(id: str, sql: str):
|
|
|
+ try:
|
|
|
+ df = vn.run_sql(sql=sql)
|
|
|
+
|
|
|
+ cache.set(id=id, field='df', value=df)
|
|
|
+
|
|
|
+ return jsonify(
|
|
|
+ {
|
|
|
+ "type": "df",
|
|
|
+ "id": id,
|
|
|
+ "df": df.head(10).to_json(orient='records'),
|
|
|
+ })
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ return jsonify({"type": "error", "error": str(e)})
|
|
|
+
|
|
|
+@app.route('/api/v0/download_csv', methods=['GET'])
|
|
|
+@requires_cache(['df'])
|
|
|
+def download_csv(id: str, df):
|
|
|
+ csv = df.to_csv()
|
|
|
+
|
|
|
+ return Response(
|
|
|
+ csv,
|
|
|
+ mimetype="text/csv",
|
|
|
+ headers={"Content-disposition":
|
|
|
+ f"attachment; filename={id}.csv"})
|
|
|
+
|
|
|
+@app.route('/api/v0/generate_plotly_figure', methods=['GET'])
|
|
|
+@requires_cache(['df', 'question', 'sql'])
|
|
|
+def generate_plotly_figure(id: str, df, question, sql):
|
|
|
+ try:
|
|
|
+ code = vn.generate_plotly_code(question=question, sql=sql, df_metadata=f"Running df.dtypes gives:\n {df.dtypes}")
|
|
|
+ fig = vn.get_plotly_figure(plotly_code=code, df=df, dark_mode=False)
|
|
|
+ fig_json = fig.to_json()
|
|
|
+
|
|
|
+ cache.set(id=id, field='fig_json', value=fig_json)
|
|
|
+
|
|
|
+ return jsonify(
|
|
|
+ {
|
|
|
+ "type": "plotly_figure",
|
|
|
+ "id": id,
|
|
|
+ "fig": fig_json,
|
|
|
+ })
|
|
|
+ except Exception as e:
|
|
|
+ # Print the stack trace
|
|
|
+ import traceback
|
|
|
+ traceback.print_exc()
|
|
|
+
|
|
|
+ return jsonify({"type": "error", "error": str(e)})
|
|
|
+
|
|
|
+@app.route('/api/v0/get_training_data', methods=['GET'])
|
|
|
+def get_training_data():
|
|
|
+ df = vn.get_training_data()
|
|
|
+
|
|
|
+ return jsonify(
|
|
|
+ {
|
|
|
+ "type": "df",
|
|
|
+ "id": "training_data",
|
|
|
+ "df": df.head(25).to_json(orient='records'),
|
|
|
+ })
|
|
|
+
|
|
|
+@app.route('/api/v0/remove_training_data', methods=['POST'])
|
|
|
+def remove_training_data():
|
|
|
+ # Get id from the JSON body
|
|
|
+ id = flask.request.json.get('id')
|
|
|
+
|
|
|
+ if id is None:
|
|
|
+ return jsonify({"type": "error", "error": "No id provided"})
|
|
|
+
|
|
|
+ if vn.remove_training_data(id=id):
|
|
|
+ return jsonify({"success": True})
|
|
|
+ else:
|
|
|
+ return jsonify({"type": "error", "error": "Couldn't remove training data"})
|
|
|
+
|
|
|
+@app.route('/api/v0/train', methods=['POST'])
|
|
|
+def add_training_data():
|
|
|
+ question = flask.request.json.get('question')
|
|
|
+ sql = flask.request.json.get('sql')
|
|
|
+ ddl = flask.request.json.get('ddl')
|
|
|
+ documentation = flask.request.json.get('documentation')
|
|
|
+
|
|
|
+ try:
|
|
|
+ id = vn.train(question=question, sql=sql, ddl=ddl, documentation=documentation)
|
|
|
+
|
|
|
+ return jsonify({"id": id})
|
|
|
+ except Exception as e:
|
|
|
+ print("TRAINING ERROR", e)
|
|
|
+ return jsonify({"type": "error", "error": str(e)})
|
|
|
+
|
|
|
+@app.route('/api/v0/generate_followup_questions', methods=['GET'])
|
|
|
+@requires_cache(['df', 'question'])
|
|
|
+def generate_followup_questions(id: str, df, question):
|
|
|
+ followup_questions = vn.generate_followup_questions(question=question, df=df)
|
|
|
+
|
|
|
+ cache.set(id=id, field='followup_questions', value=followup_questions)
|
|
|
+
|
|
|
+ return jsonify(
|
|
|
+ {
|
|
|
+ "type": "question_list",
|
|
|
+ "id": id,
|
|
|
+ "questions": followup_questions,
|
|
|
+ "header": "Here are some followup questions you can ask:"
|
|
|
+ })
|
|
|
+
|
|
|
+@app.route('/api/v0/load_question', methods=['GET'])
|
|
|
+@requires_cache(['question', 'sql', 'df', 'fig_json', 'followup_questions'])
|
|
|
+def load_question(id: str, question, sql, df, fig_json, followup_questions):
|
|
|
+ try:
|
|
|
+ return jsonify(
|
|
|
+ {
|
|
|
+ "type": "question_cache",
|
|
|
+ "id": id,
|
|
|
+ "question": question,
|
|
|
+ "sql": sql,
|
|
|
+ "df": df.head(10).to_json(orient='records'),
|
|
|
+ "fig": fig_json,
|
|
|
+ "followup_questions": followup_questions,
|
|
|
+ })
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ return jsonify({"type": "error", "error": str(e)})
|
|
|
+
|
|
|
+@app.route('/api/v0/get_question_history', methods=['GET'])
|
|
|
+def get_question_history():
|
|
|
+ return jsonify({"type": "question_history", "questions": cache.get_all(field_list=['question']) })
|
|
|
+
|
|
|
+@app.route('/')
|
|
|
+def root():
|
|
|
+ return app.send_static_file('index.html')
|
|
|
+
|
|
|
+if __name__ == '__main__':
|
|
|
+ app.run(debug=True)
|