Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion preql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@

# import importlib.metadata as importlib_metadata
# __version__ = importlib_metadata.version("prql")
__version__ = "0.2.8"
__version__ = "0.2.9"
__branch__ = ""
58 changes: 38 additions & 20 deletions preql/core/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
postgres = 'postgres'
bigquery = 'bigquery'
mysql = 'mysql'
mssql = 'mssql'
snowflake = 'snowflake'

class QueryBuilder:
Expand Down Expand Up @@ -498,7 +499,7 @@ class AddIndex(SqlStatement):

def _compile(self, qb):
stmt = f"CREATE {'UNIQUE' if self.unique else ''} INDEX "
if qb.target != mysql:
if qb.target != mysql and qb.target != mssql:
stmt += "IF NOT EXISTS "
return [ stmt + f"{quote_id(self.index_name)} ON {quote_id(self.table_name)}({self.column})"]

Expand Down Expand Up @@ -607,7 +608,7 @@ def _compile(self, qb):
if qb.target == mysql:
def row_func(x):
return ['ROW('] + x + [')']
elif qb.target == snowflake:
elif qb.target == snowflake or qb.target == mssql:
return join_sep([['SELECT '] + v.code for v in values], ' UNION ALL ')
else:
row_func = parens
Expand Down Expand Up @@ -739,22 +740,30 @@ def _compile(self, qb):
if self.order:
sql += [' ORDER BY '] + join_comma(o.compile_wrap(qb).code for o in self.order)

if self.limit is not None:
sql += [' LIMIT ', str(self.limit)]
elif self.offset is not None:
if qb.target == sqlite:
sql += [' LIMIT -1'] # Sqlite only (and only old versions of it)
elif qb.target == mysql:
# MySQL requires a specific limit, always!
# See: https://stackoverflow.com/questions/255517/mysql-offset-infinite-rows
sql += [' LIMIT 18446744073709551615']
elif qb.target == bigquery:
# BigQuery requires a specific limit, always!
sql += [' LIMIT 9223372036854775807']


if self.offset is not None:
sql += [' OFFSET ', str(self.offset)]
if qb.target == mssql:
if self.offset or self.limit:
if not self.order:
sql += [' ORDER BY ', list(self.table.type.elems)[0]] # XXX hacky!
sql += [' OFFSET ', str(self.offset or 0),' ROWS ']
if self.limit:
sql += [' FETCH NEXT ', str(self.limit), ' ROWS ONLY ']
else:
if self.limit is not None:
sql += [' LIMIT ', str(self.limit)]
elif self.offset is not None:
if qb.target == sqlite:
sql += [' LIMIT -1'] # Sqlite only (and only old versions of it)
elif qb.target == mysql:
# MySQL requires a specific limit, always!
# See: https://stackoverflow.com/questions/255517/mysql-offset-infinite-rows
sql += [' LIMIT 18446744073709551615']
elif qb.target == bigquery:
# BigQuery requires a specific limit, always!
sql += [' LIMIT 9223372036854775807']


if self.offset is not None:
sql += [' OFFSET ', str(self.offset)]

return sql

Expand Down Expand Up @@ -899,6 +908,8 @@ def _repr(_t: T.number, x):

@dp_type
def _repr(_t: T.bool, x):
if get_db().target == mssql:
return ['0=1', '1=1'][x]
return ['false', 'true'][x]

@dp_type
Expand Down Expand Up @@ -971,6 +982,9 @@ class P2S_Postgres(Types_PqlToSql):
class P2S_Snowflake(Types_PqlToSql):
pass

class P2S_MsSql(Types_PqlToSql):
pass

_pql_to_sql_by_target = {
bigquery: P2S_BigQuery,
mysql: P2S_MySql,
Expand All @@ -979,6 +993,7 @@ class P2S_Snowflake(Types_PqlToSql):
duck: P2S_Sqlite,
postgres: P2S_Postgres,
snowflake: P2S_Snowflake,
mssql: P2S_MsSql,
}


Expand Down Expand Up @@ -1056,10 +1071,13 @@ def compile_type_def(table_name, table) -> Sql:
posts.append(f"PRIMARY KEY ({names})")

# Consistent among SQL databases
tmp = table.options.get('temporary', False)
if db.target == 'bigquery':
command = ("CREATE TABLE" if table.options.get('temporary', False) else "CREATE TABLE IF NOT EXISTS")
command = ("CREATE TABLE" if tmp else "CREATE TABLE IF NOT EXISTS")
if db.target == 'mssql':
command = "CREATE TEMPORARY TABLE" if tmp else "CREATE TABLE"
else:
command = "CREATE TEMPORARY TABLE" if table.options.get('temporary', False) else "CREATE TABLE IF NOT EXISTS"
command = "CREATE TEMPORARY TABLE" if tmp else "CREATE TABLE IF NOT EXISTS"

return RawSql(T.nulltype, f'{command} {quote_id(table_name)} (' + ', '.join(columns + posts) + ')')

Expand Down
19 changes: 14 additions & 5 deletions preql/modules/__builtins__.pql
Original file line number Diff line number Diff line change
Expand Up @@ -131,11 +131,6 @@ func first_or_null(obj: union[table, projected]) {
func limit(tbl: table, n: int) = tbl[..n]
"""Returns the first 'n' rows in the table."""

func limit_offset(tbl: table, lim: int, offset: int) {
"Returns the first 'n' rows in the table at the given offset."
return SQL(type(tbl), "SELECT * FROM $tbl LIMIT $lim OFFSET $offset")
}



func upper(s: string) = SQL(string, "upper($s)")
Expand Down Expand Up @@ -287,6 +282,18 @@ func _count_true(field: aggregated) = SQL(int, "sum(cast($field!=0 as int))")

db_type = get_db_type()

if (db_type == "mssql") {
func limit_offset(tbl: table, lim: int, offset: int) {
"Returns the first 'n' rows in the table at the given offset."
return SQL(type(tbl), "SELECT TOP $lim * FROM $tbl")
}
} else {
func limit_offset(tbl: table, lim: int, offset: int) {
"Returns the first 'n' rows in the table at the given offset."
return SQL(type(tbl), "SELECT * FROM $tbl LIMIT $lim OFFSET $offset")
}
}

if (db_type == 'sqlite' and PY("sqlite3.sqlite_version_info < (3, 25)", "import sqlite3")) {
func enum(tbl: table) {
throw new NotImplementedError("Sqlite doesn't support window functions before version 3.25")
Expand Down Expand Up @@ -397,6 +404,8 @@ if (db_type == "postgres") {
func year(x) = SQL(int, "EXTRACT (YEAR FROM $x)")
} else if (db_type == "snowflake") {
//
} else if (db_type == "mssql") {
//
} else {
throw new TypeError("Unexpected")
}
Expand Down
64 changes: 63 additions & 1 deletion preql/sql_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from .loggers import sql_log
from .context import context

from .core.sql import Sql, QueryBuilder, sqlite, postgres, mysql, duck, bigquery, quote_id, snowflake
from .core.sql import Sql, QueryBuilder, sqlite, postgres, mysql, mssql, duck, bigquery, quote_id, snowflake
from .core.sql_import_result import sql_result_to_python, type_from_sql
from .core.pql_types import T, Type, Object, Id
from .core.exceptions import DatabaseQueryError, Signal
Expand Down Expand Up @@ -211,6 +211,66 @@ def ping(self):
assert n == 1


class MssqlInterface(SqlInterfaceCursor):
target = mssql

id_type_decl = "INT IDENTITY (1, 1)"
requires_subquery_name = True # XXX not sure

def __init__(self, host, port, database, user, password, print_sql=False):
self._print_sql = print_sql

args = dict(server=host, port=port, database=database, user=user, password=password)
self._args = {k:v for k, v in args.items() if v is not None}
super().__init__(print_sql)

def _create_connection(self):
import pymssql

try:
return pymssql.connect(**self._args)
except mysql.connector.Error as e:
raise ConnectError(*e.args) from e

def quote_name(self, name):
return f'[{name}]'

def list_tables(self):
# TODO import more schemas?
sql_code = "SELECT Distinct TABLE_NAME FROM information_schema.TABLES"
names = self._execute_sql(T.list[T.string], sql_code)
return list(map(Id, names))

_schema_columns_t = T.table(dict( # Same as postgres
schema=T.string,
table=T.string,
name=T.string,
pos=T.int,
nullable=T.bool,
type=T.string,
))

def import_table_type(self, table_id, columns_whitelist=None):
columns_q = f"""SELECT table_schema, table_name, column_name, ordinal_position, is_nullable, data_type
FROM information_schema.columns
WHERE table_name = '{table_id.name}'
"""
sql_columns = self._execute_sql(self._schema_columns_t, columns_q)

if columns_whitelist:
wl = set(columns_whitelist)
sql_columns = [c for c in sql_columns if c['name'] in wl]

cols = [(c['pos'], c['name'], type_from_sql(c['type'], c['nullable'])) for c in sql_columns]
cols.sort()
cols = dict(c[1:] for c in cols)

return T.table(cols, name=table_id)

def table_exists(self, name):
assert isinstance(name, Id)
tables = [t.lower() for t in self.list_tables()]
return name.lower() in tables

class MysqlInterface(SqlInterfaceCursor):
target = mysql
Expand Down Expand Up @@ -858,6 +918,8 @@ def create_engine(db_uri, print_sql, auto_create):
return PostgresInterface(dsn.host, dsn.port, path, dsn.user, dsn.password, print_sql=print_sql)
elif scheme == 'mysql':
return MysqlInterface(dsn.host, dsn.port, path, dsn.user, dsn.password, print_sql=print_sql)
elif scheme == 'mssql':
return MssqlInterface(dsn.host, dsn.port, path, dsn.user, dsn.password, print_sql=print_sql)
elif scheme == 'git':
return GitInterface(path, print_sql=print_sql)
elif scheme == 'duck':
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "preql"
version = "0.2.8"
version = "0.2.9"
description = "An interpreted relational query language that compiles to SQL"
authors = ["Erez Shin <erezshin@gmail.com>"]
license = "Interface-Protection Clause + MIT"
Expand Down