import { err, ok } from 'neverthrow';
import { AST } from '../../ast';
import { Constant } from '../../builder/utilities';
import { Assert } from '../../utils';
import { WithMeta } from '../lex/types';
import {
  Assoc,
  Call,
  CompileOutput,
  EraQlAst,
  List,
  ParseCtx,
  ParseRes,
} from './ast';
import { IdentToken, LiteralToken, SymbolToken, TagToken } from '../lex/tokens';
import {
  FunctionIdentifier,
  isFunctionIdentifier,
} from '../../ast/func-identifier';
import _ from 'lodash';
import { Ty } from '../../ty';
import { SPECIAL_FORMS } from './special-forms';

export const compileAst = (
  [ast, meta]: WithMeta<EraQlAst>,
  ctx: ParseCtx
): ParseRes<CompileOutput> => {
  switch (ast.t) {
    case 'symbol':
      return compileSym([ast, meta]);
    case 'literal':
      return compileLiteral([ast, meta]);
    case 'ident':
      return compileIdent([ast, meta], ctx);
    case 'call':
      return compileCall([ast, meta], ctx);
    case 'list':
      return compileList([ast, meta], ctx);
    case 'assoc':
      return compileAssoc([ast, meta], ctx);
    default:
      return Assert.unreachable(ast);
  }
};

const compileAssoc = (
  [ast, _meta]: WithMeta<Assoc>,
  ctx: ParseCtx
): ParseRes<CompileOutput> => {
  const fields: Record<string, AST.ExprIR> = {};

  for (const [{ key, val }, _kvMeta] of ast.kvs) {
    const exprRes = expectExpr(val, ctx);
    if (exprRes.isErr()) {
      return err(exprRes.error);
    }

    fields[key[0]] = exprRes.value;
  }

  const expr: AST._MakeStruct<AST.IRChildren> = {
    t: 'make-struct',
    fields,
  };

  return ok({ t: 'expr', expr: expr });
};

const compileList = (
  [ast, _meta]: WithMeta<List>,
  ctx: ParseCtx
): ParseRes<CompileOutput> => {
  const elements: AST.ExprIR[] = [];

  for (const elem of ast.elems) {
    const exprRes = expectExpr(elem, ctx);
    if (exprRes.isErr()) {
      return err(exprRes.error);
    }

    elements.push(exprRes.value);
  }

  const expr: AST._MakeArray<AST.IRChildren> = {
    t: 'make-array',
    elements,
  };

  return ok({ t: 'expr', expr });
};

const compileCall = (
  [ast, _meta]: WithMeta<Call>,
  ctx: ParseCtx
): ParseRes<CompileOutput> => {
  const { op, args } = ast;
  const [opToken, opMeta] = op;

  switch (opToken.t) {
    case 'tag':
      return callTag([opToken, opMeta], args, ctx);
    case 'operator':
    case 'symbol': {
      const sym = opToken.t === 'symbol' ? opToken.sym : opToken.op;

      const spForm = SPECIAL_FORMS[sym];

      if (spForm) {
        const [arity, fn] = spForm;

        if (arity !== 'any') {
          if (typeof arity === 'number') {
            if (arity !== args.length) {
              return err([{ t: 'invalid-arity' }, opMeta]);
            }
          } else if (!arity.includes(args.length)) {
            return err([{ t: 'invalid-arity' }, opMeta]);
          }
        }

        return fn(args, ctx);
      }

      if (isFunctionIdentifier(sym)) {
        return callFn([sym, opMeta], args, ctx);
      }

      return err([{ t: 'cant-call' }, opMeta]);
    }
    case 'call':
    case 'literal':
    case 'ident':
    case 'assoc':
    case 'list':
      return err([{ t: 'cant-call' }, opMeta]);
    default:
      return Assert.unreachable(opToken);
  }
};

const callFn = (
  [op, _opMeta]: WithMeta<FunctionIdentifier>,
  argsAst: WithMeta<EraQlAst>[],
  ctx: ParseCtx
): ParseRes<CompileOutput> => {
  const args: AST.ExprIR[] = [];

  for (const arg of argsAst) {
    const res = expectExpr(arg, ctx);

    if (res.isErr()) {
      return err(res.error);
    }

    args.push(res.value);
  }

  const expr: AST._FunctionCall<AST.IRChildren> = {
    t: 'function-call',
    op,
    args,
  };

  return ok({ t: 'expr', expr });
};

export const expectExpr = (
  [arg, argMeta]: WithMeta<EraQlAst>,
  ctx: ParseCtx
): ParseRes<AST.ExprIR> => {
  const res = compileAst([arg, argMeta], ctx);

  if (res.isErr()) {
    return err(res.error);
  }

  switch (res.value.t) {
    case 'rel':
      return err([{ t: 'expected', expected: 'expression' }, argMeta]);
    case 'expr':
      return ok(res.value.expr);
    case 'ty':
      return ok(Constant(Ty.displayTy(res.value.ty)).ir());
  }
};

export const expectTy = (
  [arg, argMeta]: WithMeta<EraQlAst>,
  ctx: ParseCtx
): ParseRes<Ty.ExtendedAttributeType> => {
  const res = compileAst([arg, argMeta], ctx);
  if (res.isErr()) {
    return err(res.error);
  }

  switch (res.value.t) {
    case 'ty':
      return ok(res.value.ty);
    case 'expr':
    case 'rel': {
      return err([{ t: 'expected', expected: 'type' }, argMeta]);
    }
  }
};

const callTag = (
  [{ tag }, tagMeta]: WithMeta<TagToken>,
  args: WithMeta<EraQlAst>[],
  ctx: ParseCtx
): ParseRes<CompileOutput> => {
  const [arg, ...rest] = args;
  if (arg === undefined || rest.length > 0) {
    return err([{ t: 'invalid-arity' }, tagMeta]);
  }

  const argRes = compileAst(arg, ctx);

  if (argRes.isErr()) {
    return err(argRes.error);
  }

  const inner = argRes.value;

  switch (inner.t) {
    case 'ty':
      return ok({ t: 'ty', ty: Ty.tag(inner.ty, [tag]) });
    case 'expr': {
      const expr: AST._FunctionCall<AST.IRChildren> = {
        t: 'function-call',
        op: 'tag',
        args: [inner.expr, Constant(tag).ir()],
      };

      return ok({ t: 'expr', expr });
    }
    case 'rel':
      return err([{ t: 'cant-tag' }, tagMeta]);
    default:
      return Assert.unreachable(inner);
  }
};

const compileIdent = (
  [ast, meta]: WithMeta<IdentToken>,
  ctx: ParseCtx
): ParseRes<CompileOutput> => {
  const attr = ctx.attributes?.[ast.name];

  if (attr === undefined) {
    return err([{ t: 'attribute-not-found', name: ast.name }, meta]);
  }

  return ok({ t: 'expr', expr: attr.ir() });
};

const compileLiteral = ([
  { val },
  _meta,
]: WithMeta<LiteralToken>): ParseRes<CompileOutput> => {
  return ok({ t: 'expr', expr: Constant(val).ir() });
};

const compileSym = ([
  { sym },
  meta,
]: WithMeta<SymbolToken>): ParseRes<CompileOutput> => {
  switch (sym) {
    case 'int':
    case 'super':
    case 'string':
    case 'timestamp':
    case 'float':
    case 'day':
    case 'month':
    case 'year':
    case 'boolean':
      return ok({ t: 'ty', ty: Ty.ty(sym) });
    case 'int!':
    case 'super!':
    case 'string!':
    case 'timestamp!':
    case 'float!':
    case 'day!':
    case 'month!':
    case 'year!':
    case 'boolean!':
      return ok({ t: 'ty', ty: Ty.nn(sym.slice(0, -1) as any) });
    default:
      return err([{ t: 'unexpected-token' }, meta]);
  }
};
