import { useContext, useQuery } from '@invisible/trpc/client'
import { Button } from '@invisible/ui/button'
import { TextArea } from '@invisible/ui/form'
import { Input } from '@invisible/ui/input'
import { MultiSelect } from '@invisible/ui/multi-select'
import { Wizard as WizardSchemas } from '@invisible/ultron/zod'
import { flatten, groupBy } from 'lodash/fp'
import pMap from 'p-map'
import { ChangeEvent, FC, useMemo, useState } from 'react'

import { useBaseRunCreate } from '../../hooks/useBaseRunCreate'
import { useBaseRunCreateMany } from '../../hooks/useBaseRunCreateManyWizardAction'
import { TBaseRunQueryData } from '../../hooks/useGetBaseRuns'
import { BASES, MODELS } from './constants'
import { PromptResponse } from './PromptResponse'

type TBaseRun = TBaseRunQueryData['items'][number]
type TStepRun = TBaseRun['stepRuns'][number]

interface IProps extends WizardSchemas.WACConfig.TSchema {
  baseRun: TBaseRun
  stepRun: TStepRun
  isReadOnly: boolean
}

// eslint-disable-next-line @typescript-eslint/ban-types
const RlhfOperateWAC: FC<IProps> = ({ baseRun, stepRun, isReadOnly }) => {
  const reactQueryContext = useContext()
  const [promptText, setPromptText] = useState('')
  const [numberOfResponses, setNumberOfResponses] = useState(1)
  const [activeModels, setActiveModels] = useState(Object.keys(MODELS) as (keyof typeof MODELS)[])
  const [getResponseLoading, setGetResponseLoading] = useState(false)
  const { data: prompts } = useQuery([
    'baseRun.findChildBaseRuns',
    {
      baseId: BASES.PROMPTS.ID,
      parentBaseRunId: baseRun.id,
    },
  ])
  // Parses prompts base runs into a typed object with its base run variables
  const normalizedPrompts = useMemo(
    () =>
      (prompts ?? [])
        .map((prompt) => ({
          id: prompt.id,
          text: prompt.baseRunVariables.find(
            (variable) => variable.baseVariable.id === BASES.PROMPTS.BASE_VARIABLES.PROMPT_TEXT
          )?.value as string,
          index: prompt.baseRunVariables.find(
            (variable) => variable.baseVariable.id === BASES.PROMPTS.BASE_VARIABLES.PROMPT_INDEX
          )?.value as number,
          acceptedResponse: prompt.baseRunVariables.find(
            (variable) =>
              variable.baseVariable.id === BASES.PROMPTS.BASE_VARIABLES.ACCEPTED_RESPONSE
          )?.value as string,
          acceptedModel: prompt.baseRunVariables.find(
            (variable) => variable.baseVariable.id === BASES.PROMPTS.BASE_VARIABLES.ACCEPTED_MODEL
          )?.value as string,
        }))
        .sort((a, b) => a.index - b.index),
    [prompts]
  )
  const unsubmittedPrompt = normalizedPrompts?.find((prompt) => !prompt.acceptedResponse)

  const { data: activePromptResponses } = useQuery(
    [
      'baseRun.findChildBaseRuns',
      {
        baseId: BASES.RESPONSES.ID,
        parentBaseRunId: unsubmittedPrompt?.id as string,
      },
    ],
    {
      enabled: !!unsubmittedPrompt,
    }
  )

  const normalizedActivePromptResponses = useMemo(() => {
    const sortedResponses = (activePromptResponses ?? [])
      .map((response) => ({
        id: response.id,
        text: response.baseRunVariables.find(
          (variable) => variable.baseVariable.id === BASES.RESPONSES.BASE_VARIABLES.TEXT
        )?.value as string,
        index: response.baseRunVariables.find(
          (variable) => variable.baseVariable.id === BASES.RESPONSES.BASE_VARIABLES.INDEX
        )?.value as number,
        category: response.baseRunVariables.find(
          (variable) => variable.baseVariable.id === BASES.RESPONSES.BASE_VARIABLES.CATEGORY
        )?.value as string,
        score: response.baseRunVariables.find(
          (variable) => variable.baseVariable.id === BASES.RESPONSES.BASE_VARIABLES.SCORE
        )?.value as number,
        model: response.baseRunVariables.find(
          (variable) => variable.baseVariable.id === BASES.RESPONSES.BASE_VARIABLES.MODEL
        )?.value as number,
        rationale: response.baseRunVariables.find(
          (variable) => variable.baseVariable.id === BASES.RESPONSES.BASE_VARIABLES.RATIONALE
        )?.value as string,
      }))
      .sort((a, b) => a.index - b.index)

    const groupedResponses = groupBy('model', sortedResponses)
    return { flattened: sortedResponses, grouped: groupedResponses }
  }, [activePromptResponses])

  // Checks if all responses have a category and score before allowing selection
  const canSubmit = normalizedActivePromptResponses.flattened.every(
    (response) => response.category && response.score && response.rationale
  )

  const { mutateAsync: createBaseRun, isLoading: createBaseRunLoading } = useBaseRunCreate()
  const { mutateAsync: createBaseRuns, isLoading: createBaseRunsLoading } = useBaseRunCreateMany({
    onSettled: () => {
      reactQueryContext.invalidateQueries('baseRun.findChildBaseRuns')
    },
  })

  const handlePromptSubmission = async () => {
    setGetResponseLoading(true)
    const modelResponses = await pMap(activeModels, async (model) => {
      const responses = await MODELS[model].executor({
        prompt: promptText,
        response_count: numberOfResponses,
      })
      return responses.map((response, index) => ({
        message: response,
        model: MODELS[model].name,
        index: index + 1,
      }))
    })
    setGetResponseLoading(false)

    const prompt = await createBaseRun({
      baseId: BASES.PROMPTS.ID,
      stepRunId: stepRun.id,
      parentBaseRunId: baseRun.id,
      initialValues: [
        {
          baseVariableId: BASES.PROMPTS.BASE_VARIABLES.PROMPT_TEXT,
          value: promptText,
        },
        {
          baseVariableId: BASES.PROMPTS.BASE_VARIABLES.PROMPT_INDEX,
          value: (prompts?.length ?? 0) + 1,
        },
      ],
    })

    // Create responses
    await createBaseRuns({
      baseId: BASES.RESPONSES.ID,
      parentBaseRunId: prompt.id,
      initialValuesArray: flatten(modelResponses).map((response) => [
        {
          baseVariableId: BASES.RESPONSES.BASE_VARIABLES.TEXT,
          value: response.message,
        },
        {
          baseVariableId: BASES.RESPONSES.BASE_VARIABLES.INDEX,
          value: response.index,
        },
        {
          baseVariableId: BASES.RESPONSES.BASE_VARIABLES.MODEL,
          value: response.model,
        },
      ]),
      sourceStepRunId: stepRun.id,
    })
    setPromptText('')
  }

  return (
    <div className='box-border h-full w-full bg-white p-2'>
      <div className='text-header pt-2 text-xl font-bold'>Grade Model Responses</div>
      <div className='text-paragraphs py-2 text-sm'>
        Assign a score to model responses based on accuracy and completeness.
      </div>
      <div className='relative box-border flex max-h-[calc(100%-100px)] w-full flex-col justify-between gap-4 rounded-md border border-solid border-gray-300 p-3'>
        <div className='flex items-center justify-end gap-2'>
          <MultiSelect
            name='Select models to compare'
            width='500px'
            onAdd={(selected) =>
              setActiveModels((prev) => [...prev, selected.key as keyof typeof MODELS])
            }
            onRemove={(removed) =>
              setActiveModels((prev) => prev.filter((model) => model !== removed.key))
            }
            options={Object.values(MODELS).map((model) => ({ key: model.id, value: model.name }))}
            defaultKeys={activeModels}
            disabled={isReadOnly}
          />

          <Input
            type='number'
            placeholder='Responses per model'
            width='200px'
            defaultValue={numberOfResponses}
            disabled={isReadOnly}
            onChange={(e: ChangeEvent<HTMLInputElement>) =>
              setNumberOfResponses(Number(e.target.value))
            }
          />
        </div>

        <div className='flex h-full w-full flex-col gap-2 overflow-auto'>
          {normalizedPrompts?.map((prompt) => (
            <div>
              <div className='relative  w-fit max-w-[50%] rounded-xl bg-indigo-100 py-2 px-4 text-sm shadow'>
                {prompt.text}
              </div>
              {prompt.acceptedResponse ? (
                <div className='max-w-1/2 relative ml-auto mt-2 w-fit max-w-[50%] whitespace-pre-line rounded-xl bg-pink-100 py-2 px-4 text-sm shadow'>
                  {prompt.acceptedResponse.trim()}
                </div>
              ) : null}
            </div>
          ))}

          {unsubmittedPrompt && normalizedActivePromptResponses.flattened ? (
            <div className='mt-4 flex flex-wrap gap-5 gap-y-10'>
              {Object.keys(normalizedActivePromptResponses.grouped).map((model) => (
                <div key={model}>
                  <div className='font-bold'>{model}</div>

                  <div className='mt-3 flex w-[600px] flex-col gap-2'>
                    {(normalizedActivePromptResponses.grouped[model] ?? []).map((response) => (
                      <PromptResponse
                        stepRunId={stepRun.id}
                        response={response}
                        activePromptId={unsubmittedPrompt.id}
                        key={model + response.id}
                        canSubmit={unsubmittedPrompt && canSubmit}
                        wizardIsReadOnly={isReadOnly}
                      />
                    ))}
                  </div>
                </div>
              ))}
            </div>
          ) : null}

          {getResponseLoading || createBaseRunLoading || createBaseRunsLoading ? (
            <div className='text-paragraphs bg-weak-3 border-main mx-auto mt-4 w-fit rounded-md border border-solid p-2 text-xs'>
              Waiting for response...
            </div>
          ) : null}
        </div>
        <div className='flex w-full items-center gap-3'>
          <TextArea
            placeholder='Enter your prompt here...'
            value={promptText}
            onChange={(e) => setPromptText(e.target.value)}
            className='!w-[900px] resize-none !rounded'
            disabled={
              Boolean(unsubmittedPrompt) ||
              createBaseRunLoading ||
              getResponseLoading ||
              createBaseRunsLoading ||
              isReadOnly
            }
          />
          <Button
            icon='RocketFilledIcon'
            size='md'
            variant='primary'
            shape='square'
            onClick={handlePromptSubmission}
            disabled={
              createBaseRunLoading ||
              getResponseLoading ||
              createBaseRunsLoading ||
              !promptText ||
              Boolean(unsubmittedPrompt) ||
              isReadOnly
            }
          />
        </div>
      </div>
    </div>
  )
}

export { RlhfOperateWAC }
