import { AST } from '../../ast';
import { Assert } from '@cotera/utilities';
import { checkExpr } from './check-expr';
import { TyStackTrace } from '../ty-stack-trace';
import { Expression } from '../../builder/expression';

type AnalysisCategory = 'now' | 'rand' | 'other';

type Reqs = Record<AnalysisCategory, boolean>;

const CONST_REQS_CACHE: WeakMap<AST.Expr, Reqs> = new WeakMap();

const PURE: Reqs = {
  now: false,
  rand: false,
  other: false,
};

export const isExprInterpretable = (
  expr: AST.Expr,
  opts?: {
    allow?: { now?: boolean; rand?: boolean; undefaultedVars?: boolean };
  }
): boolean => {
  const check = checkExpr(expr);

  if (check instanceof TyStackTrace) {
    return false;
  }

  const { now, rand, other } = exprPurityAnalysis(expr);

  if (now && !opts?.allow?.now) {
    return false;
  }

  if (rand && !opts?.allow?.rand) {
    return false;
  }

  const attrsNeeded = Object.values(check.attrReqs).some(
    (attrs) => Object.keys(attrs).length > 0
  );

  if (attrsNeeded || check.aggregated || check.windowed || other) {
    return false;
  }

  const varsNeeded = Object.values(check.vars).some(({ exprs }) =>
    Object.values(exprs).some((expr) => !expr.defaulted)
  );

  if (varsNeeded && !opts?.allow?.undefaultedVars) {
    return false;
  }

  return true;
};

export const exprPurityAnalysis = (expr: AST.Expr): Reqs => {
  const existing = CONST_REQS_CACHE.get(expr);

  if (existing) {
    return existing;
  }

  const { t } = expr;

  let reqs: Reqs;

  switch (t) {
    case 'attr':
    case 'scalar': {
      reqs = PURE;
      break;
    }
    case 'expr-var': {
      reqs = expr.default === null ? PURE : exprPurityAnalysis(expr.default);
      break;
    }
    case 'cast': {
      const { expr: child, targetTy, ...rest } = expr;

      const { ty: childTy } = Expression.fromAst(child);

      const isSimpleCast =
        childTy.ty.k === 'primitive' &&
        childTy.ty.t !== 'super' &&
        targetTy.k === 'primitive' &&
        targetTy.t !== 'super' &&
        targetTy.t !== 'timestamp';

      Assert.matchesType<NoExprs<typeof rest>>(rest);
      reqs = mergeReqs(exprPurityAnalysis(child), { other: !isSimpleCast });
      break;
    }
    case 'case': {
      const { cases, else: else_, ...rest } = expr;
      Assert.matchesType<NoExprs<typeof rest>>(rest);

      reqs = mergeReqs(
        ...cases.flatMap(({ when, then }) => [
          exprPurityAnalysis(when),
          exprPurityAnalysis(then),
        ]),
        ...(else_ ? [exprPurityAnalysis(else_)] : [])
      );
      break;
    }
    case 'match': {
      const { cases, expr: child, ...rest } = expr;
      Assert.matchesType<NoExprs<typeof rest>>(rest);

      reqs = mergeReqs(
        exprPurityAnalysis(child),
        ...Object.values(cases).map((x) => exprPurityAnalysis(x))
      );
      break;
    }
    case 'window': {
      const { args, over, ...rest } = expr;
      Assert.matchesType<NoExprs<typeof rest>>(rest);
      reqs = mergeReqs(
        ...args.map((arg) => exprPurityAnalysis(arg)),
        ...over.orderBy.map((orderBy) => exprPurityAnalysis(orderBy.expr)),
        ...over.partitionBy.map((partitionBy) =>
          exprPurityAnalysis(partitionBy)
        )
      );
      break;
    }
    case 'make-struct': {
      const { fields, ...rest } = expr;
      Assert.matchesType<NoExprs<typeof rest>>(rest);
      reqs = mergeReqs(
        ...Object.values(fields).map((field) => exprPurityAnalysis(field))
      );
      break;
    }
    case 'make-array': {
      const { elements, ...rest } = expr;
      Assert.matchesType<NoExprs<typeof rest>>(rest);
      reqs = mergeReqs(...elements.map((elem) => exprPurityAnalysis(elem)));
      break;
    }
    case 'function-call': {
      const { op, args, ...rest } = expr;
      Assert.matchesType<NoExprs<typeof rest>>(rest);
      reqs = mergeReqs(
        ...args.map((arg) => exprPurityAnalysis(arg)),
        FN_REQS[op] ?? {}
      );
      break;
    }
    case 'get-field': {
      const { expr: child, ...rest } = expr;
      Assert.matchesType<NoExprs<typeof rest>>(rest);
      reqs = exprPurityAnalysis(child);
      break;
    }
    case 'invariants': {
      const { expr: child, invariants, ...rest } = expr;
      Assert.matchesType<NoExprs<typeof rest>>(rest);
      reqs = mergeReqs(
        exprPurityAnalysis(child),
        ...Object.values(invariants).map((invariant) =>
          exprPurityAnalysis(invariant)
        )
      );
      break;
    }
    case 'macro-expr-case': {
      const { cases, else: else_, ...rest } = expr;
      Assert.matchesType<NoExprs<typeof rest>>(rest);
      reqs = mergeReqs(
        exprPurityAnalysis(else_),
        ...cases.flatMap(({ when, then }) => [
          exprPurityAnalysis(when),
          exprPurityAnalysis(then),
        ])
      );
      break;
    }
    case 'macro-apply-vars-to-expr': {
      const { vars, sources, ...rest } = expr;
      Assert.matchesType<NoExprs<typeof rest>>(rest);

      reqs = mergeReqs(
        exprPurityAnalysis(sources.from),
        ...Object.values(vars.exprs).map((x) => exprPurityAnalysis(x))
      );

      break;
    }
    default:
      return Assert.unreachable(t);
  }

  CONST_REQS_CACHE.set(expr, reqs);
  return reqs;
};

const mergeReqs = (...reqs: Partial<Reqs>[]): Reqs =>
  reqs.reduce<Reqs>(
    (acc, next): Reqs => ({
      now: (acc.now ?? false) || (next.now ?? false),
      rand: (acc.rand ?? false) || (next.rand ?? false),
      other: (acc.other ?? false) || (next.other ?? false),
    }),
    PURE
  );

const FN_REQS: Partial<Record<AST.FunctionIdentifier, Partial<Reqs>>> = {
  // Now
  now: { now: true },
  // Random
  random: { rand: true },
  gen_random_uuid: { rand: true },
  // These can probably be pure functions
  nan: { other: true },
  round: { other: true },
  // Always impure
  impure: { other: true },
};

type NoExprs<T extends Object> = {
  [Key in keyof T]: T[Key] extends AST.Expr
    ? never
    : T[Key] extends Object
    ? NoExprs<T[Key]>
    : T[Key];
};
