import { useBaseVariableByBaseId } from '@invisible/common/components/process-base'
import { SnackbarContext } from '@invisible/common/providers'
import {
  fromGlobalId,
  getErrorMessage,
  useBaseFineTuningModelsQuery,
  useLlmFineTuneBaseMutation,
} from '@invisible/concorde/gql-client'
import { toGlobalId } from '@invisible/concorde/gql-client'
import { useQuery } from '@invisible/trpc/client'
import { MUIThemeProvider } from '@invisible/ui/mui-theme-v2'
import LoadingButton from '@mui/lab/LoadingButton'
import Autocomplete from '@mui/material/Autocomplete'
import Box from '@mui/material/Box'
import Button from '@mui/material/Button'
import CircularProgress from '@mui/material/CircularProgress'
import Dialog from '@mui/material/Dialog'
import DialogActions from '@mui/material/DialogActions'
import DialogContent from '@mui/material/DialogContent'
import DialogTitle from '@mui/material/DialogTitle'
import Divider from '@mui/material/Divider'
import FormControl from '@mui/material/FormControl'
import InputLabel from '@mui/material/InputLabel'
import MenuItem from '@mui/material/MenuItem'
import Paper from '@mui/material/Paper'
import Select from '@mui/material/Select'
import Stack from '@mui/material/Stack'
import Tab from '@mui/material/Tab'
import Table from '@mui/material/Table'
import TableBody from '@mui/material/TableBody'
import TableCell from '@mui/material/TableCell'
import TableContainer from '@mui/material/TableContainer'
import TableHead from '@mui/material/TableHead'
import TablePagination from '@mui/material/TablePagination'
import TableRow from '@mui/material/TableRow'
import TableSortLabel from '@mui/material/TableSortLabel'
import Tabs from '@mui/material/Tabs'
import TextField from '@mui/material/TextField'
import { format, parseISO } from 'date-fns'
import { capitalize } from 'lodash/fp'
import * as React from 'react'
import { FormEvent, ReactNode, useContext, useMemo, useState } from 'react'
import { useQueryClient } from 'react-query'

interface FineTuneDialogProps {
  open: boolean
  handleClose: () => void
  baseId: string
}

interface TabPanelProps {
  children?: ReactNode
  index: number
  value: number
}

interface Prompt {
  id: string
  label: string
}

interface FormData {
  selectedModel: string | null
  selectedPrompts: Prompt[]
  selectedCompletion: {
    id: string
    label: string
  } | null
  selectedTask: string | null
  nSamples: number
}

type SortDirection = 'asc' | 'desc'

const MODELS = [
  { label: 'OpenAI GPT-3.5', provider: 'openai', value: 'gpt-3.5-turbo' },
  { label: 'OpenAI GPT-4', provider: 'openai', value: 'gpt-4' },
]

const TASK_OPTIONS = [
  { label: 'Classification', value: 'classification' },
  { label: 'Completion', value: 'completion' },
]

const initialFormData: FormData = {
  selectedModel: null,
  selectedPrompts: [],
  selectedCompletion: null,
  selectedTask: null,
  nSamples: 100,
}

const getTabA11yProps = (index: number) => ({
  id: `fine-tune-base-tab-${index}`,
  'aria-controls': `fine-tune-base-tabpanel-${index}`,
})

const formatTimestamp = (timestamp: string) => {
  try {
    const date = parseISO(timestamp)
    const formattedDate = format(date, "MMM d, yyyy, h:mm a 'UTC'")
    return formattedDate
  } catch (e) {
    return '-'
  }
}

const CustomTabPanel = (props: TabPanelProps) => {
  const { children, value, index, ...other } = props

  return (
    <Box
      role='tabpanel'
      hidden={value !== index}
      id={`fine-tune-tabpanel-${index}`}
      aria-labelledby={`fine-tune-tab-${index}`}
      {...other}>
      {value === index ? children : null}
    </Box>
  )
}

const sharedFormStyles = {
  width: '440px',
}

const FineTuneDialog = ({ open, handleClose, baseId }: FineTuneDialogProps) => {
  const [activeTab, setActiveTab] = useState(0)
  const [formData, setFormData] = useState(initialFormData)
  const [sortDirection, setSortDirection] = useState<SortDirection>('desc')
  const [page, setPage] = useState(0)
  const [rowsPerPage, setRowsPerPage] = useState(10)

  const { data: baseVariables } = useBaseVariableByBaseId({
    baseId,
    enabled: true,
  })

  const { data, isLoading: isLoadingFineTuningModels } = useBaseFineTuningModelsQuery({
    baseId: toGlobalId('BaseType', baseId),
  })

  const reactQueryClient = useQueryClient()

  const models = useMemo(
    () =>
      data?.baseFineTuningModels.edges.map((edge) => ({
        ...edge.node,
        id: fromGlobalId(edge.node.id),
      })) || [],
    [data]
  )

  const { showSnackbar } = useContext(SnackbarContext)

  const baseVariablesOptions = useMemo(
    () =>
      baseVariables
        ? baseVariables.map((baseVariable) => ({
            label: baseVariable.name,
            id: baseVariable.id,
          }))
        : [],
    [baseVariables]
  )

  const { mutateAsync: fineTuneBase, isLoading: isSubmittingFineTuneRequest } =
    useLlmFineTuneBaseMutation({
      onError: (e) => {
        showSnackbar({
          message: `Error starting fine-tune: ${getErrorMessage(e)}`,
          variant: 'error',
        })
      },
      onSuccess: () => {
        setFormData(initialFormData)
        showSnackbar({
          message: 'Fine-tune started',
          variant: 'success',
        })
      },

      onSettled: () => {
        reactQueryClient.invalidateQueries(['BaseFineTuningModels'])
      },
    })

  const sortedModels = useMemo(
    () =>
      models.sort((a, b) => {
        const dateA = new Date(a.createdAt).getTime()
        const dateB = new Date(b.createdAt).getTime()

        if (!isNaN(dateA) && !isNaN(dateB)) {
          return sortDirection === 'asc' ? dateA - dateB : dateB - dateA
        }

        if (isNaN(dateA)) return 1
        if (isNaN(dateB)) return -1

        return 0
      }),
    [models, sortDirection]
  )

  const handleSort = () => {
    setSortDirection(sortDirection === 'asc' ? 'desc' : 'asc')
  }

  const handleFineTuneBase = (e: FormEvent<HTMLFormElement>) => {
    e.preventDefault()

    const fineTuneData = {
      completionField: formData.selectedCompletion?.id as string,
      modelName: formData.selectedModel as string,
      provider: MODELS.find((m) => m.value === formData.selectedModel)?.provider as string,
      taskType: formData?.selectedTask as string,
      trainingFields: formData.selectedPrompts.map((p) => p.id),
      baseId,
      nSamples: formData.nSamples,
    }

    fineTuneBase({ input: fineTuneData })
  }

  const handleChangeRowsPerPage = (event: React.ChangeEvent<HTMLInputElement>) => {
    setRowsPerPage(+event.target.value)
    setPage(0)
  }

  const renderHistoryTab = () => {
    if (isLoadingFineTuningModels) {
      return (
        <Box
          sx={{
            display: 'flex',
            justifyContent: 'center',
            alignItems: 'center',
            height: '390px',
          }}>
          <CircularProgress />
        </Box>
      )
    }

    return (
      <Paper sx={{ width: '100%', overflow: 'hidden' }}>
        <TableContainer sx={{ height: '440px' }}>
          <Table stickyHeader aria-label='history table'>
            <TableHead>
              <TableRow>
                <TableCell align='left'>Model ID</TableCell>
                <TableCell align='left'>
                  <TableSortLabel active={true} direction={sortDirection} onClick={handleSort}>
                    Created At
                  </TableSortLabel>
                </TableCell>
                <TableCell align='left'>Status</TableCell>
              </TableRow>
            </TableHead>
            <TableBody>
              {sortedModels.slice(page * rowsPerPage, page * rowsPerPage + rowsPerPage).map((m) => (
                <TableRow key={m.id} sx={{ '&:last-child td, &:last-child th': { border: 0 } }}>
                  <TableCell align='left'>{m.metadata?.model_id ?? '-'}</TableCell>
                  <TableCell align='left'>{formatTimestamp(m.createdAt)}</TableCell>
                  <TableCell align='left'>{capitalize(m.status)}</TableCell>
                </TableRow>
              ))}
            </TableBody>
          </Table>
        </TableContainer>
        <TablePagination
          rowsPerPageOptions={[10, 25, 100]}
          component='div'
          count={sortedModels.length}
          rowsPerPage={rowsPerPage}
          page={page}
          onPageChange={(_, page) => setPage(page)}
          onRowsPerPageChange={handleChangeRowsPerPage}
        />
      </Paper>
    )
  }

  const renderSubmissionContent = () => (
    <Stack
      gap={4}
      sx={{
        height: '440px',
      }}>
      <FormControl
        required
        sx={{
          ...sharedFormStyles,
        }}>
        <InputLabel id='select-llm-label'>Model</InputLabel>
        <Select
          labelId='select-llm-label'
          id='select-llm'
          value={formData.selectedModel}
          label='Model *'
          required
          onChange={(e) => {
            setFormData({ ...formData, selectedModel: e.target.value })
          }}>
          {MODELS.map((model) => (
            <MenuItem value={model.value}>{model.label}</MenuItem>
          ))}
        </Select>
      </FormControl>

      <FormControl
        required
        sx={{
          ...sharedFormStyles,
        }}>
        <InputLabel id='tast-type-label'>Task Type</InputLabel>
        <Select
          labelId='tast-type-label'
          id='task-type'
          value={formData.selectedTask}
          label='Task Type *'
          required
          onChange={(e) => setFormData({ ...formData, selectedTask: e.target.value as string })}>
          {TASK_OPTIONS.map((task) => (
            <MenuItem value={task.value} disabled={task.value === 'completion'}>
              {task.label}
            </MenuItem>
          ))}
        </Select>
      </FormControl>

      <Autocomplete
        multiple
        disablePortal
        id='multi-select-prompt'
        options={baseVariablesOptions}
        value={formData.selectedPrompts}
        isOptionEqualToValue={(option, value) => option.id === value.id}
        ListboxProps={{ style: { maxHeight: '130px' } }}
        sx={{
          ...sharedFormStyles,
        }}
        onChange={(_, newValues) => {
          setFormData({ ...formData, selectedPrompts: newValues })
        }}
        renderInput={(params) => (
          <TextField
            required
            {...params}
            label='Prompt Fields'
            inputProps={{
              ...params.inputProps,
              required: formData.selectedPrompts.length === 0,
            }}
          />
        )}
      />

      <Autocomplete
        disablePortal
        id='select-completion'
        options={baseVariablesOptions}
        ListboxProps={{ style: { maxHeight: '130px' } }}
        sx={{
          ...sharedFormStyles,
        }}
        value={formData.selectedCompletion}
        onChange={(_, newValue) => {
          setFormData({ ...formData, selectedCompletion: newValue })
        }}
        renderInput={(params) => <TextField {...params} label='Completion Field' required />}
      />

      <TextField
        type='number'
        label='Number of Samples'
        value={formData.nSamples}
        sx={{
          ...sharedFormStyles,
        }}
        required
        onChange={(e) => {
          const value = parseInt(e.target.value, 10)
          setFormData({ ...formData, nSamples: value })
        }}
        inputProps={{ min: 1 }}
      />
    </Stack>
  )

  return (
    <MUIThemeProvider>
      <Dialog
        open={open}
        onClose={handleClose}
        sx={{
          '& .MuiDialog-paper': {
            width: '600px',
            maxWidth: 'none',
            height: '650px',
            padding: '2px',
          },
        }}>
        <DialogTitle id='fine-tune-llm'>Base Fine-Tuning</DialogTitle>
        <Box
          sx={{
            marginLeft: '15px',
          }}>
          <Tabs
            value={activeTab}
            onChange={(_, v) => {
              setActiveTab(v)
            }}
            aria-label='fine-tune tabs'>
            <Tab label='Submission' {...getTabA11yProps(0)} />
            <Tab label='History' {...getTabA11yProps(1)} />
          </Tabs>
        </Box>

        <CustomTabPanel value={activeTab} index={0}>
          <form onSubmit={handleFineTuneBase}>
            <DialogContent dividers>{renderSubmissionContent()}</DialogContent>

            <DialogActions>
              <Stack
                direction='row'
                gap={2}
                sx={{
                  justifyContent: 'center',
                  alignContent: 'center',
                  alignItems: 'center',
                  height: '50px',
                }}>
                <Button
                  variant='outlined'
                  onClick={handleClose}
                  sx={{
                    fontWeight: 'normal',
                    '&:hover': {
                      backgroundColor: 'white',
                    },
                  }}>
                  Cancel
                </Button>

                <LoadingButton
                  variant='contained'
                  disabled={isLoadingFineTuningModels}
                  loading={isSubmittingFineTuneRequest}
                  type='submit'
                  sx={{
                    fontWeight: 'normal',
                  }}>
                  Fine-tune
                </LoadingButton>
              </Stack>
            </DialogActions>
          </form>
        </CustomTabPanel>

        <CustomTabPanel value={activeTab} index={1}>
          <Divider />
          <DialogContent>{renderHistoryTab()}</DialogContent>
        </CustomTabPanel>
      </Dialog>
    </MUIThemeProvider>
  )
}

export { FineTuneDialog }
