import { IUniform, Material, Shader } from "three"

// https://stackoverflow.com/a/64493510/1193622
type Mixin<T> = new (...args: any[]) => T

export const extendMaterial = <
  Base extends Mixin<Material>,
  Uniforms extends Record<string, IUniform>
>(
  baseMaterial: Base,
  extensions: {
    fragmentShader?: string
    vertexShader?: string
    uniforms?: Uniforms
    defines?: Record<string, string>
  } = {}
) =>
  class extends baseMaterial {
    isShaderMaterial = true
    uniformsNeedUpdate = true
    shader?: Shader
    initialUniforms = extensions.uniforms
    onBeforeCompile(shader: Shader) {
      if (extensions.vertexShader) shader.vertexShader = extensions.vertexShader
      if (extensions.fragmentShader)
        shader.fragmentShader = extensions.fragmentShader
      if (this.initialUniforms)
        shader.uniforms = { ...shader.uniforms, ...extensions.uniforms }
      if (extensions.defines) {
        const defines =
          Object.entries(extensions.defines)
            .map(([name, value]) => `#define ${name} ${value}`)
            .join("\n") + "\n"
        shader.vertexShader = defines + shader.vertexShader
        shader.fragmentShader = defines + shader.fragmentShader
      }
      this.shader = shader
    }
    get uniforms() {
      if (this.shader) return this.shader.uniforms as Uniforms
      if (this.initialUniforms) return this.initialUniforms
      throw new Error("Material has no uniforms")
    }
  }
