import { Result, err, ok } from 'neverthrow';
import {
  FunctionIdentifier,
  isFunctionIdentifier,
} from '../../ast/func-identifier';
import { Expression } from '../../builder';
import { Assert } from '../../utils';
import {
  ExprCtx,
  GroupingToken,
  NongroupingToken,
  OperatorToken,
  ParseError,
  ParseResult,
  SymbolToken,
  Token,
  TokenMeta,
  WithMeta,
} from '../ast';
import { OPERATOR_PRECENDENCE } from './reference-tables';
import {
  ArityMismatch,
  FunctionDoesNotExist,
  NoSuchAttribute,
} from '../../type-checker/type-check-error';
import _ from 'lodash';
import { FUNCTION_ARITY } from '../../type-checker/expr/func-arity';

export const parseExpr = (
  tokenStream: Iterable<WithMeta<Token>>,
  ctx: ExprCtx
): ParseResult => {
  const rpn = toRpn(tokenStream);
  return evalRpn(rpn, ctx);
};

export function* toRpn(
  tokenStream: Iterable<WithMeta<Token>>
): Generator<WithMeta<Result<NongroupingToken, ParseError>>> {
  // https://en.wikipedia.org/wiki/Shunting_yard_algorithm
  const opStack: WithMeta<GroupingToken | OperatorToken | SymbolToken>[] = [];

  stream: for (const [token, meta] of tokenStream) {
    switch (token.t) {
      case 'ident':
      case 'literal':
      case 'raw-expression':
      case 'white-space': {
        yield [ok(token), meta];
        continue stream;
      }
      case 'grouping': {
        switch (token.op) {
          case '(': {
            opStack.push([token, meta]);
            break;
          }
          case ')': {
            while (opStack.length > 0) {
              const [topOp, topOpMeta] = opStack.pop()!;
              if (topOp.t === 'grouping') {
                switch (topOp.op) {
                  case '(': {
                    const nextToken = opStack.pop();
                    if (nextToken) {
                      const [nextOp, nextOpMeta] = nextToken;
                      if (nextOp.t === 'symbol') {
                        yield [ok(nextOp), nextOpMeta];
                      } else {
                        opStack.push([nextOp, nextOpMeta]);
                      }
                    }
                    continue stream;
                  }
                  case '[':
                  case ']':
                    throw new Error('TODO BRACKET SUPPORT');
                  case ')':
                  case ',':
                    throw new Error('UNREACHABLE');
                  default:
                    return Assert.unreachable(topOp);
                }
              } else {
                yield [ok(topOp), topOpMeta];
              }
            }

            yield [
              err({ t: 'unclosed-deliminator' as const, delim: ')' }),
              meta,
            ];
            return;
          }
          case ',': {
            while (opStack.length > 0) {
              const [nextOp, nextOpMeta] = opStack.pop()!;
              if (nextOp.t === 'grouping') {
                switch (nextOp.op) {
                  case '(':
                    opStack.push([nextOp, nextOpMeta]);
                    continue stream;
                  case ')':
                  case ',':
                    throw new Error(
                      'UNREACHABLE WE DONT PUSH THESE TO THE STACK'
                    );
                  case '[':
                  case ']':
                    throw new Error('TODO BRACKET SUPPORT');
                }
              } else {
                yield [ok(nextOp), nextOpMeta];
              }
            }

            yield [err({ t: 'unexpected-token', token: ',' }), meta];
            return;
          }
          case '[':
          case ']': {
            throw new Error('TODO BRACKET SUPPORT');
          }
        }
        continue stream;
      }
      case 'operator': {
        const precedence = OPERATOR_PRECENDENCE[token.op];
        checkTopOp: while (opStack.length > 0) {
          const [topOp, topOpMeta] = opStack.pop()!;
          if (topOp.t === 'symbol') {
            yield [err({ t: 'unexpected-token', token: topOp.sym }), topOpMeta];
            return;
          }

          if (topOp.t === 'grouping') {
            if (topOp.op === '(') {
              opStack.push([topOp, topOpMeta]);
              break checkTopOp;
            }
            throw new Error('BRACKET SUPPORT');
          }

          const topOpPrecedence = OPERATOR_PRECENDENCE[topOp.op];

          if (topOpPrecedence >= precedence) {
            yield [ok(topOp), topOpMeta];
            continue checkTopOp;
          } else {
            opStack.push([topOp, topOpMeta]);
            break checkTopOp;
          }
        }

        opStack.push([token, meta]);
        continue stream;
      }
      case 'symbol': {
        opStack.push([token, meta]);
        continue stream;
      }
      case 'unknown': {
        yield [err({ t: 'unexpected-token', token: token.found }), meta];
        return;
      }
      default:
        return Assert.unreachable(token);
    }
  }

  while (opStack.length > 0) {
    const [token, meta] = opStack.pop()!;
    if (token.t === 'grouping') {
      yield [err({ t: 'unclosed-deliminator', delim: token.op }), meta];
      return;
    }
    yield [ok(token as OperatorToken | SymbolToken), meta];
  }
}

export const evalRpn = (
  tokens: Iterable<WithMeta<Result<NongroupingToken, ParseError>>>,
  ctx: ExprCtx
): ParseResult => {
  const stack: WithMeta<Expression>[] = [];
  stream: for (const [next, meta] of tokens) {
    if (next.isErr()) {
      return err([next.error, meta] as const);
    }

    const token = next.value;
    switch (token.t) {
      case 'literal':
      case 'raw-expression': {
        stack.push([Expression.wrap(token.val), meta]);
        break;
      }
      case 'ident': {
        const attr = ctx.attributes[token.name];
        if (!attr) {
          return err([
            {
              t: 'type-check-error',
              trace: new NoSuchAttribute({
                attributeName: token.name,
                relation: {
                  attributes: _.mapValues(ctx.attributes, (attr) => attr.ty),
                },
              }).toStackTrace({}),
            },
            meta,
          ]);
        }
        stack.push([attr, meta]);
        break;
      }
      case 'symbol':
      case 'operator': {
        const op: string = token.t === 'symbol' ? token.sym : token.op;
        if (!isFunctionIdentifier(op)) {
          return err([
            {
              t: 'type-check-error',
              trace: new FunctionDoesNotExist({ name: op }).toStackTrace({}),
            },
            meta,
          ]);
        }

        const arity = FUNCTION_ARITY[op];
        const args: WithMeta<Expression>[] = [];

        for (let i = 0; i < arity; i++) {
          const x = stack.pop();
          if (x === undefined) {
            return err([
              {
                t: 'type-check-error',
                trace: new ArityMismatch({
                  name: op,
                  expected: arity,
                  attempted: args.length,
                }).toStackTrace({}),
              },
              meta,
            ]);
          }
          args.push(x);
        }

        const newOp = compileOp([op, meta], args.reverse());

        if (newOp.isErr()) {
          return err(newOp.error);
        }
        stack.push(newOp.value);
        break;
      }
      case 'white-space':
        continue stream;
    }
  }

  const [top, ...rest] = stack;

  Assert.assert(rest.length === 0, 'Arity was calculated correctly');

  if (!top) {
    return err([{ t: 'empty-expression' }, { range: [0, 0] }]);
  }

  return ok(top[0]);
};

const compileOp = (
  [op, opMeta]: WithMeta<FunctionIdentifier>,
  args: WithMeta<Expression>[]
): Result<WithMeta<Expression>, WithMeta<ParseError>> => {
  const newMeta = combineMeta(opMeta, ...args.map((arg) => arg[1]));
  return Expression.tryfromAst({
    t: 'function-call',
    op,
    args: args.map((arg) => arg[0].ast),
  })
    .map((x): WithMeta<Expression> => [x, newMeta])
    .mapErr(
      (trace): WithMeta<ParseError> => [
        { trace, t: 'type-check-error' as const },
        newMeta,
      ]
    );
};

const combineMeta = (...tokenMeta: TokenMeta[]): TokenMeta =>
  tokenMeta.reduce((acc, curr) => ({
    range: [
      Math.min(acc.range[0], curr.range[0]),
      Math.max(acc.range[1], curr.range[1]),
    ],
  }));
