import { useWizardState } from '@invisible/common/components/providers/active-wizard-provider'
import { useContext, useQuery } from '@invisible/trpc/client'
import { Button } from '@invisible/ui/button'
import { TextArea } from '@invisible/ui/form'
import { Wizard as WizardSchemas } from '@invisible/ultron/zod'
import { flatten, groupBy } from 'lodash/fp'
import { compact } from 'lodash/fp'
import pMap from 'p-map'
import { FC, useMemo, useState } from 'react'

import { useBaseRunCreate } from '../../hooks/useBaseRunCreate'
import { useBaseRunCreateMany } from '../../hooks/useBaseRunCreateManyWizardAction'
import { useBaseRunDeleteWithStepRunReference } from '../../hooks/useBaseRunDeleteWithStepRunReference'
import { useBaseRunVariableFindManyByBaseRunId } from '../../hooks/useBaseRunVariableFindManyByBaseRunId'
import { useFirstManualStepForBaseRun } from '../../hooks/useFirstManualStepForBaseRun'
import { TBaseRunQueryData } from '../../hooks/useGetBaseRuns'
import { DEFAULT_MODEL, queryModel } from './constants'
import { PromptEdit } from './PromptEdit'
import { PromptResponse } from './PromptResponse'

type TBaseRun = TBaseRunQueryData['items'][number]
type TStepRun = TBaseRun['stepRuns'][number]

interface IProps extends WizardSchemas.WACConfig.TSchema {
  baseRun: TBaseRun
  stepRun: TStepRun
}

// eslint-disable-next-line @typescript-eslint/ban-types
const CohereOperateWAC: FC<IProps> = ({ baseRun, stepRun, cohereOperate: config }) => {
  const reactQueryContext = useContext()
  const [promptText, setPromptText] = useState('')
  const [isLastResponseEmpty, setIsLastResponseEmpty] = useState(false)
  const [getResponseLoading, setGetResponseLoading] = useState(false)
  const [editPromptInfo, setEditPromptInfo] = useState<Record<string, boolean>>({})
  const [editLastResponse, setEditLastResponse] = useState<Record<string, boolean>>({})
  const { dispatch } = useWizardState()
  const { data: prompts } = useQuery([
    'baseRun.findChildBaseRuns',
    {
      baseId: config?.promptsBaseId as string,
      parentBaseRunId: baseRun.id,
    },
  ])
  const { data: firstManualStepRun } = useFirstManualStepForBaseRun({
    baseRunId: baseRun.id,
  })

  const firstManualStepRunCreatedAt = useMemo(
    () => firstManualStepRun?.createdAt ?? '',
    [firstManualStepRun]
  )

  // Parses prompts base runs into a typed object with its base run variables
  const normalizedPrompts = useMemo(
    () =>
      (prompts ?? [])
        .map((prompt) => ({
          id: prompt.id,
          text: prompt.baseRunVariables.find(
            (variable) => variable.baseVariable.id === config?.promptTextBaseVariableId
          )?.value as string,
          index: prompt.baseRunVariables.find(
            (variable) => variable.baseVariable.id === config?.promptIndexBaseVariableId
          )?.value as number,
          acceptedResponse: prompt.baseRunVariables.find(
            (variable) => variable.baseVariable.id === config?.promptResponseBaseVariableId
          )?.value as string,
          acceptedModel: prompt.baseRunVariables.find(
            (variable) => variable.baseVariable.id === config?.promptAcceptedModelBaseVariableId
          )?.value as string,
          responseId: prompt.baseRunVariables.find(
            (variable) => variable.baseVariable.id === config?.responseIdBaseVariableId
          )?.value as string,
          createdAt: prompt.createdAt,
        }))
        .sort((a, b) => a.index - b.index),
    [prompts, config]
  )

  const unsubmittedPrompt = normalizedPrompts?.find(
    (prompt) => !prompt.acceptedResponse || editLastResponse[prompt.responseId]
  )

  const { data: activePromptResponses } = useQuery(
    [
      'baseRun.findChildBaseRuns',
      {
        baseId: config?.responsesBaseId as string,
        parentBaseRunId: unsubmittedPrompt?.id as string,
      },
    ],
    {
      enabled: !!unsubmittedPrompt,
    }
  )

  const normalizedActivePromptResponses = useMemo(() => {
    const sortedResponses = (activePromptResponses ?? [])
      .map((response) => ({
        id: response.id,
        text: response.baseRunVariables.find(
          (variable) => variable.baseVariable.id === config?.responseTextBaseVariableId
        )?.value as string,
        index: response.baseRunVariables.find(
          (variable) => variable.baseVariable.id === config?.responseIdBaseVariableId
        )?.value as number,
        model: response.baseRunVariables.find(
          (variable) => variable.baseVariable.id === config?.responseModelBaseVariableId
        )?.value as string,
      }))
      .filter((response) => response.text)
      .sort((a, b) => a.index - b.index)

    const groupedResponses = groupBy('model', sortedResponses)
    return { flattened: sortedResponses, grouped: groupedResponses }
  }, [activePromptResponses, config])

  const activeModels = [config?.model ?? DEFAULT_MODEL]

  const { mutateAsync: createBaseRun, isLoading: createBaseRunLoading } = useBaseRunCreate()

  const { mutateAsync: deleteBaseRuns, isLoading: deletingBaseRuns } =
    useBaseRunDeleteWithStepRunReference({
      onSettled: () => {
        reactQueryContext.invalidateQueries('baseRun.findChildBaseRuns')
      },
    })

  const { mutateAsync: createBaseRuns, isLoading: createBaseRunsLoading } = useBaseRunCreateMany({
    onSettled: () => {
      reactQueryContext.invalidateQueries('baseRun.findChildBaseRuns')
    },
  })

  const responsesBaseRunVariables = useBaseRunVariableFindManyByBaseRunId({
    baseRunIds: compact(activePromptResponses?.map((res) => res.id) ?? []),
  })

  const promptBaseRunVariables = useBaseRunVariableFindManyByBaseRunId({
    baseRunIds: compact(normalizedPrompts?.map((prompt) => prompt.id) ?? []),
  })

  const resubmitPrompt = async (failedPrompt: { id: string; text: string }) => {
    dispatch({
      type: 'setReadyForSubmit',
      key: 'cohereOperate',
      value: false,
    })
    setGetResponseLoading(true)
    setIsLastResponseEmpty(false)
    const modelResponses = await pMap(activeModels, async (model) => {
      const responses = await queryModel({
        model,
        prompt: failedPrompt.text,
        chatHistory:
          normalizedPrompts?.filter((prompt: { id: string }) => failedPrompt.id !== prompt.id) ?? // Exclude failed prompt
          [],
        callCount: config?.numOfModelCalls,
      })
      return responses.map((response: { text: string }, index: number) => ({
        message: response.text,
        model: model,
        index: index + 1,
      }))
    })
    setGetResponseLoading(false)

    if (modelResponses && flatten(modelResponses).every((res) => res.message)) {
      // Create responses
      await createBaseRuns({
        baseId: config?.responsesBaseId as string,
        parentBaseRunId: failedPrompt.id,
        initialValuesArray: flatten(modelResponses).map((response) => [
          {
            baseVariableId: config?.responseTextBaseVariableId as string,
            value: response.message,
          },
          {
            baseVariableId: config?.responseIdBaseVariableId as string,
            value: response.index,
          },
          {
            baseVariableId: config?.responseModelBaseVariableId as string,
            value: response.model,
          },
        ]),
        sourceStepRunId: stepRun.id,
      })
    } else if (modelResponses) {
      setIsLastResponseEmpty(true)
    }
  }

  const handlePromptSubmission = async () => {
    dispatch({
      type: 'setReadyForSubmit',
      key: 'cohereOperate',
      value: false,
    })
    setGetResponseLoading(true)
    setIsLastResponseEmpty(false)
    const modelResponses = await pMap(activeModels, async (model) => {
      const responses = await queryModel({
        model,
        prompt: promptText,
        chatHistory: normalizedPrompts ?? [],
        callCount: config?.numOfModelCalls,
      })
      return responses.map((response: { text: string }, index: number) => ({
        message: response.text,
        model: model,
        index: index + 1,
      }))
    })
    setGetResponseLoading(false)

    const prompt = await createBaseRun({
      baseId: config?.promptsBaseId as string,
      stepRunId: stepRun.id,
      parentBaseRunId: baseRun.id,
      initialValues: [
        {
          baseVariableId: config?.promptTextBaseVariableId as string,
          value: promptText,
        },
        {
          baseVariableId: config?.promptIndexBaseVariableId as string,
          value: (prompts?.length ?? 0) + 1,
        },
      ],
    })

    if (modelResponses && flatten(modelResponses).every((res) => res.message)) {
      // Create responses
      await createBaseRuns({
        baseId: config?.responsesBaseId as string,
        parentBaseRunId: prompt.id,
        initialValuesArray: flatten(modelResponses).map((response) => [
          {
            baseVariableId: config?.responseTextBaseVariableId as string,
            value: response.message,
          },
          {
            baseVariableId: config?.responseModelBaseVariableId as string,
            value: response.model,
          },
        ]),
        sourceStepRunId: stepRun.id,
      })
    } else if (modelResponses) {
      setIsLastResponseEmpty(true)
    }

    setPromptText('')
  }

  const handleDeletePrompt = async (prompt: { id: string; responseId: string }) => {
    if (!window.confirm("You are about to delete this prompt and it's response. Are you sure?"))
      return

    const responses = await reactQueryContext.fetchQuery([
      'baseRun.findChildBaseRuns',
      {
        baseId: config?.responsesBaseId as string,
        parentBaseRunId: prompt.id as string,
      },
    ])
    const baseRunsIdsToDelete = [prompt.id, ...responses.map((r) => r.id)]
    await deleteBaseRuns({
      stepRunId: stepRun.id,
      baseRunIds: baseRunsIdsToDelete,
    })
  }

  return (
    <div className='box-border h-full w-full bg-white p-2'>
      <div className='text-paragraphs py-2 text-sm'>Your conversations</div>
      <div className='relative box-border flex max-h-[calc(100%-100px)] w-full flex-col justify-between gap-4 rounded-md border border-solid border-gray-300 p-3'>
        <div className='flex h-full w-full flex-col gap-2 overflow-auto'>
          {normalizedPrompts?.map((prompt, index) => (
            <div key={index}>
              {editPromptInfo[prompt.id] ? (
                <PromptEdit
                  stepRunId={stepRun.id}
                  parentBaseRunId={baseRun.id}
                  baseRunVariables={(promptBaseRunVariables ?? []).filter(
                    (b) => b.baseRunId === prompt.id
                  )}
                  editMode
                  prompt={prompt}
                  wacConfig={config}
                  handleClose={() =>
                    setEditPromptInfo({
                      [prompt.id]: !editPromptInfo[prompt.id],
                    })
                  }
                />
              ) : (
                <div className='flex items-center'>
                  <div className='relative   w-fit max-w-[50%] rounded-xl bg-indigo-100 py-2 px-4 text-sm shadow'>
                    <span>{prompt.text}</span>{' '}
                    {!prompt.acceptedResponse &&
                      !(normalizedActivePromptResponses.flattened.length > 0) && (
                        <span className='flex'>
                          <Button
                            icon='RedoIcon'
                            size='md'
                            variant='subtle'
                            shape='square'
                            color='theme'
                            disabled={
                              getResponseLoading || createBaseRunsLoading || createBaseRunLoading
                            }
                            onClick={() => resubmitPrompt(prompt)}
                          />
                        </span>
                      )}
                  </div>
                  {prompt.createdAt > firstManualStepRunCreatedAt && (
                    <>
                      <span
                        className='ml-2 cursor-pointer'
                        onClick={() =>
                          setEditPromptInfo({
                            [prompt.id]: !editPromptInfo[prompt.id],
                          })
                        }>
                        Edit
                      </span>
                      <span
                        className='ml-2 cursor-pointer'
                        onClick={() => handleDeletePrompt(prompt)}>
                        {deletingBaseRuns ? 'Deleting..' : 'Delete'}
                      </span>
                    </>
                  )}
                </div>
              )}

              {prompt.acceptedResponse ? (
                <div className='flex items-center'>
                  {editLastResponse[prompt.responseId] ? (
                    <div className='mt-3 flex w-[600px] flex-col gap-2'>
                      {Object.keys(normalizedActivePromptResponses.grouped).map((model) => (
                        <div key={model}>
                          <div className='mt-3 flex w-[600px] flex-col gap-2'>
                            {(normalizedActivePromptResponses.grouped[model] ?? []).map(
                              (response, index) => (
                                <PromptResponse
                                  response={{
                                    ...response,
                                    index,
                                    text:
                                      response.id === prompt.responseId
                                        ? prompt.acceptedResponse
                                        : response.text,
                                  }}
                                  activePromptId={prompt.id}
                                  editMode={response.id === prompt.responseId}
                                  key={model + response.id}
                                  wacConfig={config}
                                  baseRunVariables={(responsesBaseRunVariables ?? []).filter(
                                    (b) => b.baseRunId === response.id
                                  )}
                                  stepRunId={stepRun.id}
                                  handleClose={() =>
                                    setEditLastResponse({
                                      [prompt.responseId]: !editLastResponse[prompt.responseId],
                                    })
                                  }
                                />
                              )
                            )}
                          </div>
                        </div>
                      ))}
                    </div>
                  ) : (
                    <div className='max-w-1/2 relative ml-auto mt-2 w-fit max-w-[50%] whitespace-pre-line rounded-xl bg-pink-100 py-2 px-4 text-sm shadow'>
                      {prompt.acceptedResponse.trim()}{' '}
                    </div>
                  )}
                  {!editLastResponse[prompt.responseId] &&
                  prompt.createdAt > firstManualStepRunCreatedAt ? (
                    <span
                      className='ml-2 cursor-pointer'
                      onClick={() =>
                        setEditLastResponse({
                          [prompt.responseId]: !editLastResponse[prompt.responseId],
                        })
                      }>
                      Edit
                    </span>
                  ) : null}
                </div>
              ) : null}
            </div>
          ))}

          {unsubmittedPrompt &&
          !Object.values(editLastResponse).some((value) => value) &&
          normalizedActivePromptResponses.flattened ? (
            <div className='mt-4 flex flex-wrap gap-5 gap-y-10'>
              {Object.keys(normalizedActivePromptResponses.grouped).map((model) => (
                <div key={model}>
                  <div className='mt-3 flex w-[600px] flex-col gap-2'>
                    {(normalizedActivePromptResponses.grouped[model] ?? []).map(
                      (response, index) => (
                        <PromptResponse
                          response={{ ...response, index }}
                          activePromptId={unsubmittedPrompt.id}
                          key={model + response.id}
                          wacConfig={config}
                          stepRunId={stepRun.id}
                          baseRunVariables={(responsesBaseRunVariables ?? []).filter(
                            (b) => b.baseRunId === response.id
                          )}
                        />
                      )
                    )}
                  </div>
                </div>
              ))}
            </div>
          ) : null}

          {getResponseLoading || createBaseRunLoading || createBaseRunsLoading ? (
            <div className='text-paragraphs bg-weak-3 border-main mx-auto mt-4 w-fit rounded-md border border-solid p-2 text-xs'>
              Waiting for response...
            </div>
          ) : null}
        </div>
        {isLastResponseEmpty ? (
          <div className='text-red-strong text-paragraphs py-2'>
            Blank response from Model. You can refresh the page & retry or End Conversation.
          </div>
        ) : null}
        <div className='flex w-full items-center gap-3'>
          <TextArea
            placeholder='Enter your prompt here...'
            value={promptText}
            onChange={(e) => setPromptText(e.target.value)}
            className='!w-[900px] resize-none !rounded sm:!w-[200px] md:!w-[500px]'
            disabled={
              Boolean(unsubmittedPrompt) ||
              createBaseRunLoading ||
              getResponseLoading ||
              createBaseRunsLoading
            }
          />
          <Button
            icon='RocketFilledIcon'
            size='md'
            variant='primary'
            shape='square'
            onClick={handlePromptSubmission}
            disabled={
              createBaseRunLoading ||
              getResponseLoading ||
              createBaseRunsLoading ||
              !promptText ||
              Boolean(unsubmittedPrompt)
            }
          />
        </div>
      </div>
    </div>
  )
}

export { CohereOperateWAC }
