import { Interpreter } from '../interpreter';
import { AST } from '../ast';
import { z } from 'zod';
import { Vars } from './base';
import { expandExpr } from './expand-expr';
import _ from 'lodash';
import { Assert } from '@cotera/utilities';
import deepEquals from 'fast-deep-equal';
import { TC, TyStackTrace } from '../type-checker';
import { buildCacheEntryFn } from './expansion-cache';

const cacheFor = buildCacheEntryFn<AST.Rel, AST.RelIR>({
  openScopes: (rel) => {
    const tc = TC.checkRel(rel);
    Assert.assert(!(tc instanceof TyStackTrace));
    return new Set(Object.keys(tc.vars));
  },
});

export const expandRel = (rel: AST.Rel, vars: Vars): AST.RelIR => {
  const { t } = rel;

  const cacheEntry = cacheFor(rel, vars);

  if (cacheEntry.t === 'existing') {
    return cacheEntry.val;
  }

  let ir: AST.RelIR;

  switch (t) {
    case 'file':
    case 'table':
    case 'values':
    case 'information-schema':
      ir = rel;
      break;
    case 'generate-series':
      ir = {
        ...rel,
        start: expandExpr(rel.start, vars),
        stop: expandExpr(rel.stop, vars),
      };
      break;
    case 'rel-var': {
      const variable = vars[rel.scope]?.rels?.[rel.name];
      if (variable === undefined) {
        if (rel.default === null) {
          throw new Error(
            `Error expanding found relation "${rel.scope}"."${rel.name}" that was unreplaced`
          );
        } else {
          ir = expandRel(rel.default, vars);
        }
      } else {
        ir = expandRel(variable, vars);
      }
      break;
    }
    case 'macro-rel-case': {
      ir = expandRel(rel.else, vars);

      loop: for (const { when, then } of rel.cases) {
        const cond = z
          .boolean()
          .nullable()
          .parse(Interpreter.evalExprIR(expandExpr(when, vars)));

        if (cond) {
          ir = expandRel(then, vars);
          break loop;
        }
      }

      break;
    }
    case 'macro-apply-vars-to-rel': {
      const combinedVars = {
        ...vars,
        [rel.scope]: rel.vars,
      };
      ir = expandRel(rel.sources.from, combinedVars);
      break;
    }
    case 'select':
      ir = {
        ...rel,
        condition:
          rel.condition !== null ? expandExpr(rel.condition, vars) : null,
        selection: _.mapValues(rel.selection, (expr) => expandExpr(expr, vars)),
        orderBys: rel.orderBys.map(({ expr, direction }) => ({
          expr: expandExpr(expr, vars),
          direction,
        })),
        sources: { from: expandRel(rel.sources.from, vars) },
      };
      break;
    case 'aggregate':
      ir = {
        ...rel,
        selection: _.mapValues(rel.selection, (expr) => expandExpr(expr, vars)),
        sources: { from: expandRel(rel.sources.from, vars) },
      };
      break;
    case 'join':
      ir = {
        ...rel,
        selection: _.mapValues(rel.selection, (expr) => expandExpr(expr, vars)),
        condition: expandExpr(rel.condition, vars),
        sources: {
          left: expandRel(rel.sources.left, vars),
          right: expandRel(rel.sources.right, vars),
        },
      };
      break;
    case 'union':
      ir = {
        ...rel,
        sources: {
          left: expandRel(rel.sources.left, vars),
          right: expandRel(rel.sources.right, vars),
        },
      };
      break;
    default:
      return Assert.unreachable(t);
  }

  const res = deepEquals(ir, rel) ? (rel as AST.RelIR) : ir;
  cacheEntry.set(res);
  return res;
};
