app.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. from dotenv import load_dotenv
  2. load_dotenv()
  3. from functools import wraps
  4. from flask import Flask, jsonify, Response, request, redirect, url_for
  5. import flask
  6. import os
  7. from cache import MemoryCache
  8. app = Flask(__name__, static_url_path='')
  9. # SETUP
  10. cache = MemoryCache()
  11. # from vanna.local import LocalContext_OpenAI
  12. # vn = LocalContext_OpenAI()
  13. from vanna.remote import VannaDefault
  14. vn = VannaDefault(model=os.environ['VANNA_MODEL'], api_key=os.environ['VANNA_API_KEY'])
  15. vn.connect_to_snowflake(
  16. account=os.environ['SNOWFLAKE_ACCOUNT'],
  17. username=os.environ['SNOWFLAKE_USERNAME'],
  18. password=os.environ['SNOWFLAKE_PASSWORD'],
  19. database=os.environ['SNOWFLAKE_DATABASE'],
  20. warehouse=os.environ['SNOWFLAKE_WAREHOUSE'],
  21. )
  22. # NO NEED TO CHANGE ANYTHING BELOW THIS LINE
  23. def requires_cache(fields):
  24. def decorator(f):
  25. @wraps(f)
  26. def decorated(*args, **kwargs):
  27. id = request.args.get('id')
  28. if id is None:
  29. return jsonify({"type": "error", "error": "No id provided"})
  30. for field in fields:
  31. if cache.get(id=id, field=field) is None:
  32. return jsonify({"type": "error", "error": f"No {field} found"})
  33. field_values = {field: cache.get(id=id, field=field) for field in fields}
  34. # Add the id to the field_values
  35. field_values['id'] = id
  36. return f(*args, **field_values, **kwargs)
  37. return decorated
  38. return decorator
  39. @app.route('/api/v0/generate_questions', methods=['GET'])
  40. def generate_questions():
  41. return jsonify({
  42. "type": "question_list",
  43. "questions": vn.generate_questions(),
  44. "header": "Here are some questions you can ask:"
  45. })
  46. @app.route('/api/v0/generate_sql', methods=['GET'])
  47. def generate_sql():
  48. question = flask.request.args.get('question')
  49. if question is None:
  50. return jsonify({"type": "error", "error": "No question provided"})
  51. id = cache.generate_id(question=question)
  52. sql = vn.generate_sql(question=question)
  53. cache.set(id=id, field='question', value=question)
  54. cache.set(id=id, field='sql', value=sql)
  55. return jsonify(
  56. {
  57. "type": "sql",
  58. "id": id,
  59. "text": sql,
  60. })
  61. @app.route('/api/v0/run_sql', methods=['GET'])
  62. @requires_cache(['sql'])
  63. def run_sql(id: str, sql: str):
  64. try:
  65. df = vn.run_sql(sql=sql)
  66. cache.set(id=id, field='df', value=df)
  67. return jsonify(
  68. {
  69. "type": "df",
  70. "id": id,
  71. "df": df.head(10).to_json(orient='records'),
  72. })
  73. except Exception as e:
  74. return jsonify({"type": "error", "error": str(e)})
  75. @app.route('/api/v0/download_csv', methods=['GET'])
  76. @requires_cache(['df'])
  77. def download_csv(id: str, df):
  78. csv = df.to_csv()
  79. return Response(
  80. csv,
  81. mimetype="text/csv",
  82. headers={"Content-disposition":
  83. f"attachment; filename={id}.csv"})
  84. @app.route('/api/v0/generate_plotly_figure', methods=['GET'])
  85. @requires_cache(['df', 'question', 'sql'])
  86. def generate_plotly_figure(id: str, df, question, sql):
  87. try:
  88. code = vn.generate_plotly_code(question=question, sql=sql, df_metadata=f"Running df.dtypes gives:\n {df.dtypes}")
  89. fig = vn.get_plotly_figure(plotly_code=code, df=df, dark_mode=False)
  90. fig_json = fig.to_json()
  91. cache.set(id=id, field='fig_json', value=fig_json)
  92. return jsonify(
  93. {
  94. "type": "plotly_figure",
  95. "id": id,
  96. "fig": fig_json,
  97. })
  98. except Exception as e:
  99. # Print the stack trace
  100. import traceback
  101. traceback.print_exc()
  102. return jsonify({"type": "error", "error": str(e)})
  103. @app.route('/api/v0/get_training_data', methods=['GET'])
  104. def get_training_data():
  105. df = vn.get_training_data()
  106. return jsonify(
  107. {
  108. "type": "df",
  109. "id": "training_data",
  110. "df": df.head(25).to_json(orient='records'),
  111. })
  112. @app.route('/api/v0/remove_training_data', methods=['POST'])
  113. def remove_training_data():
  114. # Get id from the JSON body
  115. id = flask.request.json.get('id')
  116. if id is None:
  117. return jsonify({"type": "error", "error": "No id provided"})
  118. if vn.remove_training_data(id=id):
  119. return jsonify({"success": True})
  120. else:
  121. return jsonify({"type": "error", "error": "Couldn't remove training data"})
  122. @app.route('/api/v0/train', methods=['POST'])
  123. def add_training_data():
  124. question = flask.request.json.get('question')
  125. sql = flask.request.json.get('sql')
  126. ddl = flask.request.json.get('ddl')
  127. documentation = flask.request.json.get('documentation')
  128. try:
  129. id = vn.train(question=question, sql=sql, ddl=ddl, documentation=documentation)
  130. return jsonify({"id": id})
  131. except Exception as e:
  132. print("TRAINING ERROR", e)
  133. return jsonify({"type": "error", "error": str(e)})
  134. @app.route('/api/v0/generate_followup_questions', methods=['GET'])
  135. @requires_cache(['df', 'question', 'sql'])
  136. def generate_followup_questions(id: str, df, question, sql):
  137. followup_questions = vn.generate_followup_questions(question=question, sql=sql, df=df)
  138. cache.set(id=id, field='followup_questions', value=followup_questions)
  139. return jsonify(
  140. {
  141. "type": "question_list",
  142. "id": id,
  143. "questions": followup_questions,
  144. "header": "Here are some followup questions you can ask:"
  145. })
  146. @app.route('/api/v0/load_question', methods=['GET'])
  147. @requires_cache(['question', 'sql', 'df', 'fig_json', 'followup_questions'])
  148. def load_question(id: str, question, sql, df, fig_json, followup_questions):
  149. try:
  150. return jsonify(
  151. {
  152. "type": "question_cache",
  153. "id": id,
  154. "question": question,
  155. "sql": sql,
  156. "df": df.head(10).to_json(orient='records'),
  157. "fig": fig_json,
  158. "followup_questions": followup_questions,
  159. })
  160. except Exception as e:
  161. return jsonify({"type": "error", "error": str(e)})
  162. @app.route('/api/v0/get_question_history', methods=['GET'])
  163. def get_question_history():
  164. return jsonify({"type": "question_history", "questions": cache.get_all(field_list=['question']) })
  165. @app.route('/')
  166. def root():
  167. return app.send_static_file('index.html')
  168. if __name__ == '__main__':
  169. app.run(debug=True)