app.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286
  1. import json
  2. from dotenv import load_dotenv
  3. from vanna import add_ddl
  4. load_dotenv()
  5. from functools import wraps
  6. from flask import Flask, jsonify, Response, request, redirect, url_for
  7. import flask
  8. import os
  9. from cache import MemoryCache
  10. from flask_cors import CORS
  11. from service.result import success, failed, MyEncoder
  12. app = Flask(__name__, static_url_path='')
  13. CORS(app)
  14. # SETUP
  15. cache = MemoryCache()
  16. # from vanna.local import LocalContext_OpenAI
  17. # vn = LocalContext_OpenAI()
  18. from vanna.openai.openai_chat import OpenAI_Chat
  19. from vanna.chromadb import ChromaDB_VectorStore
  20. from vanna.flask import VannaFlaskApp
  21. from openai import OpenAI
  22. api_key = os.getenv('api_key')
  23. base_url = os.getenv('base_url')
  24. model_name = os.getenv('model')
  25. host = os.getenv('host')
  26. dbname = os.getenv('dbname')
  27. user = os.getenv('user')
  28. password = os.getenv('password')
  29. class MyVanna(ChromaDB_VectorStore, OpenAI_Chat):
  30. def __init__(self, client=None,config=None):
  31. ChromaDB_VectorStore.__init__(self, config=config)
  32. OpenAI_Chat.__init__(self, client=client, config=config)
  33. client = OpenAI(
  34. api_key=api_key,
  35. base_url=base_url,
  36. )
  37. vn =MyVanna(client=client,config={'model':model_name})
  38. # from vanna.remote import VannaDefault
  39. # vn = VannaDefault(model=os.environ['VANNA_MODEL'], api_key=os.environ['VANNA_API_KEY'])
  40. # vn.connect_to_snowflake(
  41. # account=os.environ['SNOWFLAKE_ACCOUNT'],
  42. # username=os.environ['SNOWFLAKE_USERNAME'],
  43. # password=os.environ['SNOWFLAKE_PASSWORD'],
  44. # database=os.environ['SNOWFLAKE_DATABASE'],
  45. # warehouse=os.environ['SNOWFLAKE_WAREHOUSE'],
  46. # )
  47. # vn.connect_to_mysql(host='192.168.3.86', dbname='digitization_data', user='root', password='123456', port=3306)
  48. # vn.connect_to_mysql(host='192.168.3.91', dbname='digitization_data', user='root', password='2099citu##$$**.com', port=3306)
  49. vn.connect_to_mysql(host=host, dbname=dbname, user=user, password=password,port=3306)
  50. # NO NEED TO CHANGE ANYTHING BELOW THIS LINE
  51. def requires_cache(fields):
  52. def decorator(f):
  53. @wraps(f)
  54. def decorated(*args, **kwargs):
  55. id = request.args.get('id')
  56. if id is None:
  57. # return jsonify({"type": "error", "error": "No id provided"})
  58. res = failed({}, {"error":"No id provided"})
  59. return json.dumps(res, ensure_ascii=False, cls=MyEncoder)
  60. for field in fields:
  61. if cache.get(id=id, field=field) is None:
  62. # return jsonify({"type": "error", "error": f"No {field} found"})
  63. res = failed({}, {"error": f"No {field} found"})
  64. return json.dumps(res, ensure_ascii=False, cls=MyEncoder)
  65. field_values = {field: cache.get(id=id, field=field) for field in fields}
  66. # Add the id to the field_values
  67. field_values['id'] = id
  68. return f(*args, **field_values, **kwargs)
  69. return decorated
  70. return decorator
  71. @app.route('/api/v0/generate_questions', methods=['GET'])
  72. def generate_questions():
  73. # return jsonify({
  74. # "type": "question_list",
  75. # "questions": vn.generate_questions(),
  76. # "header": "Here are some questions you can ask:"
  77. # })
  78. try:
  79. response_data = {
  80. "type": "question_list",
  81. "questions": vn.generate_questions(),
  82. "header": "Here are some questions you can ask:"
  83. }
  84. res = success(response_data, "success")
  85. return json.dumps(res, ensure_ascii=False, cls=MyEncoder)
  86. except Exception as e:
  87. res = failed({}, {"error": f"{e}"})
  88. return json.dumps(res, ensure_ascii=False, cls=MyEncoder)
  89. @app.route('/api/v0/generate_sql', methods=['GET'])
  90. def generate_sql():
  91. question = flask.request.args.get('question')
  92. if question is None:
  93. # return jsonify({"type": "error", "error": "No question provided"})
  94. res = failed({}, {"error": "No question provided"})
  95. return json.dumps(res, ensure_ascii=False, cls=MyEncoder)
  96. id = cache.generate_id(question=question)
  97. sql = vn.generate_sql(question=question)
  98. cache.set(id=id, field='question', value=question)
  99. cache.set(id=id, field='sql', value=sql)
  100. # return jsonify(
  101. # {
  102. # "type": "sql",
  103. # "id": id,
  104. # "text": sql,
  105. # })
  106. response_data = {
  107. "type": "sql",
  108. "id": id,
  109. "text": sql,
  110. }
  111. res = success(response_data, "success")
  112. return json.dumps(res, ensure_ascii=False, cls=MyEncoder)
  113. @app.route('/api/v0/run_sql', methods=['GET'])
  114. @requires_cache(['sql'])
  115. def run_sql(id: str, sql: str):
  116. try:
  117. df = vn.run_sql(sql=sql)
  118. cache.set(id=id, field='df', value=df)
  119. # return jsonify(
  120. # {
  121. # "type": "df",
  122. # "id": id,
  123. # "df": df.head(10).to_json(orient='records'),
  124. # })
  125. response_data = {
  126. "type": "df",
  127. "id": id,
  128. "df": df.head(10).to_json(orient='records'),
  129. }
  130. res = success(response_data, "success")
  131. return json.dumps(res, ensure_ascii=False, cls=MyEncoder)
  132. except Exception as e:
  133. # return jsonify({"type": "error", "error": str(e)})
  134. res = failed({}, {"error": f"{e}"})
  135. return json.dumps(res, ensure_ascii=False, cls=MyEncoder)
  136. @app.route('/api/v0/download_csv', methods=['GET'])
  137. @requires_cache(['df'])
  138. def download_csv(id: str, df):
  139. csv = df.to_csv()
  140. return Response(
  141. csv,
  142. mimetype="text/csv",
  143. headers={"Content-disposition":
  144. f"attachment; filename={id}.csv"})
  145. @app.route('/api/v0/generate_plotly_figure', methods=['GET'])
  146. @requires_cache(['df', 'question', 'sql'])
  147. def generate_plotly_figure(id: str, df, question, sql):
  148. try:
  149. code = vn.generate_plotly_code(question=question, sql=sql, df_metadata=f"Running df.dtypes gives:\n {df.dtypes}")
  150. fig = vn.get_plotly_figure(plotly_code=code, df=df, dark_mode=False)
  151. fig_json = fig.to_json()
  152. cache.set(id=id, field='fig_json', value=fig_json)
  153. return jsonify(
  154. {
  155. "type": "plotly_figure",
  156. "id": id,
  157. "fig": fig_json,
  158. })
  159. except Exception as e:
  160. # Print the stack trace
  161. import traceback
  162. traceback.print_exc()
  163. return jsonify({"type": "error", "error": str(e)})
  164. @app.route('/api/v0/get_training_data', methods=['GET'])
  165. def get_training_data():
  166. df = vn.get_training_data()
  167. return jsonify(
  168. {
  169. "type": "df",
  170. "id": "training_data",
  171. "df": df.head(25).to_json(orient='records'),
  172. })
  173. @app.route('/api/v0/remove_training_data', methods=['POST'])
  174. def remove_training_data():
  175. # Get id from the JSON body
  176. id = flask.request.json.get('id')
  177. if id is None:
  178. return jsonify({"type": "error", "error": "No id provided"})
  179. if vn.remove_training_data(id=id):
  180. return jsonify({"success": True})
  181. else:
  182. return jsonify({"type": "error", "error": "Couldn't remove training data"})
  183. @app.route('/api/v0/train', methods=['POST'])
  184. def add_training_data():
  185. question = flask.request.json.get('question')
  186. sql = flask.request.json.get('sql')
  187. ddl = flask.request.json.get('ddl')
  188. documentation = flask.request.json.get('documentation')
  189. try:
  190. id = vn.train(question=question, sql=sql, ddl=ddl, documentation=documentation)
  191. if ddl:
  192. vn.train(ddl=ddl)
  193. # return jsonify({"id": id})
  194. res = success({"id": id}, "success")
  195. return json.dumps(res, ensure_ascii=False, cls=MyEncoder)
  196. except Exception as e:
  197. print("TRAINING ERROR", e)
  198. # return jsonify({"type": "error", "error": str(e)})
  199. res = failed({}, {"error": f"{e}"})
  200. return json.dumps(res, ensure_ascii=False, cls=MyEncoder)
  201. @app.route('/api/v0/generate_followup_questions', methods=['GET'])
  202. @requires_cache(['df', 'question', 'sql'])
  203. def generate_followup_questions(id: str, df, question, sql):
  204. followup_questions = vn.generate_followup_questions(question=question, sql=sql, df=df)
  205. cache.set(id=id, field='followup_questions', value=followup_questions)
  206. return jsonify(
  207. {
  208. "type": "question_list",
  209. "id": id,
  210. "questions": followup_questions,
  211. "header": "Here are some followup questions you can ask:"
  212. })
  213. @app.route('/api/v0/load_question', methods=['GET'])
  214. @requires_cache(['question', 'sql', 'df', 'fig_json', 'followup_questions'])
  215. def load_question(id: str, question, sql, df, fig_json, followup_questions):
  216. try:
  217. return jsonify(
  218. {
  219. "type": "question_cache",
  220. "id": id,
  221. "question": question,
  222. "sql": sql,
  223. "df": df.head(10).to_json(orient='records'),
  224. "fig": fig_json,
  225. "followup_questions": followup_questions,
  226. })
  227. except Exception as e:
  228. return jsonify({"type": "error", "error": str(e)})
  229. @app.route('/api/v0/get_question_history', methods=['GET'])
  230. def get_question_history():
  231. return jsonify({"type": "question_history", "questions": cache.get_all(field_list=['question']) })
  232. @app.route('/')
  233. def root():
  234. return app.send_static_file('index.html')
  235. if __name__ == '__main__':
  236. app.run(debug=True,host='0.0.0.0',port = 3005)