import type { ArrayElement } from '@lib/types/arrayElement'
import type { TrialAssociation } from '@modules/trials/types/TrialAssociation'
import type { ArmGroup, ArmGroupType, Intervention } from '@prisma/client'

export type TreatmentItem = {
  description?: string
  intervention?: Intervention
  name: string
  tag?: string
}

export type TreatmentGroup = {
  description: string | null
  name: string
  treatments: TreatmentItem[]
  type: ArmGroupType
}

type ArmGroupNarrow = Pick<ArmGroup, 'description' | 'label' | 'type'>

function armGroupToTreatmentItem(armGroup: ArmGroupNarrow): TreatmentItem {
  return {
    description: armGroup.description ?? undefined,
    name: armGroup.label,
  }
}

export function interventionToTreatmentItem(
  intervention: ArrayElement<TrialAssociation['interventions']>,
): TreatmentItem {
  return {
    description: intervention.description ?? undefined,
    intervention,
    name: intervention.name,
    tag: intervention.type,
  }
}

export function isExperimentalArmGroup(armGroup: ArmGroupNarrow) {
  return armGroup.type === 'Experimental' || armGroup.type === 'Other'
}

export function isControlArmGroup(armGroup: ArmGroupNarrow) {
  return (
    armGroup.type === 'PlaceboComparator' ||
    armGroup.type === 'ShamComparator' ||
    armGroup.type === 'ActiveComparator' ||
    armGroup.type === 'NoIntervention'
  )
}

export function getNlpClassifiedPrimaryTreatmentItem(
  trial: Pick<TrialAssociation, 'interventions' | 'armGroups'>,
) {
  const primaryIntervention = getPrimaryInterventionFromTrial(trial)
  if (primaryIntervention) {
    return interventionToTreatmentItem(primaryIntervention)
  }
  const primaryArmGroup = trial.armGroups.find(
    (armGroup) => armGroup.isPrimaryTreatment,
  )
  if (primaryArmGroup) {
    return armGroupToTreatmentItem(primaryArmGroup)
  }
}

export function getPrimaryInterventionFromTrial(
  trial: Pick<TrialAssociation, 'interventions'>,
) {
  return trial.interventions.find(
    (intervention) => intervention.isPrimaryTreatment,
  )
}

const isInterventionInArmGroup = (
  intervention: ArrayElement<TrialAssociation['interventions']>,
  armGroupLabel: string,
) => {
  return intervention.armGroupsJoin.find(
    (interventionGroup) => interventionGroup.armGroup.label === armGroupLabel,
  )
}

export function armGroupToTreatmentGroup({
  armGroup,
  interventions,
}: {
  armGroup: ArmGroupNarrow
  interventions: TrialAssociation['interventions']
}) {
  const relevantInterventions = interventions
    .filter((intervention) =>
      isInterventionInArmGroup(intervention, armGroup.label),
    )
    .map(interventionToTreatmentItem)

  // If the ARM group is experimental only include if there are interventions;
  // if the ARM group is control and there are no interventions, include it anyways
  let treatments: TreatmentItem[]

  if (relevantInterventions.length > 0) {
    treatments = relevantInterventions
  } else if (isExperimentalArmGroup(armGroup)) {
    treatments = []
  } else {
    treatments = [armGroupToTreatmentItem(armGroup)]
  }

  return {
    description: armGroup.description,
    name: armGroup.label,
    treatments,
    type: armGroup.type,
  }
}

export function getExperimentalTreatmentGroups(trial: TrialAssociation) {
  const experimentalArmGroups = trial.armGroups.filter(isExperimentalArmGroup)
  const experimentalTreatmentGroups = experimentalArmGroups.map((eag) =>
    armGroupToTreatmentGroup({
      armGroup: eag,
      interventions: trial.interventions,
    }),
  )

  return experimentalTreatmentGroups
}

export function getControlTreatmentGroups({
  trial,
}: {
  trial: TrialAssociation
}) {
  const controlArmGroups = trial.armGroups.filter(isControlArmGroup)
  const controlTreatmentGroups = controlArmGroups.map((cag) =>
    armGroupToTreatmentGroup({
      armGroup: cag,
      interventions: trial.interventions,
    }),
  )

  return controlTreatmentGroups
}

export default function parseTreatments(trial: TrialAssociation) {
  const experimentalTreatments = getExperimentalTreatmentGroups(trial)
  const controlTreatments = getControlTreatmentGroups({
    trial,
  })

  return { controlTreatments, experimentalTreatments }
}
