import Heap from "heap"
import { createAtom, IAtom } from "mobx"
import { Box2, Vector2 } from "three"
import { neighbours, orthogonalNeighbours } from "../utils/tiles"
import * as bitmask from "../utils/bitmask"

const CHUNK_SIZE = 16

type ChangeListener<T extends number> = (
  index: number,
  value: T,
  map: TileMap<T>
) => void
export class TileMap<T extends number = number> {
  atom: IAtom
  version = 0
  readonly size: Vector2
  readonly data: Uint8Array
  private changeListeners: ChangeListener<T>[] = []
  readonly chunks: Box2[]
  private chunkChangeListeners: ChangeListener<T>[][] = []
  constructor(size: Vector2, data = new Uint8Array(size.x * size.y)) {
    this.atom = createAtom("tile map")
    if (data.length !== size.x * size.y) throw new Error("Data is wrong size")
    this.size = size
    this.data = data
    this.chunks = []
    for (let y = 0; y < size.y; y += CHUNK_SIZE) {
      for (let x = 0; x < size.x; x += CHUNK_SIZE) {
        const bounds = new Box2(
          new Vector2(x, y),
          new Vector2(
            x + Math.min(size.x - x, CHUNK_SIZE),
            y + Math.min(size.y - y, CHUNK_SIZE)
          )
        )
        this.chunks.push(bounds)
      }
    }
  }
  addListener(fn: ChangeListener<T>) {
    this.changeListeners.push(fn)
  }
  removeListener(fn: ChangeListener<T>) {
    this.changeListeners = this.changeListeners.filter(f => f !== fn)
  }
  addChunkListener(chunkIndex: number, fn: ChangeListener<T>) {
    if (!this.chunkChangeListeners[chunkIndex])
      this.chunkChangeListeners[chunkIndex] = []
    this.chunkChangeListeners[chunkIndex].push(fn)
  }
  removeChunkListener(chunkIndex: number, fn: ChangeListener<T>) {
    if (!this.chunkChangeListeners[chunkIndex]) return
    const listeners = this.chunkChangeListeners[chunkIndex]
    this.chunkChangeListeners[chunkIndex] = listeners.filter(f => f !== fn)
  }
  isInBounds(coord: Vector2 | number) {
    const { x, y } = typeof coord === "number" ? this.toCoord(coord) : coord
    return x >= 0 && y >= 0 && x < this.size.x && y < this.size.y
  }
  toIndex(coord: Vector2) {
    return coord.y * this.size.x + coord.x
  }
  toCoord(index: number, target = new Vector2()) {
    if (index < 0 || index >= this.data.length) throw new Error("Out of bounds")
    return target.set(index % this.size.x, Math.floor(index / this.size.x))
  }
  get(coord: Vector2 | number) {
    const index = typeof coord === "number" ? coord : this.toIndex(coord)
    this.atom.reportObserved()
    return this.data[index] as T
  }
  set(coord: Vector2 | number, value: T) {
    const index = typeof coord === "number" ? coord : this.toIndex(coord)
    this.data[index] = value
    this.version++
    this.changeListeners.forEach(fn => fn(index, value, this))
    const chunkListeners = this.chunkChangeListeners[this.getChunk(index)]
    if (chunkListeners) chunkListeners.forEach(fn => fn(index, value, this))
    this.atom.reportChanged()
  }
  is(index: number | Vector2, mask: number) {
    const value = this.get(index)
    return bitmask.get(value, mask)
  }
  copy(data: Uint8Array) {
    this.data.set(data)
    this.version++
    this.atom.reportChanged()
  }
  getChunk(coord: number | Vector2) {
    const { x, y } = typeof coord === "number" ? this.toCoord(coord) : coord
    const chunkWidth = Math.ceil(this.size.x / CHUNK_SIZE)
    const chunkX = Math.floor(x / CHUNK_SIZE)
    const chunkY = Math.floor(y / CHUNK_SIZE)
    return chunkY * chunkWidth + chunkX
  }
  forEach(
    fn: (value: T, coord: Vector2) => void | false,
    min = new Vector2(),
    max = this.size
  ) {
    const coord = new Vector2()
    for (let y = min.y; y < max.y; y++) {
      for (let x = min.x; x < max.x; x++) {
        coord.set(x, y)
        if (!this.isInBounds(coord)) continue
        const result = fn(this.get(coord), coord)
        if (result === false) return
      }
    }
  }
  getNeighbours(_coord: Vector2 | number, includeDiagonal = true) {
    const coord = typeof _coord === "number" ? this.toCoord(_coord) : _coord
    const offsets = includeDiagonal ? neighbours : orthogonalNeighbours
    return offsets
      .map(offset => offset.clone().add(coord))
      .filter(coord => this.isInBounds(coord))
  }
  getNeighbourIndexes(_index: Vector2 | number, includeDiagonal = true) {
    const index = typeof _index === "number" ? _index : this.toIndex(_index)
    const { x, y } = typeof _index === "number" ? this.toCoord(_index) : _index
    const top = y > 0
    const right = x < this.size.x - 1
    const bottom = y < this.size.y - 1
    const left = x > 0
    const neighbours = []
    if (top) neighbours.push(index - this.size.x)
    if (includeDiagonal && top && right)
      neighbours.push(index - this.size.x + 1)
    if (right) neighbours.push(index + 1)
    if (includeDiagonal && right && bottom)
      neighbours.push(index + this.size.x + 1)
    if (bottom) neighbours.push(index + this.size.x)
    if (includeDiagonal && bottom && left)
      neighbours.push(index + this.size.x - 1)
    if (left) neighbours.push(index - 1)
    if (includeDiagonal && left && top) neighbours.push(index - this.size.x - 1)
    return neighbours
  }
  areNeighbours(a: Vector2 | number, b: Vector2 | number) {
    const index1 = typeof a === "number" ? a : this.toIndex(a)
    const index2 = typeof b === "number" ? b : this.toIndex(b)
    const diff = Math.abs(index1 - index2)
    return diff === 1 || diff === this.size.x
  }
  map(
    min = new Vector2(),
    max = this.size,
    fn: (value: T, coord: Vector2) => T
  ) {
    this.forEach((value, coord) => this.set(coord, fn(value, coord)), min, max)
  }
  fillRect(position: Vector2, size: Vector2, value: T) {
    this.map(position, position.clone().add(size), () => value)
  }
  fill(value: T) {
    this.fillRect(new Vector2(), this.size, value)
  }
  fillCircle(center: Vector2, radius: number, value: T) {
    const halfSize = new Vector2(radius, radius)
    const min = center.clone().sub(halfSize)
    const max = center.clone().add(halfSize)
    const d = new Vector2()
    this.map(min, max, (prevValue, coord) => {
      if (this.isInBounds(coord)) d.copy(coord).sub(center)
      if (d.length() < radius) return value
      return prevValue
    })
  }
  *floodFill(
    center: Vector2 | number,
    include: (index: number) => boolean = () => true
  ) {
    const centerIndex =
      typeof center === "number" ? center : this.toIndex(center)
    const centerCoord =
      typeof center === "number" ? this.toCoord(center) : center
    if (!include(centerIndex)) return
    type SearchNode = {
      index: number
      distanceSq: number
      visited: boolean
      closed: boolean
    }
    const v = new Vector2()
    const nodes: SearchNode[] = []
    const getNode = (index: number) => {
      let node = nodes[index]
      if (!node) {
        node = nodes[index] = {
          index,
          distanceSq: centerCoord.distanceToSquared(this.toCoord(index, v)),
          visited: false,
          closed: false,
        }
      }
      return node
    }
    const heap = new Heap<SearchNode>((a, b) => a.distanceSq - b.distanceSq)
    const centerNode = getNode(centerIndex)
    centerNode.visited = true
    heap.push(centerNode)
    while (heap.size() > 0) {
      const currentNode = heap.pop()!
      yield currentNode.index
      for (const neighbourIndex of this.getNeighbourIndexes(
        currentNode.index
      )) {
        if (!include(neighbourIndex)) continue
        const neighbour = getNode(neighbourIndex)
        if (neighbour.closed) continue
        if (!neighbour.visited) {
          neighbour.visited = true
          heap.push(neighbour)
        }
      }
      currentNode.closed = true
    }
  }
  static fromImage<T extends number>(
    image: HTMLImageElement,
    colorToValue: Record<number, T>
  ) {
    const canvas = document.createElement("canvas")
    const ctx = canvas.getContext("2d")
    if (!ctx) throw new Error("Couldn't get context")
    canvas.width = image.width
    canvas.height = image.height
    ctx.drawImage(image, 0, 0)
    const imageData = ctx.getImageData(0, 0, canvas.width, canvas.height).data
    const mapData = new Uint8Array(image.width * image.height)
    for (let i = 0; i < imageData.length; i += 4) {
      const r = imageData[i]
      const g = imageData[i + 1]
      const b = imageData[i + 2]
      let color = (r << 16) + (g << 8) + b
      // if (!(color in colorToValue))
      // throw new Error(`No value for color ${color.toString(16)}`)
      if (!(color in colorToValue)) color = 0xffff00
      const value = colorToValue[color]
      mapData[i / 4] = value
    }
    return new TileMap<T>(new Vector2(image.width, image.height), mapData)
  }
}
