Parcourir la source

第一次提交

yulongyan_citu il y a 1 an
Parent
commit
e44151a564
4 fichiers modifiés avec 132 ajouts et 30 suppressions
  1. 98 30
      app.py
  2. 0 0
      service/__init__.py
  3. 34 0
      service/result.py
  4. 0 0
      static/assets/index-b1a5a2f1.css

+ 98 - 30
app.py

@@ -1,4 +1,8 @@
+import json
+
 from dotenv import load_dotenv
+from vanna import add_ddl
+
 load_dotenv()
 
 from functools import wraps
@@ -6,8 +10,11 @@ from flask import Flask, jsonify, Response, request, redirect, url_for
 import flask
 import os
 from cache import MemoryCache
+from flask_cors import CORS
+from service.result import success, failed, MyEncoder
 
 app = Flask(__name__, static_url_path='')
+CORS(app)
 
 # SETUP
 cache = MemoryCache()
@@ -15,16 +22,39 @@ cache = MemoryCache()
 # from vanna.local import LocalContext_OpenAI
 # vn = LocalContext_OpenAI()
 
-from vanna.remote import VannaDefault
-vn = VannaDefault(model=os.environ['VANNA_MODEL'], api_key=os.environ['VANNA_API_KEY'])
+from vanna.openai.openai_chat import OpenAI_Chat
+from vanna.chromadb import ChromaDB_VectorStore
+from vanna.flask import VannaFlaskApp
+from openai import OpenAI
+
+api_key = os.getenv('api_key')
+base_url = os.getenv('base_url')
+model_name = os.getenv('model')
+
+class MyVanna(ChromaDB_VectorStore, OpenAI_Chat):
+    def __init__(self, client=None,config=None):
+        ChromaDB_VectorStore.__init__(self, config=config)
+        OpenAI_Chat.__init__(self, client=client, config=config)
+
+client = OpenAI(
+    api_key=api_key,
+    base_url=base_url,
+)   
+vn =MyVanna(client=client,config={'model':model_name})
 
-vn.connect_to_snowflake(
-    account=os.environ['SNOWFLAKE_ACCOUNT'],
-    username=os.environ['SNOWFLAKE_USERNAME'],
-    password=os.environ['SNOWFLAKE_PASSWORD'],
-    database=os.environ['SNOWFLAKE_DATABASE'],
-    warehouse=os.environ['SNOWFLAKE_WAREHOUSE'],
-)
+# from vanna.remote import VannaDefault
+# vn = VannaDefault(model=os.environ['VANNA_MODEL'], api_key=os.environ['VANNA_API_KEY'])
+
+# vn.connect_to_snowflake(
+#     account=os.environ['SNOWFLAKE_ACCOUNT'],
+#     username=os.environ['SNOWFLAKE_USERNAME'],
+#     password=os.environ['SNOWFLAKE_PASSWORD'],
+#     database=os.environ['SNOWFLAKE_DATABASE'],
+#     warehouse=os.environ['SNOWFLAKE_WAREHOUSE'],
+# )
+
+# vn.connect_to_mysql(host='192.168.3.86', dbname='digitization_data', user='root', password='123456', port=3306)
+vn.connect_to_mysql(host='192.168.3.91', dbname='digitization_data', user='root', password='2099citu##$$**.com', port=3306)
 
 # NO NEED TO CHANGE ANYTHING BELOW THIS LINE
 def requires_cache(fields):
@@ -34,11 +64,14 @@ def requires_cache(fields):
             id = request.args.get('id')
 
             if id is None:
-                return jsonify({"type": "error", "error": "No id provided"})
-            
+                # return jsonify({"type": "error", "error": "No id provided"})
+                res = failed({}, {"error":"No id provided"})
+                return json.dumps(res, ensure_ascii=False, cls=MyEncoder)
             for field in fields:
                 if cache.get(id=id, field=field) is None:
-                    return jsonify({"type": "error", "error": f"No {field} found"})
+                    # return jsonify({"type": "error", "error": f"No {field} found"})
+                    res = failed({}, {"error": f"No {field} found"})
+                    return json.dumps(res, ensure_ascii=False, cls=MyEncoder)
             
             field_values = {field: cache.get(id=id, field=field) for field in fields}
             
@@ -51,18 +84,31 @@ def requires_cache(fields):
 
 @app.route('/api/v0/generate_questions', methods=['GET'])
 def generate_questions():
-    return jsonify({
-        "type": "question_list", 
-        "questions": vn.generate_questions(),
-        "header": "Here are some questions you can ask:"
-        })
+    # return jsonify({
+    #     "type": "question_list",
+    #     "questions": vn.generate_questions(),
+    #     "header": "Here are some questions you can ask:"
+    #     })
+    try:
+        response_data = {
+            "type": "question_list",
+            "questions": vn.generate_questions(),
+            "header": "Here are some questions you can ask:"
+            }
+        res = success(response_data, "success")
+        return json.dumps(res, ensure_ascii=False, cls=MyEncoder)
+    except Exception as e:
+        res = failed({}, {"error": f"{e}"})
+        return json.dumps(res, ensure_ascii=False, cls=MyEncoder)
 
 @app.route('/api/v0/generate_sql', methods=['GET'])
 def generate_sql():
     question = flask.request.args.get('question')
 
     if question is None:
-        return jsonify({"type": "error", "error": "No question provided"})
+        # return jsonify({"type": "error", "error": "No question provided"})
+        res = failed({}, {"error": "No question provided"})
+        return json.dumps(res, ensure_ascii=False, cls=MyEncoder)
 
     id = cache.generate_id(question=question)
     sql = vn.generate_sql(question=question)
@@ -70,12 +116,19 @@ def generate_sql():
     cache.set(id=id, field='question', value=question)
     cache.set(id=id, field='sql', value=sql)
 
-    return jsonify(
-        {
-            "type": "sql", 
+    # return jsonify(
+    #     {
+    #         "type": "sql",
+    #         "id": id,
+    #         "text": sql,
+    #     })
+    response_data = {
+            "type": "sql",
             "id": id,
             "text": sql,
-        })
+        }
+    res = success(response_data, "success")
+    return json.dumps(res, ensure_ascii=False, cls=MyEncoder)
 
 @app.route('/api/v0/run_sql', methods=['GET'])
 @requires_cache(['sql'])
@@ -85,15 +138,24 @@ def run_sql(id: str, sql: str):
 
         cache.set(id=id, field='df', value=df)
 
-        return jsonify(
-            {
-                "type": "df", 
+        # return jsonify(
+        #     {
+        #         "type": "df",
+        #         "id": id,
+        #         "df": df.head(10).to_json(orient='records'),
+        #     })
+        response_data = {
+                "type": "df",
                 "id": id,
                 "df": df.head(10).to_json(orient='records'),
-            })
+            }
+        res = success(response_data, "success")
+        return json.dumps(res, ensure_ascii=False, cls=MyEncoder)
 
     except Exception as e:
-        return jsonify({"type": "error", "error": str(e)})
+        # return jsonify({"type": "error", "error": str(e)})
+        res = failed({}, {"error": f"{e}"})
+        return json.dumps(res, ensure_ascii=False, cls=MyEncoder)
 
 @app.route('/api/v0/download_csv', methods=['GET'])
 @requires_cache(['df'])
@@ -162,11 +224,17 @@ def add_training_data():
 
     try:
         id = vn.train(question=question, sql=sql, ddl=ddl, documentation=documentation)
+        if ddl:
+            vn.train(ddl=ddl)
 
-        return jsonify({"id": id})
+        # return jsonify({"id": id})
+        res = success({"id": id}, "success")
+        return json.dumps(res, ensure_ascii=False, cls=MyEncoder)
     except Exception as e:
         print("TRAINING ERROR", e)
-        return jsonify({"type": "error", "error": str(e)})
+        # return jsonify({"type": "error", "error": str(e)})
+        res = failed({}, {"error": f"{e}"})
+        return json.dumps(res, ensure_ascii=False, cls=MyEncoder)
 
 @app.route('/api/v0/generate_followup_questions', methods=['GET'])
 @requires_cache(['df', 'question', 'sql'])
@@ -210,4 +278,4 @@ def root():
     return app.send_static_file('index.html')
 
 if __name__ == '__main__':
-    app.run(debug=True)
+    app.run(debug=True,host='0.0.0.0',port = 3005)

+ 0 - 0
service/__init__.py


+ 34 - 0
service/result.py

@@ -0,0 +1,34 @@
+import logging
+import numpy as np
+import json
+import datetime
+import decimal
+
+
+class MyEncoder(json.JSONEncoder):
+    def default(self, obj):
+        if isinstance(obj, np.integer):
+            return int(obj)
+        elif isinstance(obj, np.floating):
+            return float(obj)
+        elif isinstance(obj, np.ndarray):
+            return obj.tolist()
+        elif isinstance(obj,decimal.Decimal):
+            return float(obj)
+        elif isinstance(obj, datetime.datetime):
+            return obj.strftime("%Y-%m-%d %H:%M:%S")
+        else:
+            return super(MyEncoder, self).default(obj)
+
+
+
+SUCCESS_CODE=20000
+FAILED_CODE=50000
+logger = logging.getLogger('app')
+
+def success(data, msg):
+    return {"code": SUCCESS_CODE, "data": data, "level": 0, "msg": msg}
+
+
+def failed(data, msg):
+    return {"code": FAILED_CODE, "data": data, "level": 1, "msg": msg}

Fichier diff supprimé car celui-ci est trop grand
+ 0 - 0
static/assets/index-b1a5a2f1.css


Certains fichiers n'ont pas été affichés car il y a eu trop de fichiers modifiés dans ce diff