use squawk_syntax::ast;

fn has_foreign_key_constraint(create_table: &ast::CreateTable) -> bool {
    if let Some(table_arg_list) = create_table.table_arg_list() {
        for arg in table_arg_list.args() {
            match arg {
                ast::TableArg::TableConstraint(ast::TableConstraint::ForeignKeyConstraint(_)) => {
                    return true;
                }
                ast::TableArg::Column(column) => {
                    if let Some(ast::ColumnConstraint::ReferencesConstraint(_)) =
                        column.constraint()
                    {
                        return true;
                    }
                }
                _ => {}
            }
        }
    }
    false
}

/// Returns `true` if the statement might impede normal database queries.
pub fn possibly_slow_stmt(stmt: &ast::Stmt) -> bool {
    // We assume all DDL like Alter, Create, Drop could affect queries.
    //
    // We don't want to warn about DML style queries, like select,
    // insert, update, delete, etc. This allows using squawk for normal SQL
    // editing.
    match stmt {
        // Without a foreign key constraint, creating a new table should be fast
        ast::Stmt::CreateTable(create_table) => has_foreign_key_constraint(create_table),
        | ast::Stmt::AlterAggregate(_)
        | ast::Stmt::AlterCollation(_)
        | ast::Stmt::AlterConversion(_)
        | ast::Stmt::AlterDatabase(_)
        | ast::Stmt::AlterDefaultPrivileges(_)
        | ast::Stmt::AlterDomain(_)
        | ast::Stmt::AlterEventTrigger(_)
        | ast::Stmt::AlterExtension(_)
        | ast::Stmt::AlterForeignDataWrapper(_)
        | ast::Stmt::AlterForeignTable(_)
        | ast::Stmt::AlterFunction(_)
        | ast::Stmt::AlterGroup(_)
        | ast::Stmt::AlterIndex(_)
        | ast::Stmt::AlterLanguage(_)
        | ast::Stmt::AlterLargeObject(_)
        | ast::Stmt::AlterMaterializedView(_)
        | ast::Stmt::AlterOperator(_)
        | ast::Stmt::AlterOperatorClass(_)
        | ast::Stmt::AlterOperatorFamily(_)
        | ast::Stmt::AlterPolicy(_)
        | ast::Stmt::AlterProcedure(_)
        | ast::Stmt::AlterPublication(_)
        | ast::Stmt::AlterRole(_)
        | ast::Stmt::AlterRoutine(_)
        | ast::Stmt::AlterRule(_)
        | ast::Stmt::AlterSchema(_)
        | ast::Stmt::AlterSequence(_)
        | ast::Stmt::AlterServer(_)
        | ast::Stmt::AlterStatistics(_)
        | ast::Stmt::AlterSubscription(_)
        | ast::Stmt::AlterSystem(_)
        | ast::Stmt::AlterTable(_)
        | ast::Stmt::AlterTablespace(_)
        | ast::Stmt::AlterTextSearchConfiguration(_)
        | ast::Stmt::AlterTextSearchDictionary(_)
        | ast::Stmt::AlterTextSearchParser(_)
        | ast::Stmt::AlterTextSearchTemplate(_)
        | ast::Stmt::AlterTrigger(_)
        | ast::Stmt::AlterType(_)
        | ast::Stmt::AlterUser(_)
        | ast::Stmt::AlterUserMapping(_)
        | ast::Stmt::AlterView(_)
        | ast::Stmt::CreateAccessMethod(_)
        | ast::Stmt::CreateAggregate(_)
        | ast::Stmt::CreateCast(_)
        | ast::Stmt::CreateCollation(_)
        | ast::Stmt::CreateConversion(_)
        | ast::Stmt::CreateDatabase(_)
        | ast::Stmt::CreateDomain(_)
        | ast::Stmt::CreateEventTrigger(_)
        | ast::Stmt::CreateExtension(_)
        | ast::Stmt::CreateForeignDataWrapper(_)
        | ast::Stmt::CreateForeignTable(_)
        | ast::Stmt::CreateFunction(_)
        | ast::Stmt::CreateGroup(_)
        | ast::Stmt::CreateIndex(_)
        | ast::Stmt::CreateLanguage(_)
        | ast::Stmt::CreateMaterializedView(_)
        | ast::Stmt::CreateOperator(_)
        | ast::Stmt::CreateOperatorClass(_)
        | ast::Stmt::CreateOperatorFamily(_)
        | ast::Stmt::CreatePolicy(_)
        | ast::Stmt::CreateProcedure(_)
        | ast::Stmt::CreatePublication(_)
        | ast::Stmt::CreateRole(_)
        | ast::Stmt::CreateRule(_)
        | ast::Stmt::CreateSchema(_)
        | ast::Stmt::CreateSequence(_)
        | ast::Stmt::CreateServer(_)
        | ast::Stmt::CreateStatistics(_)
        | ast::Stmt::CreateSubscription(_)
        | ast::Stmt::CreateTableAs(_)
        | ast::Stmt::CreateTablespace(_)
        | ast::Stmt::CreateTextSearchConfiguration(_)
        | ast::Stmt::CreateTextSearchDictionary(_)
        | ast::Stmt::CreateTextSearchParser(_)
        | ast::Stmt::CreateTextSearchTemplate(_)
        | ast::Stmt::CreateTransform(_)
        | ast::Stmt::CreateTrigger(_)
        | ast::Stmt::CreateType(_)
        | ast::Stmt::CreateUser(_)
        | ast::Stmt::CreateUserMapping(_)
        | ast::Stmt::CreateView(_)
        | ast::Stmt::DropAccessMethod(_)
        | ast::Stmt::DropAggregate(_)
        | ast::Stmt::DropCast(_)
        | ast::Stmt::DropCollation(_)
        | ast::Stmt::DropConversion(_)
        | ast::Stmt::DropDatabase(_)
        | ast::Stmt::DropDomain(_)
        | ast::Stmt::DropEventTrigger(_)
        | ast::Stmt::DropExtension(_)
        | ast::Stmt::DropForeignDataWrapper(_)
        | ast::Stmt::DropForeignTable(_)
        | ast::Stmt::DropFunction(_)
        | ast::Stmt::DropGroup(_)
        | ast::Stmt::DropIndex(_)
        | ast::Stmt::DropLanguage(_)
        | ast::Stmt::DropMaterializedView(_)
        | ast::Stmt::DropOperator(_)
        | ast::Stmt::DropOperatorClass(_)
        | ast::Stmt::DropOperatorFamily(_)
        | ast::Stmt::DropOwned(_)
        | ast::Stmt::DropPolicy(_)
        | ast::Stmt::DropProcedure(_)
        | ast::Stmt::DropPublication(_)
        | ast::Stmt::DropRole(_)
        | ast::Stmt::DropRoutine(_)
        | ast::Stmt::DropRule(_)
        | ast::Stmt::DropSchema(_)
        | ast::Stmt::DropSequence(_)
        | ast::Stmt::DropServer(_)
        | ast::Stmt::DropStatistics(_)
        | ast::Stmt::DropSubscription(_)
        | ast::Stmt::DropTable(_)
        | ast::Stmt::DropTablespace(_)
        | ast::Stmt::DropTextSearchConfig(_)
        | ast::Stmt::DropTextSearchDict(_)
        | ast::Stmt::DropTextSearchParser(_)
        | ast::Stmt::DropTextSearchTemplate(_)
        | ast::Stmt::DropTransform(_)
        | ast::Stmt::DropTrigger(_)
        | ast::Stmt::DropType(_)
        | ast::Stmt::DropUser(_)
        | ast::Stmt::DropUserMapping(_)
        | ast::Stmt::DropView(_)
        // non-Alter, Create, Drop statements
        | ast::Stmt::Cluster(_)
        | ast::Stmt::CommentOn(_)
        | ast::Stmt::ImportForeignSchema(_)
        | ast::Stmt::Load(_)
        | ast::Stmt::Lock(_)
        | ast::Stmt::Refresh(_)
        | ast::Stmt::Reindex(_)
        | ast::Stmt::Truncate(_)
        | ast::Stmt::Vacuum(_)
        => true,
        ast::Stmt::Analyze(_)
        | ast::Stmt::Begin(_)
        | ast::Stmt::Call(_)
        | ast::Stmt::Checkpoint(_)
        | ast::Stmt::Close(_)
        | ast::Stmt::Commit(_)
        | ast::Stmt::Copy(_)
        | ast::Stmt::Deallocate(_)
        | ast::Stmt::Declare(_)
        | ast::Stmt::Delete(_)
        | ast::Stmt::Discard(_)
        | ast::Stmt::Do(_)
        | ast::Stmt::Execute(_)
        | ast::Stmt::Explain(_)
        | ast::Stmt::Fetch(_)
        | ast::Stmt::Grant(_)
        | ast::Stmt::Insert(_)
        | ast::Stmt::Listen(_)
        | ast::Stmt::Merge(_)
        | ast::Stmt::Move(_)
        | ast::Stmt::Notify(_)
        | ast::Stmt::ParenSelect(_)
        | ast::Stmt::Prepare(_)
        | ast::Stmt::PrepareTransaction(_)
        | ast::Stmt::Reassign(_)
        | ast::Stmt::ReleaseSavepoint(_)
        | ast::Stmt::Reset(_)
        | ast::Stmt::Revoke(_)
        | ast::Stmt::Rollback(_)
        | ast::Stmt::Savepoint(_)
        | ast::Stmt::SecurityLabel(_)
        | ast::Stmt::Select(_)
        | ast::Stmt::SelectInto(_)
        | ast::Stmt::Set(_)
        | ast::Stmt::SetConstraints(_)
        | ast::Stmt::SetRole(_)
        | ast::Stmt::SetSessionAuth(_)
        | ast::Stmt::ResetSessionAuth(_)
        | ast::Stmt::SetTransaction(_)
        | ast::Stmt::Show(_)
        | ast::Stmt::Table(_)
        | ast::Stmt::Unlisten(_)
        | ast::Stmt::Update(_)
        | ast::Stmt::Values(_) => false,
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use squawk_syntax::SourceFile;

    #[test]
    fn alter_table() {
        let sql = "ALTER TABLE users ADD COLUMN email TEXT;";
        let file = SourceFile::parse(sql);
        let stmts = file.tree().stmts().next().unwrap();
        assert!(possibly_slow_stmt(&stmts));
    }

    #[test]
    fn select() {
        let sql = "select 1;";
        let file = SourceFile::parse(sql);
        let stmts = file.tree().stmts().next().unwrap();
        assert!(!possibly_slow_stmt(&stmts));
    }

    #[test]
    fn create_table_without_foreign_key() {
        let sql = "create table foo (id integer generated by default as identity primary key);";
        let file = SourceFile::parse(sql);
        let stmts = file.tree().stmts().next().unwrap();
        assert!(!possibly_slow_stmt(&stmts));
    }

    #[test]
    fn create_table_with_foreign_key() {
        let sql = "create table foo (id integer, user_id integer references users(id));";
        let file = SourceFile::parse(sql);
        let stmts = file.tree().stmts().next().unwrap();
        assert!(possibly_slow_stmt(&stmts));
    }

    #[test]
    fn create_table_with_table_level_foreign_key() {
        let sql = "create table foo (id integer, user_id integer, foreign key (user_id) references users(id));";
        let file = SourceFile::parse(sql);
        let stmts = file.tree().stmts().next().unwrap();
        assert!(possibly_slow_stmt(&stmts));
    }
}
