import {StoreApi, createStore, useStore} from 'zustand';
import {devtools} from 'zustand/middleware';
import produce from 'immer';
import _ from 'lodash';
import {
  ScoringSignal,
  CoreScoringSignals,
  OptionalScoringSignal,
  CustomSignal,
} from '../../../shared/signals';
import {SignalsWithWeights} from '../../../shared/market';
import {StrictExclude, entries} from '../../../shared/util';
import {track} from '../../../analytics';
import {useMetadata} from '../../../context/MetadataContext';
import {useEffect} from 'react';
import {useMarketFromContext} from '../MarketProvider';

export type BucketType = 'positive' | 'negative' | 'unassigned';
export type SignalWeights = {[signal in ScoringSignal]?: number};

/*
 * Manages the client-side state for the drag-and-drop scoring model builder.
 *
 * We have a separate query that polls the server, getting the latest model, and informing our client-side
 * state using updateWithLatestModelOnServer. This state keeps track of client-side modifications, and
 * the slightly interesting logic is has is that it'll retain the older scoring model it had from the server
 * at the time those modifications were made. In the code below, rawLatestServerResponse is just the latest
 * response from the server, and serverAssignments is the value that we actually use. You can also see functions
 * like hasClientChanges and usingOutOfDateServerModel that can inform the user of the current state through
 * the UI.
 */
interface SignalState {
  // all about data from the server
  hasInitialized: boolean;
  updateWithLatestServerModel: (
    signalsWithWeights: SignalsWithWeights,
    optionalSignals: OptionalScoringSignal[],
    customSignals: CustomSignal[]
  ) => void;
  rawLatestServerResponse: SignalsWithWeights | null;

  // customer-specific signals
  customSignals: CustomSignal[] | null;
  optionalSignals: OptionalScoringSignal[] | null;

  serverState: SignalsWithWeights | null;
  clientBuckets: {
    [signal in ScoringSignal]?: StrictExclude<BucketType, 'unassigned'>;
  };
  clientWeights: SignalWeights;

  assignBucket: (signal: ScoringSignal, bucket: BucketType) => void;
  assignWeight: (signal: ScoringSignal, weight: number) => void;

  getEverySignalBucketSorted: () => Record<ScoringSignal, BucketType>;
  getSignalWeights: () => {
    positive: SignalWeights;
    negative: SignalWeights;
  };
  getSignalWeight: (signal: ScoringSignal) => number;
  getSignalBucket: (signal: ScoringSignal) => BucketType;
  getServerFormat: () => SignalsWithWeights;

  // understanding the current state of user changes
  hasWeightChanges: (comparisonState?: SignalWeights) => boolean;
  hasBucketChanges: () => boolean;
  usingOutOfDateServerModel: () => boolean;

  persistAllChanges: (signals: SignalsWithWeights) => void;

  resetBucketChanges: () => void;
  resetWeightChanges: (resetWeight?: SignalWeights) => void;
}

const createSignalState = () =>
  createStore<SignalState>()(
    devtools<SignalState>(
      (set, get) => {
        const getSortedSignals = () => {
          return [
            // Display custom signals above all others
            ...(get().customSignals ?? []),
            ...CoreScoringSignals,
            ...(get().optionalSignals ?? []),
          ];
        };

        return {
          // do we have initial data from the server yet?
          hasInitialized: false,

          // called whenever we get fresh data from the server. We always
          // store that response in rawLatestServerResponse; the two
          // state variables that actually determine what we show to the user
          // are serverAssignments and serverWeights (for the two respective
          // screens.) We don't update the latter two unless there are no
          // client modifications, to prevent us blowing away their changes.
          updateWithLatestServerModel: (
            signalsWithWeights: SignalsWithWeights,
            optionalSignals: OptionalScoringSignal[],
            customSignals: CustomSignal[]
          ) =>
            set(
              produce<SignalState>((state) => {
                state.rawLatestServerResponse = signalsWithWeights;

                // don't clobber pending changes
                if (!get().hasBucketChanges() && !get().hasWeightChanges()) {
                  state.serverState = signalsWithWeights;
                  const {buckets, weights} =
                    serverToClientAssignmentsFormat(signalsWithWeights);
                  state.clientBuckets = buckets;
                  state.clientWeights = weights;
                }

                state.customSignals = customSignals;
                state.optionalSignals = optionalSignals;
                state.hasInitialized = true;
              })
            ),

          rawLatestServerResponse: null,
          customSignals: null,
          optionalSignals: null,

          serverState: [],
          clientBuckets: {},
          clientWeights: {},
          // sets clientAssignments.
          assignBucket: (signal, bucket) =>
            set(
              produce<SignalState>((state) => {
                if (!state.hasInitialized) {
                  return;
                }
                if (bucket === 'unassigned') {
                  delete state.clientBuckets[signal];
                } else {
                  state.clientBuckets[signal] = bucket;
                }
              })
            ),
          assignWeight: (signal, weight) =>
            set(
              produce<SignalState>((state) => {
                if (!state.hasInitialized) {
                  return;
                }

                state.clientWeights[signal] = weight;
              })
            ),
          // does the client have any changes? note: say a user dragged a tag to a bucket and back:
          // clientAssignments would not be empty hasAssignmentChanges should still be false.
          hasBucketChanges: () => {
            if (!get().hasInitialized) {
              return false;
            }

            const {buckets: serverBuckets} = serverToClientAssignmentsFormat(
              get().serverState
            );

            return !_.isEqual(get().clientBuckets, serverBuckets);
          },
          hasWeightChanges: (comparisonState) => {
            if (!get().hasInitialized) {
              return false;
            }

            const {weights: serverWeights} = serverToClientAssignmentsFormat(
              get().serverState
            );

            const signalWeights = comparisonState ?? serverWeights;
            // there is a change in weights if either:
            //  1) client has a new signal weight that isn't default value (1), OR
            //  2) client has changed a server weight value
            for (const [signal, weight] of entries(get().clientWeights)) {
              const signalWeight = signalWeights[signal];
              const hasChange =
                signalWeight === undefined
                  ? weight !== 1
                  : weight !== signalWeight;

              if (hasChange) {
                return true;
              }
            }

            return false;
          },

          // powers the UI: every component fetches this and renders its tags, in order.
          getEverySignalBucketSorted: () => {
            // _.mapValues doesn't preserve record typing, explicitly add type
            const assignments: Record<ScoringSignal, BucketType> = _.mapValues(
              _.keyBy(getSortedSignals()),
              () => 'unassigned'
            ) as Record<ScoringSignal, BucketType>;

            for (const [signal, bucket] of entries(get().clientBuckets)) {
              if (bucket) {
                assignments[signal] = bucket;
              }
            }
            return assignments;
          },
          getSignalWeights: () => {
            const positive: {[signal in ScoringSignal]?: number} = {};
            const negative: {[signal in ScoringSignal]?: number} = {};

            const sortedSignals = getSortedSignals();
            for (const signal of sortedSignals) {
              const bucket = get().clientBuckets[signal];
              if (!bucket) {
                continue;
              }

              const weight = get().getSignalWeight(signal);
              if (bucket === 'positive') {
                positive[signal] = weight;
              } else if (bucket === 'negative') {
                negative[signal] = weight;
              }
            }
            return {positive, negative};
          },
          getSignalWeight(signal: ScoringSignal) {
            return get().clientWeights[signal] ?? 1;
          },
          getSignalBucket(signal: ScoringSignal) {
            return get().clientBuckets[signal] ?? 'unassigned';
          },
          getServerFormat: () => {
            const signalsWithWeights: SignalsWithWeights = [];
            for (const [signal, bucket] of entries(get().clientBuckets)) {
              const multiplier = bucket === 'positive' ? 1 : -1;
              const weight = (get().clientWeights[signal] ?? 1) * multiplier;
              signalsWithWeights.push({signal, weight});
            }

            return signalsWithWeights;
          },

          // this is true when the client has made modifications and the server has a newer server
          // model than the one the changes are on top of
          usingOutOfDateServerModel: () => {
            if (!get().hasInitialized) {
              return false;
            }
            return !_.isEqual(get().rawLatestServerResponse, get().serverState);
          },
          persistAllChanges: (signalsWithWeights) => {
            set(
              produce<SignalState>((state) => {
                state.rawLatestServerResponse = signalsWithWeights;
                state.serverState = signalsWithWeights;

                const {weights, buckets} =
                  serverToClientAssignmentsFormat(signalsWithWeights);
                state.clientBuckets = buckets;
                state.clientWeights = weights;
              })
            );
          },
          resetBucketChanges: () =>
            set(
              produce<SignalState>((state) => {
                track('scoringModelReset', {properties: {type: 'buckets'}});
                const {buckets} = serverToClientAssignmentsFormat(
                  state.serverState
                );
                state.clientBuckets = buckets;
              })
            ),
          resetWeightChanges: (resetWeight) =>
            set(
              produce<SignalState>((state) => {
                track('scoringModelReset', {properties: {type: 'weights'}});
                const {weights} = serverToClientAssignmentsFormat(
                  state.serverState
                );
                state.clientWeights = resetWeight ?? weights;
              })
            ),
        };
      },
      {
        enabled: import.meta.env.DEV,
      }
    )
  );

const signalStores: Record<string, StoreApi<SignalState>> = {};
export const useScoringModelSignalStore = () => {
  const {id: marketId, scoringModel} = useMarketFromContext();
  const {customer} = useMetadata();

  let signalStore = signalStores[marketId.toString()];

  if (!signalStore) {
    signalStore = createSignalState();
    signalStores[marketId.toString()] = signalStore;
  }
  const store = useStore(signalStore);

  const {updateWithLatestServerModel} = store;
  useEffect(() => {
    if (!scoringModel) {
      return;
    }

    const customSignals =
      customer.signalDefinitions?.map((definition) => definition.id) ?? [];

    updateWithLatestServerModel(
      scoringModel.signalsWithWeights,
      customer.signals,
      customSignals
    );
  }, [
    customer.signalDefinitions,
    customer.signals,
    scoringModel,
    updateWithLatestServerModel,
  ]);

  return store;
};

const serverToClientAssignmentsFormat = (
  signalsWithWeights: SignalsWithWeights | null
) => {
  const buckets: {
    [key in ScoringSignal]?: StrictExclude<BucketType, 'unassigned'>;
  } = {};
  const weights: {[key in ScoringSignal]?: number} = {};

  if (signalsWithWeights === null) {
    return {weights, buckets};
  }

  for (const {signal, weight} of signalsWithWeights) {
    if (weight >= 0) {
      buckets[signal] = 'positive';
    } else {
      buckets[signal] = 'negative';
    }
    weights[signal] = Math.abs(weight);
  }
  return {weights, buckets};
};
