import { QueryKey } from "@tanstack/react-query";
import { sql } from "kysely";

import { executeSqlV2, queryBuilder } from "~/api/materialize";

/**
 * Builds a SQL query to retrieve the history of replica name changes and creations.
 *
 * The query combines two sets of audit events:
 * 1. Alter events: These events capture changes to replica names.
 * 2. Create events: These events capture the creation of new replicas.
 *
 * The resulting query returns the following fields:
 * - occurred_at: The timestamp when the event occurred.
 * - id: The ID of the replica.
 * - old_name: The old name of the replica (null for create events).
 * - new_name: The new name of the replica.
 * - cluster_id: The ID of the cluster to which the replica belongs.
 */
export function buildReplicaNameHistoryQuery() {
  const replicaAlterHistory = queryBuilder
    .selectFrom("mz_audit_events as audit_events")
    .select([
      "occurred_at",
      sql<string>`audit_events.details->>'replica_id'`.as("id"),
      sql<string | null>`audit_events.details->>'old_name'`.as("old_name"),
      sql<string>`audit_events.details->>'new_name'`.as("new_name"),
      sql<string>`audit_events.details->>'cluster_id'`.as("cluster_id"),
    ])
    .where("object_type", "=", "cluster-replica")
    .where("audit_events.event_type", "=", "alter");

  const replicaCreateHistory = queryBuilder
    .selectFrom("mz_audit_events as audit_events")
    .select([
      "occurred_at",
      sql<string>`audit_events.details->>'replica_id'`.as("id"),
      sql<null>`NULL`.as("old_name"),
      sql<string>`audit_events.details->>'replica_name'`.as("new_name"),
      sql<string>`audit_events.details->>'cluster_id'`.as("cluster_id"),
    ])
    .where("object_type", "=", "cluster-replica")
    .where("audit_events.event_type", "=", "create");

  // Because built in system cluster replicas don't have audit events, we need to manually add them
  const systemClusters = queryBuilder
    .selectFrom("mz_cluster_replicas")
    .select([
      // We assume that the system cluster replicas were created at the beginning of time
      sql<Date>`TIMESTAMP '1970-01-01'`.as("occurred_at"),
      "id",
      sql<null>`NULL`.as("old_name"),
      "name as new_name",
      "cluster_id",
    ])
    .where("id", "like", "s%");

  return replicaAlterHistory.union(replicaCreateHistory).union(systemClusters);
}

export type ReplicaUtilizationHistoryParameters = {
  // Filter per cluster
  clusterId?: string;
  // Filter per replica
  replicaId?: string;
  // Start date of the history. The history will start from the earliest bucket.
  startDate: string;
  // End date of the history. The history will end at the latest bucket.
  endDate?: string;
  // Size of the time buckets in milliseconds
  bucketSizeMs: number;
};

export function buildReplicaUtilizationHistoryQuery({
  clusterId,
  replicaId,
  startDate,
  endDate,
  bucketSizeMs,
}: ReplicaUtilizationHistoryParameters) {
  const bucketSizeMsSqlStr = sql.raw(`${bucketSizeMs}`);
  const startDateLit = sql.lit(startDate);
  const endDateLit = sql.lit(startDate);

  const dateBinOrigin = sql.lit("1970-01-01");

  let query = queryBuilder
    .with("replica_history", (qb) =>
      qb
        .selectFrom("mz_cluster_replica_history")
        .select(["replica_id", "size"])
        // We need to union the current set of cluster replicas since mz_cluster_replica_history doesn't account for system clusters
        .union(
          qb
            .selectFrom("mz_cluster_replicas")
            .select(["id as replica_id", "size"]),
        ),
    )
    .with("replica_utilization_history_binned", (qb) =>
      qb
        .selectFrom("replica_history as r")
        .innerJoin("mz_cluster_replica_sizes as s", "r.size", "s.size")
        .innerJoin(
          "mz_cluster_replica_metrics_history as m",
          "m.replica_id",
          "r.replica_id",
        )
        .select((eb) => [
          "m.occurred_at",
          "m.replica_id",
          "m.process_id",
          sql<number>`(m.cpu_nano_cores::float8 / s.cpu_nano_cores) * 100`.as(
            "cpu_percent",
          ),
          sql<number>`(m.memory_bytes::float8 / s.memory_bytes) * 100`.as(
            "memory_percent",
          ),
          sql<number>`(m.disk_bytes::float8 / s.disk_bytes) * 100`.as(
            "disk_percent",
          ),
          eb.ref("m.disk_bytes").$castTo<bigint>().as("disk_bytes"),
          eb.ref("m.memory_bytes").$castTo<bigint>().as("memory_bytes"),
          "r.size",
          sql<Date>`date_bin(
              '${bucketSizeMsSqlStr} MILLISECONDS', 
              occurred_at,
              TIMESTAMP ${dateBinOrigin}
            )`.as("bucket_start"),
        ]),
    )
    // For each (replica, process_id, bucket), take the (replica, process_id, bucket) with the highest memory
    .with("max_memory", (qb) =>
      /**
       * This is a TOP k optimization using a LATERAL subquery and limit https://materialize.com/docs/transform-data/patterns/top-k/.
       * Because Kysely doesn't support lateral cross joins, we use innerJoinLateral and join.onTrue() to get the same behavior.
       */
      qb
        .selectFrom(
          qb
            .selectFrom("replica_utilization_history_binned")
            .distinctOn(["bucket_start", "replica_id", "process_id"])
            .select(["bucket_start", "replica_id", "process_id"])
            .as("grp"),
        )
        .innerJoinLateral(
          (eb) =>
            eb
              .selectFrom("replica_utilization_history_binned")
              .select(["memory_percent", "memory_bytes", "occurred_at"])
              .whereRef("bucket_start", "=", "grp.bucket_start")
              .whereRef("replica_id", "=", "grp.replica_id")
              .whereRef("process_id", "=", "grp.process_id")
              .orderBy(
                (oeb) => sql`COALESCE(${oeb.ref("memory_bytes")}, 0)`,
                "desc",
              )
              .limit(1)
              .as("topk"),
          (join) => join.onTrue(),
        )
        .select([
          "bucket_start",
          "replica_id",
          "memory_percent",
          "memory_bytes",
          "occurred_at",
        ]),
    )
    // For each (replica, process_id, bucket), take the (replica, process_id, bucket) with the highest disk
    .with("max_disk", (qb) =>
      qb
        .selectFrom(
          qb
            .selectFrom("replica_utilization_history_binned")
            .distinctOn(["bucket_start", "replica_id", "process_id"])
            .select(["bucket_start", "replica_id", "process_id"])
            .as("grp"),
        )
        .innerJoinLateral(
          (eb) =>
            eb
              .selectFrom("replica_utilization_history_binned")
              .select(["disk_percent", "disk_bytes", "occurred_at"])
              .whereRef("bucket_start", "=", "grp.bucket_start")
              .whereRef("replica_id", "=", "grp.replica_id")
              .whereRef("process_id", "=", "grp.process_id")
              .orderBy(
                (oeb) => sql`COALESCE(${oeb.ref("disk_bytes")}, 0)`,
                "desc",
              )
              .limit(1)
              .as("topk"),
          (join) => join.onTrue(),
        )
        .select([
          "bucket_start",
          "replica_id",
          "disk_percent",
          "disk_bytes",
          "occurred_at",
        ]),
    )
    // For each (replica, process_id, bucket), take the (replica, process_id, bucket) with the highest cpu
    .with("max_cpu", (qb) =>
      qb
        .selectFrom(
          qb
            .selectFrom("replica_utilization_history_binned")
            .distinctOn(["bucket_start", "replica_id", "process_id"])
            .select(["bucket_start", "replica_id", "process_id"])
            .as("grp"),
        )
        .innerJoinLateral(
          (eb) =>
            eb
              .selectFrom("replica_utilization_history_binned")
              .select(["cpu_percent", "occurred_at"])
              .whereRef("bucket_start", "=", "grp.bucket_start")
              .whereRef("replica_id", "=", "grp.replica_id")
              .whereRef("process_id", "=", "grp.process_id")
              .orderBy(
                (oeb) => sql`COALESCE(${oeb.ref("cpu_percent")}, 0)`,
                "desc",
              )
              .limit(1)
              .as("topk"),
          (join) => join.onTrue(),
        )
        .select(["bucket_start", "cpu_percent", "replica_id", "occurred_at"]),
    )
    // Get the history of replica name changes
    .with("replica_name_history", (qb) =>
      qb
        .selectFrom(
          buildReplicaNameHistoryQuery().as("replica_name_history_inner"),
        )
        .selectAll(),
    )
    // For each (replica, process_id, bucket), get its offline events at that time
    .with("replica_offline_event_history", (qb) =>
      qb
        .selectFrom("mz_cluster_replica_status_history as rsh")
        .select([
          sql<Date>`date_bin(
              '${bucketSizeMsSqlStr} MILLISECONDS', 
              occurred_at,
              TIMESTAMP ${dateBinOrigin}
            )`.as("bucket_start"),
          "replica_id",
          sql<
            {
              replicaId: string;
              occurredAt: string;
              status: string;
              reason: string;
            }[]
          >`jsonb_agg(
              jsonb_build_object(
                'replicaId', rsh.replica_id, 
                'occurredAt', rsh.occurred_at, 
                'status', rsh.status, 
                'reason', rsh.reason
              )
          )`.as("offline_events"),
        ])
        .where("process_id", "=", "0")
        .where("status", "=", "offline")
        .groupBy(["bucket_start", "replica_id"]),
    )
    .with("bucket_intervals", (qb) =>
      qb
        .selectFrom("max_memory")
        .select(["bucket_start", "replica_id"])
        .union(qb.selectFrom("max_disk").select(["bucket_start", "replica_id"]))
        .union(qb.selectFrom("max_cpu").select(["bucket_start", "replica_id"]))
        .union(
          qb
            .selectFrom("replica_offline_event_history")
            .select(["bucket_start", "replica_id"]),
        ),
    )
    .selectFrom("bucket_intervals")
    .leftJoin("max_memory", (join) =>
      join
        .onRef("bucket_intervals.bucket_start", "=", "max_memory.bucket_start")
        .onRef("bucket_intervals.replica_id", "=", "max_memory.replica_id"),
    )
    .leftJoin("max_disk", (join) =>
      join
        .onRef("bucket_intervals.bucket_start", "=", "max_disk.bucket_start")
        .onRef("bucket_intervals.replica_id", "=", "max_disk.replica_id"),
    )
    .leftJoin("max_cpu", (join) =>
      join
        .onRef("bucket_intervals.bucket_start", "=", "max_cpu.bucket_start")
        .onRef("bucket_intervals.replica_id", "=", "max_cpu.replica_id"),
    )
    .leftJoin("replica_offline_event_history", (join) =>
      join
        .onRef(
          "bucket_intervals.bucket_start",
          "=",
          "replica_offline_event_history.bucket_start",
        )
        .onRef(
          "bucket_intervals.replica_id",
          "=",
          "replica_offline_event_history.replica_id",
        ),
    )
    .innerJoin("replica_history", (join) =>
      join.onRef(
        "bucket_intervals.replica_id",
        "=",
        "replica_history.replica_id",
      ),
    )
    .innerJoinLateral(
      (qb) =>
        qb
          .selectFrom("replica_name_history")
          .selectAll()
          .whereRef(
            "bucket_intervals.replica_id",
            "=",
            "replica_name_history.id",
          )
          // Before the end of each bucket, get the closest name to the end
          .whereRef(
            sql<Date>`bucket_intervals.bucket_start + INTERVAL '${bucketSizeMsSqlStr} MILLISECONDS'`,
            ">=",
            "replica_name_history.occurred_at",
          )
          .orderBy("replica_name_history.occurred_at", "desc")
          .limit(1)
          .as("replica_name_history"),
      (join) => join.onTrue(),
    )
    .select([
      "bucket_intervals.bucket_start as bucketStart",
      "bucket_intervals.replica_id as replicaId",
      "max_memory.memory_percent as memoryPercent",
      "max_memory.memory_bytes as memoryBytes",
      "max_memory.occurred_at as maxMemoryAt",
      "max_disk.disk_percent as diskPercent",
      "max_disk.disk_bytes as diskBytes",
      "max_disk.occurred_at as maxDiskAt",
      "max_cpu.cpu_percent as cpuPercent",
      "max_cpu.occurred_at as maxCpuAt",
      "replica_offline_event_history.offline_events as offlineEvents",
      sql<Date>`bucket_intervals.bucket_start + INTERVAL '${bucketSizeMsSqlStr} MILLISECONDS'`.as(
        "bucketEnd",
      ),
      "replica_name_history.new_name as name",
      "replica_name_history.cluster_id as clusterId",
      "replica_history.size",
    ])
    .where(
      "bucket_intervals.bucket_start",
      ">=",
      sql<Date>`
            date_bin(
              '${bucketSizeMsSqlStr} MILLISECONDS', 
              TIMESTAMP ${startDateLit},
              TIMESTAMP ${dateBinOrigin}
            )`,
    )

    .orderBy("bucketStart");

  if (clusterId) {
    query = query.where("replica_name_history.cluster_id", "=", clusterId);
  }

  if (replicaId) {
    query = query.where("bucket_intervals.replica_id", "=", replicaId);
  }

  if (endDate) {
    query = query.where(
      sql<Date>`bucket_intervals.bucket_start + INTERVAL '${bucketSizeMsSqlStr} MILLISECONDS'`,
      "<=",
      sql<Date>`
            date_bin(
              '${bucketSizeMsSqlStr} MILLISECONDS', 
              TIMESTAMP ${endDateLit},
              TIMESTAMP ${dateBinOrigin}
            )`,
    );
  }

  return query;
}

export async function fetchReplicaUtilizationHistory({
  params,
  queryKey,
  requestOptions,
}: {
  params: ReplicaUtilizationHistoryParameters;
  queryKey: QueryKey;
  requestOptions?: RequestInit;
}) {
  const compiledQuery = buildReplicaUtilizationHistoryQuery(params).compile();
  const res = await executeSqlV2({
    queries: compiledQuery,
    queryKey: queryKey,
    requestOptions,
    sessionVariables: {
      // We use serializable because we don't care about strict seriailizability and to get consistent performance
      transaction_isolation: "serializable",
    },
  });

  type Bucket = (typeof res)["rows"][0];

  const bucketsByReplicaId: Record<string, Bucket[]> = {};

  let minBucketStartMs = new Date(params.startDate).getTime();
  let maxBucketEndMs = params.endDate
    ? new Date(params.endDate).getTime()
    : new Date().getTime();

  for (const row of res.rows) {
    minBucketStartMs = Math.min(minBucketStartMs, row.bucketStart.getTime());
    maxBucketEndMs = Math.max(maxBucketEndMs, row.bucketEnd.getTime());

    const { replicaId } = row;
    const buckets = bucketsByReplicaId[replicaId];

    if (buckets) {
      buckets.push(row);
    } else {
      bucketsByReplicaId[replicaId] = [row];
    }
  }

  return {
    startDate: new Date(minBucketStartMs),
    endDate: new Date(maxBucketEndMs),
    bucketsByReplicaId,
  };
}
