Published

February 15, 2023

Here I am exploring the capabiliteis of Text to SQL of OpenAI codex vs fine-tuned t5-3b model (picard).

https://openai.com/blog/openai-codex/

import os
import openai
from typing import Dict, List
import json
import requests

openai.api_key = os.getenv("OPENAI_API_KEY")

def get_schema_string_for_codex(schema_json:Dict):
  # accepts schema json where key is table name and value is list of columns
  tables = []
  for table, columns in schema_json.items():
    cols = ", ".join(columns)
    table_string = f"# Table {table}, columns = [{cols}]"
    tables.append(table_string)

  schema_str = "# SQL\n\n"
  schema_str += "\n".join(tables)
  return schema_str


def get_codex_sql(schema, query, print_prompt=True):
  prompt = "\n\n".join([schema, query])
  if print_prompt:
    print(prompt)
  response = openai.Completion.create(
    model="code-davinci-002",
    prompt=prompt,
    temperature=0,
    max_tokens=150,
    top_p=1.0,
    frequency_penalty=0.0,
    presence_penalty=0.0,
    stop=["#", ";"]
  )
  return response.choices[0].text


def get_schema_string_for_picard(schema_json:Dict, db_name:str):
  tables = []
  for table, columns in schema_json.items():
    cols = ", ".join(columns)
    table_string = f"{table} : {cols}"
    tables.append(table_string)
  schema_str = f'| {db_name} | ' + ' | '.join(tables)
  return schema_str

def get_picard_sql(schema_str, query, print_prompt=True):
  payload = json.dumps({
    "question": query,
    "db_schema": schema_str
  })

  if print_prompt:
    print(schema_str)
  headers = {
    'Content-Type': 'application/json'
  }

  picard_ip = os.getenv('PICARD_IP')
  if not picard_ip:
    raise ValueError('PICARD_IP not set')

  url = f'http://{picard_ip}:8000/ask-with-schema/'
  response = requests.request("POST", url, headers=headers, data=payload)
  return response.json()[0]

schema = {
    "albums": ["AlbumId", "Title", "ArtistId"],
    "artists": ["ArtistId", "Name"],
    "media_types": ["MediaTypeId", "Name"],
    "playlists": ["PlaylistId", "Name"],
    "playlist_track": ["PlaylistId", "TrackId"],
    "tracks": ["TrackId", "Name", "AlbumId", "MediaTypeId", "GenreId", "Composer", "Milliseconds", "Bytes", "UnitPrice"]
}

query = 'generate sql query to list all albums by Adele'
query = 'genereate sql to find artists with longest average track length'
query = 'generate sql to find the most expensive albums'
query = 'generate sql to find duration of tracks from album abc in playlist xyz'
query = 'generate sql to find total duration of tracks from album abc in playlist xyz'


schema = {
    "well" : ["id, country, field, latitude, longitude, uwi, well_name"],
    "well_bore": ["id, spud_date, ubhi, well_bore_name, well_id"],
    "section": ["id, max_dls, max_inclination, min_dls, min_inclination, section_bottom_depth_md, section_caption, section_diameter, section_number, section_top_depth_md, well_bore_id"],
    "operation": ["id, operation_code, operation_enddate_time, operation_end_depth, operation_start_datetime, operation_start_depth"],
    "operating_parameters": ['id', 'avg_rop', 'flow_rate_high', 'flow_rate_low', 'rpm_high', 'rpm_low', 'wob_high', 'wob_low', 'operation_id'],
    "run": ["id, run_end_depth, run_number, runs_tart_depth, section_id, operation_id"],
    "bit": ["id, bit_run_number, bit_type, grading_out, iadc_number, manufacturer, model_number, primary_od, secondary_od, serial_number, run_id"]
}
query = "generete sql to find the count of bit types used in each well"
query = "Give SQL query for Count number of bits for every section of the well bore Matzen 569"
query = "give sql query to give distance drilled per hour by section number with wellbore name Prottes"
query = "grading out of bit in each section with wellbore name Matzen"
query = "get avg_rop by section name with WellboreName Bockfliess"
query = 'find wells with latitude greater than 30'
query = 'find wells with well bores with spud date greater than 2010-01-01'
query = 'find avg_rop, section diameter in each section with well bore name Prottes'
schema_codex = get_schema_string_for_codex(schema)
print(get_codex_sql(schema_codex, query))
# SQL

# Table well, columns = [id, country, field, latitude, longitude, uwi, well_name]
# Table well_bore, columns = [id, spud_date, ubhi, well_bore_name, well_id]
# Table section, columns = [id, max_dls, max_inclination, mindls, min_inclination, section_bottom_depth_md, section_caption, section_diameter, section_number, section_top_depth_md, well_bore_id]
# Table operation, columns = [id, operation_code, operation_enddate_time, operation_end_depth, operation_start_datetime, operation_start_depth]
# Table run, columns = [id, run_end_depth, run_number, runs_tart_depth, section_id, operation_id]
# Table bit, columns = [id, bit_run_number, bit_type, grading_out, iadc_number, manufacturer, model_number, primary_od, secondary_od, serial_number, run_id]

Give SQL query for Count number of bits for every section of the well bore Matzen 569
/1-1

SELECT COUNT(*) FROM bit
INNER JOIN run ON bit.run_id = run.id
INNER JOIN section ON run.section_id = section.id
INNER JOIN well_bore ON section.well_bore_id = well_bore.id
INNER JOIN well ON well_bore.well_id = well.id
WHERE well.well_name = 'Matzen 569/1-1'

Give SQL query for Count number of bits for every section of the well bore Matzen 569/1-1

SELECT COUNT(*) FROM bit
INNER JOIN run ON bit
schema_picard = get_schema_string_for_picard(schema, 'wells')
print('\n' + query + '\n\n' + get_picard_sql(schema_picard, query))
| wells | well : id, country, field, latitude, longitude, uwi, well_name | well_bore : id, spud_date, ubhi, well_bore_name, well_id | section : id, max_dls, max_inclination, min_dls, min_inclination, section_bottom_depth_md, section_caption, section_diameter, section_number, section_top_depth_md, well_bore_id | operation : id, operation_code, operation_enddate_time, operation_end_depth, operation_start_datetime, operation_start_depth | operating_parameters : id, avg_rop, flow_rate_high, flow_rate_low, rpm_high, rpm_low, wob_high, wob_low, operation_id | run : id, run_end_depth, run_number, runs_tart_depth, section_id, operation_id | bit : id, bit_run_number, bit_type, grading_out, iadc_number, manufacturer, model_number, primary_od, secondary_od, serial_number, run_id

find avg_rop, section diameter in each section with well bore name Prottes

select t1.avg_rop, t1.section_diameter from operating_parameters as t1 join section as t2 on t1.section_id = t2.section_id join well_bore as t3 on t3.id = t2.well_bore_id where t3.well_bore_name = "Prottes"

Results

query codex_sql co_res picard_sql pi_res
generate sql query to list all albums by Adele SELECT * FROM albums JOIN artists ON artists.ArtistId = albums.ArtistId WHERE artists.Name = ‘Adele’ select t1.title from albums as t1 join artists as t2 on t1.ArtistId = t2.ArtistId where t2.Name = “Adele”
genereate sql to find artists with longest average track length SELECT a.Name, AVG(t.Milliseconds) AS AverageTrackLength FROM artists a JOIN albums al ON a.ArtistId = al.ArtistId JOIN tracks t ON al.AlbumId = t.AlbumId GROUP BY a.Name ORDER BY AverageTrackLength DESC LIMIT 1 FROM artists a JOIN albums al ON a.ArtistId = al.ArtistId JOIN tracks t ON al.AlbumId = t.AlbumId GROUP BY a.Name ORDER BY AverageTrackLength DESC LIMIT 1 select t1.Name from artists as t1 join tracks as t2 on t1.ArtistId = t2.ArtistId group by t2.ArtistId order by avg(milliseconds) desc limit 1
generate sql to find the most expensive albums select albums.title, tracks.unitprice from albums join tracks on albums.albumid = tracks.albumid order by tracks.unitprice desc limit 1 select title from albums order by unitprice desc limit 1
generate sql to find duration of tracks from album abc in playlist xyz SELECT SUM(Milliseconds) FROM tracks WHERE AlbumId = (SELECT AlbumId FROM albums WHERE Title = ‘abc’) AND TrackId IN (SELECT TrackId FROM playlist_track WHERE PlaylistId = (SELECT PlaylistId FROM playlists WHERE Name = ‘xyz’)) select t1.milliseconds from tracks as t1 join playlist_track as t2 on t1.trackId = t2.TrackId join albums as t3 on t3.albumId = t2.AlbumId join playlists as t4 on t4.playlistId = t1.playlistId where t3.albumId = “ABC” and t4.name = “xyz”
generate sql to find total duration of tracks from album abc in playlist xyz select sum(milliseconds) from tracks where albumid = (select albumid from albums where title = ‘abc’) and trackid in (select trackid from playlist_track where playlistid = (select playlistid from playlists where name = ‘xyz’)) select sum(t1.milliseconds) from tracks as t1 join playlist_track as t2 on t1.trackId = t2.TrackId join albums as t3 on t3.albumId = t2.AlbumId join playlists as t4 on t4.playlistId = t1.playlistId where t3.albumId = “ABC” and t4.name = “xyz”