import os
import openai
from typing import Dict, List
import json
import requests
= os.getenv("OPENAI_API_KEY")
openai.api_key
def get_schema_string_for_codex(schema_json:Dict):
# accepts schema json where key is table name and value is list of columns
= []
tables for table, columns in schema_json.items():
= ", ".join(columns)
cols = f"# Table {table}, columns = [{cols}]"
table_string
tables.append(table_string)
= "# SQL\n\n"
schema_str += "\n".join(tables)
schema_str return schema_str
def get_codex_sql(schema, query, print_prompt=True):
= "\n\n".join([schema, query])
prompt if print_prompt:
print(prompt)
= openai.Completion.create(
response ="code-davinci-002",
model=prompt,
prompt=0,
temperature=150,
max_tokens=1.0,
top_p=0.0,
frequency_penalty=0.0,
presence_penalty=["#", ";"]
stop
)return response.choices[0].text
def get_schema_string_for_picard(schema_json:Dict, db_name:str):
= []
tables for table, columns in schema_json.items():
= ", ".join(columns)
cols = f"{table} : {cols}"
table_string
tables.append(table_string)= f'| {db_name} | ' + ' | '.join(tables)
schema_str return schema_str
def get_picard_sql(schema_str, query, print_prompt=True):
= json.dumps({
payload "question": query,
"db_schema": schema_str
})
if print_prompt:
print(schema_str)
= {
headers 'Content-Type': 'application/json'
}
= os.getenv('PICARD_IP')
picard_ip if not picard_ip:
raise ValueError('PICARD_IP not set')
= f'http://{picard_ip}:8000/ask-with-schema/'
url = requests.request("POST", url, headers=headers, data=payload)
response return response.json()[0]
Here I am exploring the capabiliteis of Text to SQL of OpenAI codex vs fine-tuned t5-3b model (picard).
https://openai.com/blog/openai-codex/
= {
schema "albums": ["AlbumId", "Title", "ArtistId"],
"artists": ["ArtistId", "Name"],
"media_types": ["MediaTypeId", "Name"],
"playlists": ["PlaylistId", "Name"],
"playlist_track": ["PlaylistId", "TrackId"],
"tracks": ["TrackId", "Name", "AlbumId", "MediaTypeId", "GenreId", "Composer", "Milliseconds", "Bytes", "UnitPrice"]
}
= 'generate sql query to list all albums by Adele'
query = 'genereate sql to find artists with longest average track length'
query = 'generate sql to find the most expensive albums'
query = 'generate sql to find duration of tracks from album abc in playlist xyz'
query = 'generate sql to find total duration of tracks from album abc in playlist xyz'
query
= {
schema "well" : ["id, country, field, latitude, longitude, uwi, well_name"],
"well_bore": ["id, spud_date, ubhi, well_bore_name, well_id"],
"section": ["id, max_dls, max_inclination, min_dls, min_inclination, section_bottom_depth_md, section_caption, section_diameter, section_number, section_top_depth_md, well_bore_id"],
"operation": ["id, operation_code, operation_enddate_time, operation_end_depth, operation_start_datetime, operation_start_depth"],
"operating_parameters": ['id', 'avg_rop', 'flow_rate_high', 'flow_rate_low', 'rpm_high', 'rpm_low', 'wob_high', 'wob_low', 'operation_id'],
"run": ["id, run_end_depth, run_number, runs_tart_depth, section_id, operation_id"],
"bit": ["id, bit_run_number, bit_type, grading_out, iadc_number, manufacturer, model_number, primary_od, secondary_od, serial_number, run_id"]
}= "generete sql to find the count of bit types used in each well"
query = "Give SQL query for Count number of bits for every section of the well bore Matzen 569"
query = "give sql query to give distance drilled per hour by section number with wellbore name Prottes"
query = "grading out of bit in each section with wellbore name Matzen"
query = "get avg_rop by section name with WellboreName Bockfliess"
query = 'find wells with latitude greater than 30'
query = 'find wells with well bores with spud date greater than 2010-01-01'
query = 'find avg_rop, section diameter in each section with well bore name Prottes' query
= get_schema_string_for_codex(schema)
schema_codex print(get_codex_sql(schema_codex, query))
# SQL
# Table well, columns = [id, country, field, latitude, longitude, uwi, well_name]
# Table well_bore, columns = [id, spud_date, ubhi, well_bore_name, well_id]
# Table section, columns = [id, max_dls, max_inclination, mindls, min_inclination, section_bottom_depth_md, section_caption, section_diameter, section_number, section_top_depth_md, well_bore_id]
# Table operation, columns = [id, operation_code, operation_enddate_time, operation_end_depth, operation_start_datetime, operation_start_depth]
# Table run, columns = [id, run_end_depth, run_number, runs_tart_depth, section_id, operation_id]
# Table bit, columns = [id, bit_run_number, bit_type, grading_out, iadc_number, manufacturer, model_number, primary_od, secondary_od, serial_number, run_id]
Give SQL query for Count number of bits for every section of the well bore Matzen 569
/1-1
SELECT COUNT(*) FROM bit
INNER JOIN run ON bit.run_id = run.id
INNER JOIN section ON run.section_id = section.id
INNER JOIN well_bore ON section.well_bore_id = well_bore.id
INNER JOIN well ON well_bore.well_id = well.id
WHERE well.well_name = 'Matzen 569/1-1'
Give SQL query for Count number of bits for every section of the well bore Matzen 569/1-1
SELECT COUNT(*) FROM bit
INNER JOIN run ON bit
= get_schema_string_for_picard(schema, 'wells')
schema_picard print('\n' + query + '\n\n' + get_picard_sql(schema_picard, query))
| wells | well : id, country, field, latitude, longitude, uwi, well_name | well_bore : id, spud_date, ubhi, well_bore_name, well_id | section : id, max_dls, max_inclination, min_dls, min_inclination, section_bottom_depth_md, section_caption, section_diameter, section_number, section_top_depth_md, well_bore_id | operation : id, operation_code, operation_enddate_time, operation_end_depth, operation_start_datetime, operation_start_depth | operating_parameters : id, avg_rop, flow_rate_high, flow_rate_low, rpm_high, rpm_low, wob_high, wob_low, operation_id | run : id, run_end_depth, run_number, runs_tart_depth, section_id, operation_id | bit : id, bit_run_number, bit_type, grading_out, iadc_number, manufacturer, model_number, primary_od, secondary_od, serial_number, run_id
find avg_rop, section diameter in each section with well bore name Prottes
select t1.avg_rop, t1.section_diameter from operating_parameters as t1 join section as t2 on t1.section_id = t2.section_id join well_bore as t3 on t3.id = t2.well_bore_id where t3.well_bore_name = "Prottes"
Results
query | codex_sql | co_res | picard_sql | pi_res |
---|---|---|---|---|
generate sql query to list all albums by Adele | SELECT * FROM albums JOIN artists ON artists.ArtistId = albums.ArtistId WHERE artists.Name = ‘Adele’ | select t1.title from albums as t1 join artists as t2 on t1.ArtistId = t2.ArtistId where t2.Name = “Adele” | ||
genereate sql to find artists with longest average track length | SELECT a.Name, AVG(t.Milliseconds) AS AverageTrackLength FROM artists a JOIN albums al ON a.ArtistId = al.ArtistId JOIN tracks t ON al.AlbumId = t.AlbumId GROUP BY a.Name ORDER BY AverageTrackLength DESC LIMIT 1 FROM artists a JOIN albums al ON a.ArtistId = al.ArtistId JOIN tracks t ON al.AlbumId = t.AlbumId GROUP BY a.Name ORDER BY AverageTrackLength DESC LIMIT 1 | select t1.Name from artists as t1 join tracks as t2 on t1.ArtistId = t2.ArtistId group by t2.ArtistId order by avg(milliseconds) desc limit 1 | ||
generate sql to find the most expensive albums | select albums.title, tracks.unitprice from albums join tracks on albums.albumid = tracks.albumid order by tracks.unitprice desc limit 1 | select title from albums order by unitprice desc limit 1 | ||
generate sql to find duration of tracks from album abc in playlist xyz | SELECT SUM(Milliseconds) FROM tracks WHERE AlbumId = (SELECT AlbumId FROM albums WHERE Title = ‘abc’) AND TrackId IN (SELECT TrackId FROM playlist_track WHERE PlaylistId = (SELECT PlaylistId FROM playlists WHERE Name = ‘xyz’)) | select t1.milliseconds from tracks as t1 join playlist_track as t2 on t1.trackId = t2.TrackId join albums as t3 on t3.albumId = t2.AlbumId join playlists as t4 on t4.playlistId = t1.playlistId where t3.albumId = “ABC” and t4.name = “xyz” | ||
generate sql to find total duration of tracks from album abc in playlist xyz | select sum(milliseconds) from tracks where albumid = (select albumid from albums where title = ‘abc’) and trackid in (select trackid from playlist_track where playlistid = (select playlistid from playlists where name = ‘xyz’)) | select sum(t1.milliseconds) from tracks as t1 join playlist_track as t2 on t1.trackId = t2.TrackId join albums as t3 on t3.albumId = t2.AlbumId join playlists as t4 on t4.playlistId = t1.playlistId where t3.albumId = “ABC” and t4.name = “xyz” |