import { useEffect, useState } from "react";
import {
  GridRowId,
  GridRowSelectionModel,
  useGridApiRef,
} from "@mui/x-data-grid-pro";
import { GridStatePro } from "@mui/x-data-grid-pro/models/gridStatePro";

const useGridActionToolbar = () => {
  const [allSelected, setAllSelected] = useState(false);
  const [selectedRows, setSelectedRows] = useState<GridRowId[]>([]);
  const gridApiRef = useGridApiRef();

  const handleSelectRow = (rowSelectionModel: GridRowSelectionModel) => {
    setSelectedRows(rowSelectionModel);
  };

  const clearRows = () => {
    setSelectedRows([]);
    gridApiRef.current.setState((prevState: GridStatePro) => {
      return { ...prevState, rowSelection: [] };
    });
  };

  const checkRows = (newRowIds: GridRowId[]) => {
    gridApiRef.current.setState((prevState) => {
      return {
        ...prevState,
        rowSelection: [...prevState.rowSelection, ...newRowIds],
      };
    });
  };

  useEffect(() => {
    if (allSelected) {
      // Sync internal grid state with allSelected state
      // dataRowIds is an array of all FETCHED keyword ids
      setSelectedRows(gridApiRef.current.state.rows.dataRowIds);
      gridApiRef.current.setState((prevState: GridStatePro) => {
        return {
          ...prevState,
          rowSelection: gridApiRef.current.state.rows.dataRowIds,
        };
      });
    }
  }, [allSelected, gridApiRef]);

  useEffect(() => {
    const { totalRowCount = 0, dataRowIds = [] } =
      gridApiRef.current.state?.rows ?? {};

    const checkedRowsCount = selectedRows.length;
    const loadedRowsCount = dataRowIds.length;

    // If the number of checked rows is not equal to the number of loaded rows, then not all rows are selected
    // Covers use case where allSelected is true, then a user unselects a row
    if (checkedRowsCount !== loadedRowsCount) {
      setAllSelected(false);
    }

    // If the number of checked rows is equal to the number of total rows, then all rows are selected
    // Covers use case where a user manually checks all rows in the grid
    if (checkedRowsCount > 0 && checkedRowsCount === totalRowCount) {
      setAllSelected(true);
    }
  }, [selectedRows, gridApiRef]);

  return {
    allSelected,
    gridApiRef,
    clearRows,
    checkRows,
    handleSelectRow,
    setAllSelected,
    selectedCount: selectedRows.length,
    selectedRows,
  };
};

export { useGridActionToolbar };
