import { create } from 'zustand';
import { immer } from 'zustand/middleware/immer';
import { applyNodeChanges, applyEdgeChanges } from '@xyflow/react';
import { createNode, removeNode } from '@/api/nodes';
import { disconnectInputFromOutput, connectInputToOutput } from '@/api/pathways';
import {
  parseNodes,
  parseEdges,
  getLayoutedNodes,
  parseNode,
  getNodesById,
  serializeEdge,
  applyNodeInfo,
} from '@/utils/pathwayUtils';
import useCaseStore, { caseActions } from './caseStore';
import { keyBy } from '@/utils/miscUtils';
import { pathwaysCacheActions } from './pathwaysCacheStore';
import { patchNode } from '@/api/nodes';
import { NODE_TYPES } from '@/consts';

const initialValues = {
  nodes: [],
  edges: [],
  analysis_id: null,
};

const usePathwayStore = create()(immer(() => initialValues));
const { setState: set, getState: get } = usePathwayStore;

// actions
export const pathwayActions = {
  init: async pathway => {
    const { id, label, case_id, name, connections, nodes } = pathway;
    const nodesById = getNodesById(nodes);
    const edges = parseEdges(connections, nodesById);
    const { node_info } = useCaseStore.getState();

    const initialNodes = parseNodes(applyNodeInfo(nodes, node_info));
    const cachedNodes = pathwaysCacheActions.getCachedNodePositions(id);
    const nodesWithPositions = await getLayoutedNodes(initialNodes, cachedNodes, edges);
    const edgesBySource = keyBy(edges, 'sourceHandle');
    const edgesByTarget = keyBy(edges, 'targetHandle');

    set({ id, case_id, label, name, nodes: nodesWithPositions, edges, edgesBySource, edgesByTarget });
  },
  clear: () => set(initialValues, true),
  onNodesChange: changes => {
    const change = changes?.[0];

    if (change?.type === 'position' && change?.position) {
      const pathwayId = get().id;
      const {
        id: nodeId,
        position: { x, y },
      } = change;

      pathwaysCacheActions.cacheNodePosition(pathwayId, nodeId, { x, y });
    }

    set({ nodes: applyNodeChanges(changes, get().nodes) });
  },
  onEdgesChange: changes => {
    set({ edges: applyEdgeChanges(changes, get().edges) });
  },
  addNode: async (nodeTypeStr, position) => {
    const { id: pathwayId, case_id: caseId } = get();

    const nodeType = JSON.parse(nodeTypeStr);
    const { name, type } = nodeType;
    const { data } = await createNode(pathwayId, { name });
    const label = nodeType.label || data.label;

    if (type === NODE_TYPES.system) {
      const systemId = nodeType?.data?.system?.id;
      await patchNode(caseId, data.id, { params: { system_id: systemId } });
    }

    const node = parseNode({ ...data, position, name, type, label, data: nodeType?.data });

    set(state => {
      state.nodes.push(node);
    });

    caseActions.markUnbalanced();
  },
  update: pathway => {
    const nodesById = getNodesById(pathway.nodes);
    const edges = parseEdges(pathway.connections, nodesById);
    const edgesBySource = keyBy(edges, 'sourceHandle');
    const edgesByTarget = keyBy(edges, 'targetHandle');

    set(state => {
      state.nodes.forEach(node => {
        node.data.inputs = nodesById[node.id].inputs;
        node.data.outputs = nodesById[node.id].outputs;
      });
      state.edges = edges;
      state.edgesBySource = edgesBySource;
      state.edgesByTarget = edgesByTarget;
    });

    caseActions.markUnbalanced();
  },
  disconnectInputFromOutput: async edge => {
    const { id: pathwayId } = get();
    const serializedEdge = serializeEdge(edge);
    const { data: pathway } = await disconnectInputFromOutput(pathwayId, serializedEdge);

    pathwayActions.update(pathway);
  },
  connectInputToOutput: async edge => {
    const { id: pathwayId } = get();
    const serializedEdge = serializeEdge(edge);
    const { data: pathway } = await connectInputToOutput(pathwayId, serializedEdge);

    pathwayActions.update(pathway);
  },
  removeNode: async nodeId => {
    const { id } = get();

    await removeNode(id, nodeId);

    const { edges } = get();

    const updatedEdges = edges.filter(edge => edge.source !== nodeId && edge.target !== nodeId);
    const edgesBySource = keyBy(updatedEdges, 'sourceHandle');
    const edgesByTarget = keyBy(updatedEdges, 'targetHandle');

    set(state => {
      state.edges = updatedEdges;
      state.nodes = state.nodes.filter(node => node.id !== nodeId);
      state.edgesBySource = edgesBySource;
      state.edgesByTarget = edgesByTarget;
    });

    pathwaysCacheActions.clearNodeCache(id, nodeId);
    caseActions.markUnbalanced();
  },
};

// selectors
export const useNodes = () => usePathwayStore(store => store.nodes);
export const useEdges = () => usePathwayStore(store => store.edges);
export const usePathwayId = () => usePathwayStore(store => store.id);
export const useCaseId = () => usePathwayStore(store => store.case_id);
export const useNodeCount = () => usePathwayStore(store => store?.nodes?.length);
export const useIsIOConnected = name =>
  usePathwayStore(store => store.edgesBySource[name] || store.edgesByTarget[name]);

export default usePathwayStore;
