import {z} from 'zod';
import {ScoringMode, SignalsWithWeights} from './market';
import {findSignalGroup} from './signalGroups';
import {
  ScoringSignal,
  isKeyplayScoringSignal,
  ScoringSignalSchema,
} from './signals';
import {assertNever} from './util';
import {AccountTierSchema} from './scoredAccounts';
import {SignalDefinition} from './signalDefinition';

export const ModelTestResultsSchema = z
  .object({
    domain: z.string(),
    groups: z.string().array(),
    tier: AccountTierSchema,
    score: z.number(),
  })
  .array();
export type ModelTestResults = z.infer<typeof ModelTestResultsSchema>;

// https://github.com/colinhacks/zod/issues/2623
export const TierCountsSchema = z.object({
  A: z.number(),
  B: z.number(),
  C: z.number(),
  D: z.number(),
});
export type TierCounts = z.infer<typeof TierCountsSchema>;

export const ModelTestGroupResultsSchema = z
  .object({
    group: z.string(),
    tierCounts: TierCountsSchema,
  })
  .array();
export type ModelTestGroupResults = z.infer<typeof ModelTestGroupResultsSchema>;

export const ModelTestSignalCountsSchema = z
  .object({
    signal: ScoringSignalSchema,
    groupCounts: z.record(z.number()),
  })
  .array();

export type ModelTestSignalCounts = z.infer<typeof ModelTestSignalCountsSchema>;

const MaxScore = 100;
export function getOverallFit({
  signalScore,
  similarityScore,
  scoringMode,
}: {
  signalScore: number;
  similarityScore: number;
  scoringMode: ScoringMode;
}) {
  const type = scoringMode.type;

  switch (type) {
    case 'highest_value':
      return Math.max(signalScore, similarityScore);
    case 'only_signal_score':
      return signalScore;
    case 'only_similarity':
      return similarityScore;
    case 'signal_score_plus_bonus_points':
      return Math.min(MaxScore, signalScore + (similarityScore >= 65 ? 20 : 0));
    case 'similarity_plus_bonus_points':
      return Math.min(MaxScore, similarityScore + (signalScore >= 65 ? 20 : 0));
    default:
      assertNever(type);
  }
}

export function getScoringModelMaxPoints({
  signalDefinitions,
  signalsWithWeights,
}: {
  signalsWithWeights: SignalsWithWeights;
  signalDefinitions: SignalDefinition[];
}) {
  let maxPoints = 0;
  const seenGroups = new Set<string>();
  const sortedSignals = signalsWithWeights.sort((a, b) => b.weight - a.weight);

  for (const {signal, weight} of sortedSignals) {
    // zero/negative weights don't impact max score
    if (weight <= 0) {
      continue;
    }

    const signalGroup = findSignalGroup({signal, signalDefinitions});

    if (!isKeyplayScoringSignal(signal)) {
      const signalDefinition = signalDefinitions.find((s) => s.id === signal);

      // there's an edge case where a custom signal can be bonus points only,
      // and also be grouped -- if that happens the group takes priority
      if (signalDefinition?.bonusPointsOnly && !signalGroup) {
        continue;
      }
    }

    // only count the highest weight of each group
    if (signalGroup) {
      if (seenGroups.has(signalGroup)) {
        continue;
      }
      seenGroups.add(signalGroup);
    }

    maxPoints += weight;
  }

  return maxPoints;
}

export function getAccountScore({
  signalsWithWeights,
  signals,
  signalDefinitions,
}: {
  signalsWithWeights: SignalsWithWeights;
  signals: ScoringSignal[];
  signalDefinitions: SignalDefinition[];
}) {
  const scoreBreakdown: {
    signals: ScoringSignal[];
    weight: number;
    group: string | null;
  }[] = [];
  const matchedSignals: ScoringSignal[] = [];
  const maxPoints = getScoringModelMaxPoints({
    signalsWithWeights,
    signalDefinitions,
  });

  let points = 0;
  const sortedSignals = signalsWithWeights.sort((a, b) => b.weight - a.weight);
  for (const {signal, weight} of sortedSignals) {
    if (!signals.includes(signal)) {
      continue;
    }

    matchedSignals.push(signal);

    // for zero/positive signals, group signals together and only count the highest
    // weight of each group
    const signalGroup =
      weight >= 0 ? findSignalGroup({signal, signalDefinitions}) : null;

    if (signalGroup) {
      const existingGroup = scoreBreakdown.find((s) => s.group === signalGroup);
      if (existingGroup) {
        existingGroup.signals.push(signal);
        continue;
      }
    }

    scoreBreakdown.push({
      signals: [signal],
      weight: (weight / maxPoints) * 100,
      group: signalGroup,
    });

    points += weight;
  }

  if (maxPoints <= 0) {
    return {matchedSignals, scoreBreakdown, totalScore: 0, rawTotalScore: 0};
  }

  const rawTotalScore = (points / maxPoints) * 100;
  const totalScore = Math.max(0, Math.min(100, rawTotalScore));

  return {matchedSignals, scoreBreakdown, totalScore, rawTotalScore};
}
