import { useWizardState } from '@invisible/common/components/providers/active-wizard-provider'
import { useContext, useMutation, useQuery } from '@invisible/trpc/client'
import { Button as IconButton } from '@invisible/ui/button'
import { theme } from '@invisible/ui/mui-theme-v2'
import { inferQueryOutput } from '@invisible/ultron/trpc/server'
import { Wizard as WizardSchemas } from '@invisible/ultron/zod'
import AccountCircleIcon from '@mui/icons-material/AccountCircle'
import DeleteIcon from '@mui/icons-material/Delete'
import SmartToyOutlinedIcon from '@mui/icons-material/SmartToyOutlined'
import Button from '@mui/material/Button'
import CircularProgress from '@mui/material/CircularProgress'
import { ThemeProvider } from '@mui/material/styles'
import TextField from '@mui/material/TextField'
import { flatten, groupBy } from 'lodash/fp'
import pMap from 'p-map'
import pTimes from 'p-times'
import { useEffect, useMemo, useState } from 'react'
import Markdown from 'react-markdown'

import { useBaseRunDeleteWithStepRunReference } from '../../hooks/useBaseRunDeleteWithStepRunReference'
import { TBaseRunQueryData } from '../../hooks/useGetBaseRuns'
import { MouseOverPopover } from './MouseOverPopover'

type TBaseRun = TBaseRunQueryData['items'][number]
type TStepRun = TBaseRun['stepRuns'][number]
type TFindChildBaseRunsData = NonNullable<inferQueryOutput<'baseRun.findChildBaseRuns'>>

interface IProps extends WizardSchemas.WACConfig.TSchema {
  baseRun: TBaseRun
  stepRun: TStepRun
  ragRLHF: WizardSchemas.RagRLHF.TSchema
  isReadOnly: boolean
}

const RagRLHFWAC = ({ baseRun, ragRLHF: config, stepRun, isReadOnly }: IProps) => {
  const reactQueryContext = useContext()
  const [promptText, setPromptText] = useState('')
  const [isLastResponseEmpty, setIsLastResponseEmpty] = useState(false)
  const [getResponseLoading, setGetResponseLoading] = useState(false)

  const { dispatch } = useWizardState()

  const { data: prompts } = useQuery([
    'baseRun.findChildBaseRuns',
    {
      baseId: config?.promptsBaseId as string,
      parentBaseRunId: baseRun.id,
    },
  ])

  // Parses prompts base runs into a typed object with its base run variables
  const normalizedPrompts = useMemo(
    () =>
      (prompts ?? [])
        .map((prompt) => ({
          id: prompt.id,
          text: prompt.baseRunVariables.find(
            (variable) => variable.baseVariable.id === config?.promptTextBaseVariableId
          )?.value as string,
          index: prompt.baseRunVariables.find(
            (variable) => variable.baseVariable.id === config?.promptIndexBaseVariableId
          )?.value as number,
          acceptedResponse: prompt.baseRunVariables.find(
            (variable) => variable.baseVariable.id === config?.promptResponseBaseVariableId
          )?.value as string,
          acceptedModel: prompt.baseRunVariables.find(
            (variable) => variable.baseVariable.id === config?.promptAcceptedModelBaseVariableId
          )?.value as string,
          responseId: prompt.baseRunVariables.find(
            (variable) => variable.baseVariable.id === config?.responseIdBaseVariableId
          )?.value as string,
          context: prompt.baseRunVariables.find(
            (variable) => variable.baseVariable.id === config?.promptResponseContextBaseVariableId
          )?.value as { id: string; citation: string; chunk: string }[],
          createdAt: prompt.createdAt,
        }))
        .sort((a, b) => a.index - b.index),
    [prompts, config]
  )

  const unsubmittedPrompt = normalizedPrompts?.find((prompt) => !prompt.acceptedResponse)

  const { data: activePromptResponses } = useQuery(
    [
      'baseRun.findChildBaseRuns',
      {
        baseId: config?.responsesBaseId as string,
        parentBaseRunId: unsubmittedPrompt?.id as string,
      },
    ],
    {
      enabled: !!unsubmittedPrompt,
    }
  )

  // If there are no unsubmitted prompts, set the RLHF2 in setReadyForSubmit to true, so the Wizard Submit button can be activated.
  useEffect(() => {
    dispatch({
      type: 'setReadyForSubmit',
      key: 'RagRLHF',
      value: !unsubmittedPrompt,
    })
  }, [unsubmittedPrompt])

  const normalizedActivePromptResponses = useMemo(() => {
    const sortedResponses = (activePromptResponses ?? [])
      .map((response) => ({
        id: response.id,
        text: response.baseRunVariables.find(
          (variable) => variable.baseVariable.id === config?.responseTextBaseVariableId
        )?.value as string,
        originalText: response.baseRunVariables.find(
          (variable) => variable.baseVariable.id === config?.responseOriginalTextBaseVariableId
        )?.value as string,
        index: response.baseRunVariables.find(
          (variable) => variable.baseVariable.id === config?.responseIdBaseVariableId
        )?.value as number,
        model: response.baseRunVariables.find(
          (variable) => variable.baseVariable.id === config?.responseModelBaseVariableId
        )?.value as string,
      }))
      .filter((response) => response.text)
      .sort((a, b) => a.index - b.index)

    const groupedResponses = groupBy('model', sortedResponses)
    return { flattened: sortedResponses, grouped: groupedResponses }
  }, [activePromptResponses, config])

  const { mutateAsync: createBaseRun, isLoading: createBaseRunLoading } =
    useMutation('baseRun.create')

  const { mutateAsync: createBaseRuns, isLoading: createBaseRunsLoading } = useMutation(
    'baseRun.createMany',
    {
      onSettled: () => {
        reactQueryContext.invalidateQueries('baseRun.findChildBaseRuns')
      },
    }
  )

  const { mutateAsync: deleteBaseRuns, isLoading: deletingBaseRuns } =
    useBaseRunDeleteWithStepRunReference({
      onSettled: () => {
        reactQueryContext.invalidateQueries('baseRun.findChildBaseRuns')
      },
    })

  const queryModel = async (input: {
    query: string
    model: string
    chatHistory: { id: string; text: string; acceptedResponse: string }[]
    perModelCallCount?: number
    meta?: Record<string, any>
  }) => {
    const count = input?.perModelCallCount ?? 1
    const results = await pTimes(
      count < 1 ? 1 : count,
      async () =>
        await fetch('/api/wacs/rlhf/query-model', {
          method: 'POST',
          headers: {
            'Content-Type': 'application/json',
          },
          body: JSON.stringify({
            chatHistory: input.chatHistory,
            query: input.query,
            model: input.model,
            meta: input.meta,
          }),
        }).then((res) => res.json())
    )
    return results.reduce((acc, curr) => [...acc, ...curr], [])
  }

  const resubmitPrompt = async (failedPrompt: { id: string; text: string }) => {
    dispatch({
      type: 'setReadyForSubmit',
      key: 'RagRLHF',
      value: true,
    })
    setGetResponseLoading(true)
    setIsLastResponseEmpty(false)
    const modelResponses = await pMap(config.models ?? [], async (model) => {
      const metaParams = model.responseMetaParams ?? []
      const responses = await queryModel({
        query: failedPrompt.text,
        model: model.name,
        chatHistory:
          flatten(
            normalizedPrompts?.filter((prompt: { id: string }) => failedPrompt.id !== prompt.id) // Exclude failed prompt
          ) ?? [],
        perModelCallCount: model?.numOfCalls,
        meta: model.params.reduce((acc, param) => ({ ...acc, [param.name]: param.value }), {}),
      })
      return responses.map(
        (response: { text: string; model?: string; meta?: any }, index: number) => ({
          message: response.text,
          model: model.name,
          index: index + 1,
          metaFields: metaParams
            .map((p) => ({
              value: response.meta?.[p.key],
              baseVariableId: p.baseVariableId,
            }))
            .filter((p) => p.baseVariableId),
        })
      )
    })
    setGetResponseLoading(false)

    const flattenModelResponses = flatten(modelResponses)
    await createBaseRun({
      baseId: config?.promptsBaseId as string,
      parentBaseRunId: baseRun.id,
      initialValues: [
        {
          baseVariableId: config?.promptResponseBaseVariableId as string,
          value: flattenModelResponses?.[0].message,
        },
        {
          baseVariableId: config?.promptAcceptedModelBaseVariableId as string,
          value: flattenModelResponses?.[0].model,
        },
        {
          baseVariableId: config?.responseIdBaseVariableId as string,
          value: flattenModelResponses?.[0].id,
        },
        ...((flattenModelResponses?.[0].metaFields ?? []).filter(
          (p: { baseVariableId: string; value: string }) =>
            p.baseVariableId === config?.promptResponseContextBaseVariableId
        ) ?? []),
      ],
    })

    if (modelResponses && flatten(modelResponses).every((res) => res.message)) {
      // Create responses
      await createBaseRuns({
        baseId: config?.responsesBaseId as string,
        parentBaseRunId: failedPrompt.id,
        initialValuesArray: flatten(modelResponses).map((response) => [
          {
            baseVariableId: config?.responseTextBaseVariableId as string,
            value: response.message,
          },
          {
            baseVariableId: config?.responseOriginalTextBaseVariableId as string,
            value: response.message,
          },
          {
            baseVariableId: config?.responseIndexBaseVariableId as string,
            value: response.index,
          },
          {
            baseVariableId: config?.responseModelBaseVariableId as string,
            value: response.model,
          },
          ...response.metaFields,
        ]),
      })
    } else if (modelResponses) {
      setIsLastResponseEmpty(true)
    }
  }

  const handlePromptSubmission = async () => {
    dispatch({
      type: 'setReadyForSubmit',
      key: 'RagRLHF',
      value: true,
    })
    setGetResponseLoading(true)
    setIsLastResponseEmpty(false)
    const modelResponses = await pMap(config.models ?? [], async (model) => {
      const metaParams = model.responseMetaParams ?? []
      const responses = await queryModel({
        query: promptText,
        model: model.name,
        chatHistory: flatten(normalizedPrompts) ?? [],
        perModelCallCount: model?.numOfCalls,
        meta: model.params.reduce((acc, param) => ({ ...acc, [param.name]: param.value }), {}),
      })
      return responses.map(
        (response: { text: string; model?: string; meta?: any }, index: number) => ({
          message: response.text,
          model: model.name,
          index: index + 1,
          metaFields: metaParams
            .map((p) => ({
              value: response.meta?.[p.key],
              baseVariableId: p.baseVariableId,
            }))
            .filter((p) => p.baseVariableId),
        })
      )
    })
    setGetResponseLoading(false)

    const flattenModelResponses = flatten(modelResponses)
    const prompt = await createBaseRun({
      baseId: config?.promptsBaseId as string,
      parentBaseRunId: baseRun.id,
      initialValues: [
        {
          baseVariableId: config?.promptTextBaseVariableId as string,
          value: promptText,
        },
        {
          baseVariableId: config?.promptIndexBaseVariableId as string,
          value: (prompts?.length ?? 0) + 1,
        },
        {
          baseVariableId: config?.promptResponseBaseVariableId as string,
          value: flattenModelResponses?.[0].message,
        },
        {
          baseVariableId: config?.promptAcceptedModelBaseVariableId as string,
          value: flattenModelResponses?.[0].model,
        },
        {
          baseVariableId: config?.responseIdBaseVariableId as string,
          value: flattenModelResponses?.[0].id,
        },
        ...((flattenModelResponses?.[0].metaFields ?? []).filter(
          (p: { baseVariableId: string; value: string }) =>
            p.baseVariableId === config?.promptResponseContextBaseVariableId
        ) ?? []),
      ],
    })

    if (modelResponses && flattenModelResponses.every((res) => res.message)) {
      // Create responses
      await createBaseRuns({
        baseId: config?.responsesBaseId as string,
        parentBaseRunId: prompt.id,
        initialValuesArray: flatten(modelResponses).map((response) => [
          {
            baseVariableId: config?.responseTextBaseVariableId as string,
            value: response.message,
          },
          {
            baseVariableId: config?.responseOriginalTextBaseVariableId as string,
            value: response.message,
          },
          {
            baseVariableId: config?.responseModelBaseVariableId as string,
            value: response.model,
          },
          {
            baseVariableId: config?.responseIndexBaseVariableId as string,
            value: response.index,
          },

          ...response.metaFields,
        ]),
      })
    }

    setPromptText('')
  }

  const handleDeletePrompt = async (prompt: { id: string; responseId: string }) => {
    if (
      !window.confirm(
        "You are about to delete this prompt and it's response. \nThis will also delete subsequent prompts and responses if any. \nAre you sure?"
      )
    )
      return

    const currentPromptCreatedAt = prompts?.find((p) => p.id === prompt.id)?.createdAt ?? ''
    const subsequentPrompts = prompts?.filter(
      (p) => p.createdAt > currentPromptCreatedAt && p.id !== prompt.id
    )
    const promptsResponses = await Promise.all(
      [...(subsequentPrompts || []), prompt]?.map((p) =>
        reactQueryContext.fetchQuery([
          'baseRun.findChildBaseRuns',
          {
            baseId: config?.responsesBaseId as string,
            parentBaseRunId: p.id as string,
          },
        ])
      )
    )
    const baseRunsIdsToDelete = [
      prompt.id,
      ...(subsequentPrompts || []).map((p) => p.id),
      ...flatten(promptsResponses).map((r) => r.id),
    ]
    await deleteBaseRuns({
      baseRunIds: baseRunsIdsToDelete,
      stepRunId: stepRun.id,
    })
    reactQueryContext.queryClient.setQueryData<TFindChildBaseRunsData | undefined>(
      [
        'baseRun.findChildBaseRuns',
        {
          baseId: config?.promptsBaseId as string,
          parentBaseRunId: baseRun.id,
        },
      ],
      (prevData) => {
        if (!prevData) return
        return prevData.filter((baseRun) => !baseRunsIdsToDelete.includes(baseRun.id))
      }
    )
  }

  return (
    <div className='box-border h-full w-full bg-white p-2'>
      <ThemeProvider theme={theme}>
        <div className='text-paragraphs py-2 text-sm'>Your conversations</div>
        <div className='relative box-border flex max-h-[calc(100%-100px)] w-full flex-col items-center justify-between gap-4 rounded-md border border-solid border-gray-300 p-3'>
          <div className='flex h-full max-w-[750px] flex-col gap-2 overflow-auto'>
            {normalizedPrompts?.map((prompt, index) => (
              <div key={index}>
                <div className='flex  items-center gap-3'>
                  <div className='flex min-w-[20px] '>
                    <AccountCircleIcon color='action' fontSize='medium' />
                  </div>
                  <div className={`valid relative w-fit rounded bg-[#E0E7FF] py-1 px-4 text-sm `}>
                    <Markdown
                      className='overflow-auto'
                      components={{
                        p: ({ children }) => <p className='m-0 whitespace-pre-wrap'>{children}</p>,
                      }}>
                      {prompt.text}
                    </Markdown>{' '}
                    {!prompt.acceptedResponse &&
                      !(normalizedActivePromptResponses.flattened.length > 0) && (
                        <span className='flex'>
                          <IconButton
                            icon='RedoIcon'
                            size='md'
                            variant='subtle'
                            shape='square'
                            color='theme'
                            disabled={
                              getResponseLoading ||
                              createBaseRunsLoading ||
                              createBaseRunLoading ||
                              isReadOnly
                            }
                            onClick={() => resubmitPrompt(prompt)}
                          />
                        </span>
                      )}
                  </div>
                  {config.allowDeletePrompt && (
                    <span
                      className='ml-2 cursor-pointer'
                      onClick={() => handleDeletePrompt(prompt)}>
                      {deletingBaseRuns ? (
                        <CircularProgress key={prompt.id} />
                      ) : (
                        <DeleteIcon color='action' />
                      )}
                    </span>
                  )}
                </div>

                {prompt.acceptedResponse ? (
                  <div className='flex items-center gap-3'>
                    <div className='flex min-w-[20px]'>
                      <SmartToyOutlinedIcon color='action' fontSize='medium' />
                    </div>
                    <div
                      className={`max-w-1/2 relative  mt-4 mb-2 w-fit  overflow-auto rounded bg-[#F5F5F7] py-2 px-4 text-sm`}>
                      <Markdown
                        className='overflow-auto'
                        components={{
                          p: ({ children }) => (
                            <p className='m-0 whitespace-pre-wrap'>{children}</p>
                          ),
                        }}>
                        {prompt.acceptedResponse.trim() as string}
                      </Markdown>{' '}
                      <div className='flex gap-1'>
                        {(prompt.context ?? []).map((c, i: number) => (
                          <MouseOverPopover key={c.id} title={`[${i + 1}]`}>
                            <div className='mb-2'>{c.citation}</div>
                            <Markdown
                              className='overflow-auto'
                              components={{
                                p: ({ children }) => (
                                  <p className='m-0 whitespace-pre-wrap'>{children}</p>
                                ),
                              }}>
                              {c.chunk}
                            </Markdown>{' '}
                          </MouseOverPopover>
                        ))}
                      </div>
                    </div>
                  </div>
                ) : null}
              </div>
            ))}

            {getResponseLoading || createBaseRunLoading || createBaseRunsLoading ? (
              <div className='text-paragraphs bg-weak-3 border-main mx-auto mt-4 w-fit rounded-md border border-solid p-2 text-xs'>
                Waiting for response...
              </div>
            ) : null}
          </div>
          {isLastResponseEmpty ? (
            <div className='text-red-strong py-2'>
              Blank response from Model. You can refresh the page & retry or End Conversation.
            </div>
          ) : null}

          <div className='grid w-full max-w-[750px] grid-cols-[20px_85%_10%] items-center gap-3'>
            <AccountCircleIcon color='action' fontSize='medium' className='justify-self-center' />
            <TextField
              placeholder='Enter your prompt here...'
              value={promptText}
              onChange={(e) => setPromptText(e.target.value)}
              multiline
              fullWidth
              maxRows={4}
              minRows={1}
              disabled={
                Boolean(unsubmittedPrompt) ||
                createBaseRunLoading ||
                getResponseLoading ||
                createBaseRunsLoading ||
                isReadOnly
              }
            />

            <Button
              variant='contained'
              onClick={handlePromptSubmission}
              disabled={
                createBaseRunLoading ||
                getResponseLoading ||
                createBaseRunsLoading ||
                !promptText ||
                Boolean(unsubmittedPrompt) ||
                isReadOnly
              }>
              Send
            </Button>
          </div>
        </div>
      </ThemeProvider>
    </div>
  )
}

export { RagRLHFWAC }
