import _ from 'lodash';
import { AST } from '../../ast';
import { Constant, f, Expression, NullIf } from '../../builder';
import { Ty } from '../../ty';
import { TC } from '../../type-checker';
import { sqlExprMacro, pSqlMacro, intersperse } from '../sql-ast';
import { assertConstantString, SqlDialect } from './dialect';
import {
  PrimitiveAttributeTypeToPostgresType,
  PostgresDialect,
} from './postgres';

// https://coterah.slack.com/archives/C045H3HTQCQ/p1678816619808369
//
// Valid redshift
// ```
// select generate_series(0, 10);
// ```
//
// Not valid redshift
// ```
// -- Inserts run on compute nodes
// insert into foo (a) (
//     -- Leader only function
//     select generate_series(0, 10) as a
// )
// ```
//
// This is because `generate_series` is a leader only function, so it can’t be run on worker nodes
// https://docs.aws.amazon.com/redshift/latest/dg/c_SQL_functions_leader_node_only.html
// https://docs.aws.amazon.com/redshift/latest/dg/c_sql-functions-leader-node.html
// https://stackoverflow.com/questions/62716606/redshift-loading-data-issue-specified-types-or-functions-one-per-info-message
// https://stackoverflow.com/questions/17282276/using-sql-function-generate-series-in-redshift#comment96402527_22782384
//
// Recurive CTEs are NOT supported in subqueries
// ```
// -- Not valid
// select * from (
//    with recursive t(n) as (
//        select 1::integer union all select n + 1 from t where n < 100
//    ) select n from t
// );
// ```
//
// To get around this, we can use the approach outlined by how dbt does ansi sql generate_series

// https://github.com/dbt-labs/dbt-utils/blob/main/macros/sql/generate_series.sql
const numbers = (upperBound: number) => {
  if (upperBound > 2 ** 11) {
    throw new Error(
      `We only support generating series in Reshift where the upperBound is less than ${
        2 ** 11
      }`
    );
  }

  return `
(
  with p as (
    select 0::integer as generated_number union all select 1::integer
  ),
  unioned as (
    select
      (   p0.generated_number * power(2, 0) 
       +  p1.generated_number * power(2, 1) 
       +  p2.generated_number * power(2, 2) 
       +  p3.generated_number * power(2, 3) 
       +  p4.generated_number * power(2, 4) 
       +  p5.generated_number * power(2, 5) 
       +  p6.generated_number * power(2, 6) 
       +  p7.generated_number * power(2, 7) 
       +  p8.generated_number * power(2, 8) 
       +  p9.generated_number * power(2, 9) 
       +  p10.generated_number * power(2, 10) 
       +  p11.generated_number * power(2, 11) 
      ) as generated_number
    from
      p as p0
      cross join p as p1
      cross join p as p2
      cross join p as p3
      cross join p as p4
      cross join p as p5
      cross join p as p6
      cross join p as p7
      cross join p as p8
      cross join p as p9
      cross join p as p10
      cross join p as p11
  )
  select generated_number::integer from unioned where generated_number <= ${upperBound} order by generated_number
)
`;
};

// Adapted from here
// https://stackoverflow.com/questions/32965743/how-to-generate-a-uuidv4-in-mysql#32965744

// 4th section will start with a 4 indicating the version
const uuidForthSection = `CONCAT('4', LPAD(TO_HEX(FLOOR(RANDOM() * 65535)::int), 3, '0'))::text`;
// 5th section first half-byte can only be 8, 9 A or B
const uuidFifthSection = `CONCAT(TO_HEX(FLOOR(RANDOM() * 4 + 8)::int), LPAD(TO_HEX(FLOOR(RANDOM() * 65535)::int), 3, '0'))::text`;
// All other sections are just random
const uuidRandomSection = `LPAD(TO_HEX(FLOOR(RANDOM() * 65535)::int), 4, '0')::text`;

const genUuid = `LOWER(${uuidRandomSection} || ${uuidRandomSection} || '-'::text || ${uuidRandomSection} || '-'::text || ${uuidForthSection} || '-'::text || ${uuidFifthSection} || '-'::text || ${uuidRandomSection} || ${uuidRandomSection} || ${uuidRandomSection})`;

export const ScalarAttributeTypeToReshiftType: Record<
  Ty.PrimitiveAttributeType,
  string
> = {
  ...PrimitiveAttributeTypeToPostgresType,
  string: 'character varying',
  super: 'super',
};

const typeMapping = (ty: Ty.AttributeType) => {
  if (
    ty.k === 'struct' ||
    ty.k === 'array' ||
    ty.k === 'enum' ||
    ty.k === 'record'
  ) {
    // All arrays, structs and records are represented as strings
    return 'text';
  }

  return ScalarAttributeTypeToReshiftType[ty.t];
};

const getPropertyFromStruct = (
  expr: AST.ExprIR,
  name: string | AST.ExprIR,
  wantedTy: Ty.AttributeType
) => {
  const field = typeof name === 'string' ? Constant(name).ir() : name;
  const extracted = sqlExprMacro`JSON_EXTRACT_PATH_TEXT((${expr}), ${field})`;
  if (wantedTy.k === 'primitive' && wantedTy.t === 'boolean') {
    return sqlExprMacro`(${extracted}) = 'true'`;
  } else {
    return sqlExprMacro`cast(${extracted} as ${typeMapping(wantedTy)})`;
  }
};

export const RedshiftDialect: SqlDialect = {
  ...PostgresDialect,
  generateSeries: ({ start, stop }) => {
    return sqlExprMacro`(select generated_number + ${start.toString()} as n from ${numbers(
      stop - start
    )})`;
  },
  scalarLiteralOverrides: {
    string: (val) => {
      return /^[[\]"{}!*%><&/a-zA-Z0-9-_ .:,()?=]+$/.test(val)
        ? pSqlMacro`cast('${val}' as character varying)`
        : pSqlMacro`${{ val, t: 'param' }}`;
    },
  },
  cast(expr, targetTy) {
    const { ty } = Expression.fromAst(expr);
    if (ty.ty.k === 'struct' || targetTy.k === 'struct') {
      // Structs are represented by JSON strings under the hood so no need to
      // adjust them
      if (targetTy.k === 'primitive') {
        if (targetTy.t === 'string') {
          return sqlExprMacro`${expr}`;
        }

        if (targetTy.t === 'super') {
          return sqlExprMacro`JSON_PARSE(${expr})`;
        }

        throw new Error('Unreachable');
      }

      if (ty.ty.k === 'primitive' && ty.ty.t === 'super') {
        // Accessing a field from json requires a string, so the right move
        // here is just to cast the super column to a string so that it works
        // with the json operations we've got. The value is already a json
        // encoded string, so there's no need to transform it in any way.
        return sqlExprMacro`CAST(${expr} AS VARCHAR)`;
      }

      throw new Error(`Unhandled cast case from or to struct.`);
    }

    // Reshift doesn't allow casting the string 'true' to a boolean 🤬
    if (ty.ty.t === 'boolean' && targetTy.t === 'string') {
      return sqlExprMacro`case when (${expr}) then 'true' else 'false' end`;
    }

    if (
      ty.ty.k === 'primitive' &&
      ty.ty.t === 'string' &&
      targetTy.t === 'boolean'
    ) {
      return sqlExprMacro`(${expr}) = 'true'`;
    }

    return sqlExprMacro`cast(${expr} as ${typeMapping(targetTy)})`;
  },
  makeArray: (array) => {
    const arr = [
      Constant('['),
      ...intersperse(
        array.elements
          .map((elem) => Expression.fromAst(elem))
          .map((elem) =>
            !(elem.ty.ty.k === 'primitive' && elem.ty.ty.t === 'string')
              ? elem.cast('string')
              : elem
          ),
        Constant(', ')
      ),
      Constant(']'),
    ].reduce((l, r) => l.concat(r));

    return [NullIf(arr, '[]').ir()];
  },
  makeRecord(fields) {
    return this.makeStruct(fields);
  },
  makeStruct(fields) {
    const sortedFields: Expression[] = _.chain(fields)
      .entries()
      .sortBy(([name, _expr]) => name)
      .map(([name, expr], i, arr) => {
        const val = Expression.fromAst(expr);
        const isValString = TC.implementsTy({
          subject: val.ty,
          req: 'string',
        });
        const isValTimestamp = TC.implementsTy({
          subject: val.ty,
          req: 'timestamp',
        });
        return f`"${name}": ${
          isValString || isValTimestamp ? f`"${val}"` : val
        }${i < arr.length - 1 ? ', ' : ''}`;
      })
      .value();

    const res = [Constant('{'), ...sortedFields, Constant('}')].reduce((l, r) =>
      l.concat(r)
    );
    return [res.ir()];
  },
  getPropertyFromStruct,
  typeMapping,
  functionOverrides: {
    get_from_record: ([arg0, arg1], wanted) =>
      getPropertyFromStruct(arg0!, arg1!, wanted),
    is_numeric_string: ([arg0]) => sqlExprMacro`${arg0!} ~ '^\\\\d+$'`,
    string_agg: ([arg0, arg1]) => sqlExprMacro`LISTAGG(${arg0!}, ${arg1!})`,
    array_agg: ([arg0]) => {
      const asStr = Expression.fromAst(arg0!).impure().cast('string');
      return sqlExprMacro`'[' || listagg(${asStr.ir()}, ', ') || ']'`;
    },
    avg: ([arg0]) => sqlExprMacro`avg((${arg0!})::float)::float`,
    count: ([arg0]) => sqlExprMacro`count(${arg0!})::int`,
    sum: ([arg0]) => sqlExprMacro`sum(${arg0!})::float`,
    percentile_cont: ([arg0, arg1]) =>
      sqlExprMacro`percentile_cont(${arg1!}) within group (order by (${arg0!}) asc)`,
    percentile_disc: ([arg0, arg1]) =>
      sqlExprMacro`approximate percentile_disc(${arg1!}) within group (order by (${arg0!}) asc)`,
    gen_random_uuid: () => [genUuid],
    corr: new Error('Redshift does not support `Corr`'),
    date_diff: ([arg0, arg1, arg2]) => {
      const unit: string = { days: 'DAY', years: 'YEAR', seconds: 'SECOND' }[
        assertConstantString(arg2!, AST.DATE_DIFF_UNITS)
      ];

      switch (unit) {
        case 'SECOND':
          return sqlExprMacro`datediff(${assertConstantString(
            arg2!,
            AST.DATE_DIFF_UNITS
          )}, ((${arg0!})::timestamptz at time zone 'UTC'), ((${arg1!})::timestamptz at time zone 'UTC'))::float`;
        default:
          return sqlExprMacro`datediff(${assertConstantString(
            arg2!,
            AST.DATE_DIFF_UNITS
          )}, ((${arg0!})::timestamptz at time zone 'UTC'), ((${arg1!})::timestamptz at time zone 'UTC'))::int`;
      }
    },
    date_add: ([arg0, arg1, arg2]) =>
      sqlExprMacro`dateadd(${assertConstantString(
        arg2!,
        AST.DATE_ADD_UNITS
      )}, (${arg1!}), cast((${arg0!}) as timestamp))`,
    date_trunc: PostgresDialect.functionOverrides.date_trunc,
    date_part: PostgresDialect.functionOverrides.date_part,
    log_10: PostgresDialect.functionOverrides.log_10,
    log_2: PostgresDialect.functionOverrides.log_2,
    nan: () => sqlExprMacro`'NaN'::float`,
    is_nan: ([arg0]) => sqlExprMacro`(${arg0!}) = 'NaN'::float`,
    now: () => sqlExprMacro`current_timestamp`,
    cosine_distance: ([_arg0, _arg1]) => {
      throw new Error(
        `Function 'cosine_distance' is not implemented for the Redshift dialect.`
      );
    },
  },
  values({ values, attributes }) {
    const json = JSON.stringify(values);
    const exprsSql =
      Object.keys(attributes).length > 0
        ? _.chain(attributes)
            .entries()
            .sortBy(([name, _ty]) => name)
            .map(([name, { ty }]) => {
              const extracted = `json_extract_path_text(row, '${name}', true)`;

              if (ty.k == 'primitive' && ty.t === 'boolean') {
                return `${extracted} = 'true' as "${name}"`;
              }

              return `cast(${extracted} as ${typeMapping(ty)}) as "${name}"`;
            })
            .join(',\n')
            .value()
        : sqlExprMacro`null`;

    return sqlExprMacro`(
      with 
        "numbers" as (${numbers(values.length - 1)}),
        "data" as (select ${Constant(json, {
          ty: 'string',
        }).ir()}::text as "json"),
        "rows" as (
          select 
            json_extract_array_element_text(data.json, numbers.generated_number, true) as row
          from data cross join numbers
        )
      select ${exprsSql} from "rows"
    )`;
  },
};
