import { err, ok } from 'neverthrow';
import { AST } from '../../../ast';
import { IRChildren } from '../../../ast/base';
import { Assert } from '../../../utils';
import { expectExpr, expectTy } from '../compile';
import { SpecialForm } from './types';
import { WithMeta } from '../../lex/types';
import { EraQlAst, ParseCtx, ParseRes } from '../ast';
import { Ty } from '../../../ty';

export const makeExpectStruct =
  <Key extends string>(keys: Key[], opts: { strict: boolean }) =>
  (
    [arg, argMeta]: WithMeta<EraQlAst>,
    ctx: ParseCtx
  ): ParseRes<Record<Key, AST.ExprIR>> => {
    if (arg.t !== 'assoc') {
      return err([{ t: 'expected', expected: 'assoc' }, argMeta]);
    }

    const needed: Set<string> = new Set(keys);
    const res: Record<string, AST.ExprIR> = {};

    for (const [{ key, val }, _kvMeta] of arg.kvs) {
      const wasNeeded = needed.delete(key[0]);

      if (wasNeeded) {
        const valRes = expectExpr(val, ctx);

        if (valRes.isErr()) {
          return err(valRes.error);
        }

        res[key[0]] = valRes.value;
      } else {
        if (opts.strict) {
          return err([{ t: 'unexpected-key', key: key[0] }, key[1]]);
        }
      }
    }

    return ok(res);
  };

const match: SpecialForm = ([target, arms], ctx) => {
  Assert.assert(target !== undefined && arms !== undefined);
  const targetRes = expectExpr(target, ctx);

  if (targetRes.isErr()) {
    return err(targetRes.error);
  }

  const [armsAst, armsMeta] = arms;

  if (armsAst.t !== 'assoc') {
    return err([{ t: 'expected', expected: 'key value pairs' }, armsMeta]);
  }

  const cases: Record<string, AST.ExprIR> = {};

  for (const [{ key, val }, _meta] of armsAst.kvs) {
    const valRes = expectExpr(val, ctx);
    if (valRes.isErr()) {
      return err(valRes.error);
    }

    cases[key[0]] = valRes.value;
  }

  const expr: AST._Case<AST.IRChildren> = {
    t: 'case',
    cases: Object.entries(cases).map(([key, then]) => {
      const when: AST._FunctionCall<AST.IRChildren> = {
        t: 'function-call',
        op: 'eq',
        args: [targetRes.value, { t: 'scalar', ty: Ty.nn('string'), val: key }],
        display: 'infix',
      };

      return { when, then };
    }),
  };

  return ok({ item: { t: 'expr', expr }, bindings: {} });
};

const expectWhenThenStruct = makeExpectStruct(['when', 'then'], {
  strict: true,
});

const expectElseStruct = makeExpectStruct(['else'], { strict: true });

const case_: SpecialForm = ([branches, elseClause], ctx) => {
  Assert.assert(branches !== undefined);
  const [brachesAst, branchesMeta] = branches;

  if (brachesAst.t !== 'list') {
    return err([{ t: 'expected', expected: 'list' }, branchesMeta]);
  }

  const cases: { when: AST.ExprIR; then: AST.ExprIR }[] = [];

  for (const elem of brachesAst.elems) {
    const branch = expectWhenThenStruct(elem, ctx);

    if (branch.isErr()) {
      return err(branch.error);
    }

    cases.push(branch.value);
  }

  let else_: AST.ExprIR | undefined = undefined;

  if (elseClause !== undefined) {
    const elseRes = expectElseStruct(elseClause, ctx);
    if (elseRes.isErr()) {
      return err(elseRes.error);
    }

    else_ = elseRes.value.else;
  }

  const expr: AST._Case<AST.IRChildren> = {
    t: 'case',
    cases,
    else: else_,
  };

  return ok({ item: { t: 'expr', expr }, bindings: {} });
};

const cast: SpecialForm = ([target, tyRepr], ctx) => {
  Assert.assert(
    target !== undefined && tyRepr !== undefined,
    'already checked arity'
  );

  const targetRes = expectExpr(target, ctx);

  if (targetRes.isErr()) {
    return err(targetRes.error);
  }

  const tyRes = expectTy(tyRepr, ctx);

  if (tyRes.isErr()) {
    return err(tyRes.error);
  }

  const expr: AST._Cast<IRChildren> = {
    t: 'cast',
    targetTy: tyRes.value.item.ty.ty,
    expr: targetRes.value,
  };

  return ok({ item: { t: 'expr', expr }, bindings: {} });
};

const if_: SpecialForm = ([cond, branches], ctx) => {
  Assert.assert(
    cond !== undefined && branches !== undefined,
    'arity already checked'
  );

  const condRes = expectExpr(cond, ctx);

  if (condRes.isErr()) {
    return err(condRes.error);
  }

  const [bs, bsMeta] = branches;

  if (bs.t !== 'assoc') {
    return err([
      { t: 'expected', expected: '`when:` / `then:` struct' },
      bsMeta,
    ]);
  }

  const [then, duplicateThen] = bs.kvs.filter((kv) => kv[0].key[0] === 'then');

  if (duplicateThen !== undefined) {
    return err([{ t: 'duplicate-key', key: 'then' }, bsMeta]);
  }

  if (then === undefined) {
    return err([{ t: 'missing-key', key: 'then' }, bsMeta]);
  }

  const thenRes = expectExpr(then[0].val, ctx);

  if (thenRes.isErr()) {
    return err(thenRes.error);
  }

  const [else_, duplicateElse] = bs.kvs.filter((kv) => kv[0].key[0] === 'else');

  if (duplicateElse !== undefined) {
    return err([{ t: 'duplicate-key', key: 'else' }, bsMeta]);
  }

  let elseClause: AST.ExprIR | undefined;

  if (else_ !== undefined) {
    const elseRes = expectExpr(else_[0].val, ctx);
    if (elseRes.isErr()) {
      return err(elseRes.error);
    }

    elseClause = elseRes.value;
  }

  const expr: AST._Case<IRChildren> = {
    t: 'case',
    cases: [{ when: condRes.value, then: thenRes.value }],
    else: elseClause,
  };

  return ok({ item: { t: 'expr', expr }, bindings: {} });
};

export const EXPR_SPECIAL_FORMS: Record<
  string,
  [artiy: number | number[] | 'any', fn: SpecialForm]
> = {
  match: [2, match],
  case: [[1, 2], case_],
  cast: [2, cast],
  if: [2, if_],
};
