Zain Hoda 2 vuotta sitten
vanhempi
commit
99ce723630
9 muutettua tiedostoa jossa 341 lisäystä ja 0 poistoa
  1. 4 0
      .gitignore
  2. 24 0
      README.md
  3. 213 0
      app.py
  4. 62 0
      cache.py
  5. 4 0
      requirements.txt
  6. 0 0
      static/assets/index-b1a5a2f1.css
  7. 0 0
      static/assets/index-d29524f4.js
  8. 17 0
      static/index.html
  9. 17 0
      static/vanna.svg

+ 4 - 0
.gitignore

@@ -0,0 +1,4 @@
+.env
+venv
+.DS_Store
+__pycache__

+ 24 - 0
README.md

@@ -1,2 +1,26 @@
 # vanna-flask
 Web server for chatting with your database
+
+# Setup
+
+## Set your environment variables
+```
+VANNA_MODEL=
+VANNA_API_KEY=
+SNOWFLAKE_ACCOUNT=
+SNOWFLAKE_USERNAME=
+SNOWFLAKE_PASSWORD=
+SNOWFLAKE_DATABASE=
+SNOWFLAKE_WAREHOUSE=
+```
+
+## Install dependencies
+```
+pip install -r requirements.txt
+```
+
+## Run the server
+```
+python app.py
+```
+

+ 213 - 0
app.py

@@ -0,0 +1,213 @@
+from dotenv import load_dotenv
+load_dotenv()
+
+from functools import wraps
+from flask import Flask, jsonify, Response, request, redirect, url_for
+import flask
+import os
+from cache import MemoryCache
+
+app = Flask(__name__, static_url_path='')
+
+# SETUP
+cache = MemoryCache()
+
+# from vanna.local import LocalContext_OpenAI
+# vn = LocalContext_OpenAI()
+
+from vanna.remote import VannaDefault
+vn = VannaDefault(model=os.environ['VANNA_MODEL'], api_key=os.environ['VANNA_API_KEY'])
+
+vn.connect_to_snowflake(
+    account=os.environ['SNOWFLAKE_ACCOUNT'],
+    username=os.environ['SNOWFLAKE_USERNAME'],
+    password=os.environ['SNOWFLAKE_PASSWORD'],
+    database=os.environ['SNOWFLAKE_DATABASE'],
+    warehouse=os.environ['SNOWFLAKE_WAREHOUSE'],
+)
+
+# NO NEED TO CHANGE ANYTHING BELOW THIS LINE
+def requires_cache(fields):
+    def decorator(f):
+        @wraps(f)
+        def decorated(*args, **kwargs):
+            id = request.args.get('id')
+
+            if id is None:
+                return jsonify({"type": "error", "error": "No id provided"})
+            
+            for field in fields:
+                if cache.get(id=id, field=field) is None:
+                    return jsonify({"type": "error", "error": f"No {field} found"})
+            
+            field_values = {field: cache.get(id=id, field=field) for field in fields}
+            
+            # Add the id to the field_values
+            field_values['id'] = id
+
+            return f(*args, **field_values, **kwargs)
+        return decorated
+    return decorator
+
+@app.route('/api/v0/generate_questions', methods=['GET'])
+def generate_questions():
+    return jsonify({
+        "type": "question_list", 
+        "questions": vn.generate_questions(),
+        "header": "Here are some questions you can ask:"
+        })
+
+@app.route('/api/v0/generate_sql', methods=['GET'])
+def generate_sql():
+    question = flask.request.args.get('question')
+
+    if question is None:
+        return jsonify({"type": "error", "error": "No question provided"})
+
+    id = cache.generate_id(question=question)
+    sql = vn.generate_sql(question=question)
+
+    cache.set(id=id, field='question', value=question)
+    cache.set(id=id, field='sql', value=sql)
+
+    return jsonify(
+        {
+            "type": "sql", 
+            "id": id,
+            "text": sql,
+        })
+
+@app.route('/api/v0/run_sql', methods=['GET'])
+@requires_cache(['sql'])
+def run_sql(id: str, sql: str):
+    try:
+        df = vn.run_sql(sql=sql)
+
+        cache.set(id=id, field='df', value=df)
+
+        return jsonify(
+            {
+                "type": "df", 
+                "id": id,
+                "df": df.head(10).to_json(orient='records'),
+            })
+
+    except Exception as e:
+        return jsonify({"type": "error", "error": str(e)})
+
+@app.route('/api/v0/download_csv', methods=['GET'])
+@requires_cache(['df'])
+def download_csv(id: str, df):
+    csv = df.to_csv()
+
+    return Response(
+        csv,
+        mimetype="text/csv",
+        headers={"Content-disposition":
+                 f"attachment; filename={id}.csv"})
+
+@app.route('/api/v0/generate_plotly_figure', methods=['GET'])
+@requires_cache(['df', 'question', 'sql'])
+def generate_plotly_figure(id: str, df, question, sql):
+    try:
+        code = vn.generate_plotly_code(question=question, sql=sql, df_metadata=f"Running df.dtypes gives:\n {df.dtypes}")
+        fig = vn.get_plotly_figure(plotly_code=code, df=df, dark_mode=False)
+        fig_json = fig.to_json()
+
+        cache.set(id=id, field='fig_json', value=fig_json)
+
+        return jsonify(
+            {
+                "type": "plotly_figure", 
+                "id": id,
+                "fig": fig_json,
+            })
+    except Exception as e:
+        # Print the stack trace
+        import traceback
+        traceback.print_exc()
+
+        return jsonify({"type": "error", "error": str(e)})
+
+@app.route('/api/v0/get_training_data', methods=['GET'])
+def get_training_data():
+    df = vn.get_training_data()
+
+    return jsonify(
+    {
+        "type": "df", 
+        "id": "training_data",
+        "df": df.head(25).to_json(orient='records'),
+    })
+
+@app.route('/api/v0/remove_training_data', methods=['POST'])
+def remove_training_data():
+    # Get id from the JSON body
+    id = flask.request.json.get('id')
+
+    if id is None:
+        return jsonify({"type": "error", "error": "No id provided"})
+
+    if vn.remove_training_data(id=id):
+        return jsonify({"success": True})
+    else:
+        return jsonify({"type": "error", "error": "Couldn't remove training data"})
+
+@app.route('/api/v0/train', methods=['POST'])
+def add_training_data():
+    question = flask.request.json.get('question')
+    sql = flask.request.json.get('sql')
+    ddl = flask.request.json.get('ddl')
+    documentation = flask.request.json.get('documentation')
+
+    try:
+        id = vn.train(question=question, sql=sql, ddl=ddl, documentation=documentation)
+
+        return jsonify({"id": id})
+    except Exception as e:
+        print("TRAINING ERROR", e)
+        return jsonify({"type": "error", "error": str(e)})
+
+@app.route('/api/v0/generate_followup_questions', methods=['GET'])
+@requires_cache(['df', 'question'])
+def generate_followup_questions(id: str, df, question):
+    followup_questions = vn.generate_followup_questions(question=question, df=df)
+
+    cache.set(id=id, field='followup_questions', value=followup_questions)
+
+    return jsonify(
+        {
+            "type": "question_list", 
+            "id": id,
+            "questions": followup_questions,
+            "header": "Here are some followup questions you can ask:"
+        })
+
+@app.route('/api/v0/load_question', methods=['GET'])
+@requires_cache(['question', 'sql', 'df', 'fig_json', 'followup_questions'])
+def load_question(id: str, question, sql, df, fig_json, followup_questions):
+    try:
+        return jsonify(
+            {
+                "type": "question_cache", 
+                "id": id,
+                "question": question,
+                "sql": sql,
+                "df": df.head(10).to_json(orient='records'),
+                "fig": fig_json,
+                "followup_questions": followup_questions,
+            })
+
+    except Exception as e:
+        return jsonify({"type": "error", "error": str(e)})
+
+@app.route('/api/v0/get_question_history', methods=['GET'])
+def get_question_history():
+    return jsonify({"type": "question_history", "questions": cache.get_all(field_list=['question']) })
+
+@app.route('/')
+def root():
+    return app.send_static_file('index.html')
+
+if __name__ == '__main__':
+    app.run(debug=True)

+ 62 - 0
cache.py

@@ -0,0 +1,62 @@
+from abc import ABC, abstractmethod
+import uuid
+
+class Cache(ABC):
+    @abstractmethod
+    def generate_id(self, *args, **kwargs):
+        pass
+
+    @abstractmethod
+    def get(self, id, field):
+        pass
+
+    @abstractmethod
+    def get_all(self, field_list) -> list:
+        pass
+
+    @abstractmethod
+    def set(self, id, field, value):
+        pass
+
+    @abstractmethod
+    def delete(self, id):
+        pass
+
+
+class MemoryCache(Cache):
+    def __init__(self):
+        self.cache = {}
+
+    def generate_id(self, *args, **kwargs):
+        return str(uuid.uuid4())
+
+    def set(self, id, field, value):
+        if id not in self.cache:
+            self.cache[id] = {}
+
+        self.cache[id][field] = value
+
+    def get(self, id, field):
+        if id not in self.cache:
+            return None
+
+        if field not in self.cache[id]:
+            return None
+
+        return self.cache[id][field]
+
+    def get_all(self, field_list) -> list:
+        return [
+            {
+                "id": id,
+                **{
+                    field: self.get(id=id, field=field)
+                    for field in field_list
+                }
+            }
+            for id in self.cache
+        ]
+
+    def delete(self, id):
+        if id in self.cache:
+            del self.cache[id]

+ 4 - 0
requirements.txt

@@ -0,0 +1,4 @@
+flask
+vanna[snowflake]
+db-dtypes
+python-dotenv

Tiedoston diff-näkymää rajattu, sillä se on liian suuri
+ 0 - 0
static/assets/index-b1a5a2f1.css


Tiedoston diff-näkymää rajattu, sillä se on liian suuri
+ 0 - 0
static/assets/index-d29524f4.js


+ 17 - 0
static/index.html

@@ -0,0 +1,17 @@
+<!doctype html>
+<html lang="en">
+  <head>
+    <meta charset="UTF-8" />
+    <link rel="icon" type="image/svg+xml" href="/vanna.svg" />
+    <meta name="viewport" content="width=device-width, initial-scale=1.0" />
+    <link href="https://fonts.googleapis.com/css2?family=Roboto+Slab:wght@350&display=swap" rel="stylesheet">
+    <script src="https://cdn.plot.ly/plotly-latest.min.js" type="text/javascript"></script>
+    <title>Vanna.AI</title>
+    <script type="module" crossorigin src="/assets/index-d29524f4.js"></script>
+    <link rel="stylesheet" href="/assets/index-b1a5a2f1.css">
+  </head>
+  <body class="bg-white dark:bg-slate-900">
+    <div id="app"></div>
+    
+  </body>
+</html>

+ 17 - 0
static/vanna.svg

@@ -0,0 +1,17 @@
+<?xml version="1.0" encoding="UTF-8" standalone="no"?>
+<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN" "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd">
+<!-- Created with Vectornator (http://vectornator.io/) -->
+<svg height="100%" stroke-miterlimit="10" style="fill-rule:nonzero;clip-rule:evenodd;stroke-linecap:round;stroke-linejoin:round;" version="1.1" viewBox="0 0 1024 1024" width="100%" xml:space="preserve" xmlns="http://www.w3.org/2000/svg" xmlns:vectornator="http://vectornator.io" xmlns:xlink="http://www.w3.org/1999/xlink">
+<defs>
+<linearGradient gradientTransform="matrix(1.09331 0 0 1.09331 -47.1838 -88.8946)" gradientUnits="userSpaceOnUse" id="LinearGradient" x1="237.82" x2="785.097" y1="549.609" y2="549.609">
+<stop offset="0" stop-color="#009efd"/>
+<stop offset="1" stop-color="#2af598"/>
+</linearGradient>
+</defs>
+<g id="Layer-1" vectornator:layerName="Layer 1">
+<g opacity="1" vectornator:layerName="Group 1">
+<path d="M117.718 228.798C117.718 119.455 206.358 30.8151 315.701 30.8151L708.299 30.8151C817.642 30.8151 906.282 119.455 906.282 228.798L906.282 795.202C906.282 904.545 817.642 993.185 708.299 993.185L315.701 993.185C206.358 993.185 117.718 904.545 117.718 795.202L117.718 228.798Z" fill="#0f172a" fill-rule="nonzero" opacity="1" stroke="#374151" stroke-linecap="butt" stroke-linejoin="round" stroke-width="20" vectornator:layerName="Rectangle 1"/>
+<path d="M212.828 215.239C213.095 281.169 213.629 413.028 213.629 413.028C213.629 413.028 511.51 808.257 513.993 809.681C612.915 677.809 810.759 414.065 810.759 414.065C810.759 414.065 811.034 280.901 811.172 214.319C662.105 362.973 662.105 362.973 513.038 511.627C362.933 363.433 362.933 363.433 212.828 215.239Z" fill="url(#LinearGradient)" fill-rule="nonzero" opacity="1" stroke="none"/>
+</g>
+</g>
+</svg>

Kaikkia tiedostoja ei voida näyttää, sillä liian monta tiedostoa muuttui tässä diffissä