module Components.AdjacencyMatrix

open Shared.Tree

type Id = string

type Node = {
    Id : Id
    Name : string
    Selected : bool
}

type MatrixNode = {
    Id : Id
    Name : string
    Level : int
    Visible  : bool
    EdgeCount : int
    Selected : bool
}

type Cell = {
    Row : Id
    Column : Id
    Value: float
}

type Millisecond = int

type Config = {
    CellSize : int
    TransitionDelay : Millisecond
    TransitionDuration : Millisecond
}

// Model is a matrix of cells with (hierarchical) rows and columns.
// Rows and columns represent graph nodes and cells represent graph edges (specified by node IDs).
// Rows and Columns hold respective node IDs and labels + hierarchy level and visibility (added by createModel).
// Model is usually a sparse matrix and therefore cells are represented as a list (to optimize animations).

type Model = {
    Rows : Tree<MatrixNode> list
    Columns : Tree<MatrixNode> list
    Cells : Cell list
}

type Msg =
    | RowSelected of Id
    | RowDeselected of Id
    | ColumnSelected of Id
    | ColumnDeselected of Id

let createModel (rows : List<Tree<Node>>) (columns : List<Tree<Node>>) (cells : List<Cell>) =

    let rowEdgeCounts =
        cells
        |> List.map (fun cell -> cell.Row)
        |> List.countBy id
        |> Map.ofList

    let getRowEdgeCount id =
        if rowEdgeCounts.ContainsKey id then
            rowEdgeCounts.[id]
        else 0

    let columnEdgeCounts =
        cells
        |> List.map (fun cell -> cell.Column)
        |> List.countBy id
        |> Map.ofList

    let getColumnEdgeCount id =
        if columnEdgeCounts.ContainsKey id then
            columnEdgeCounts.[id]
        else 0

    let rec createTree getEdgeCount (level : int) (tree : Tree<Node>) : Tree<MatrixNode> =
        match tree with
        | Leaf node ->
            Leaf { Id = node.Id ; Name = node.Name ; Level = level ; Visible = true ; EdgeCount = getEdgeCount node.Id ; Selected = node.Selected }
        | Branch (node, subTrees) ->
            let matrixNode = { Id = node.Id ; Name = node.Name ; Level = level ; Visible = true ; EdgeCount = getEdgeCount node.Id ; Selected = node.Selected }
            let matrixSubTrees =
                subTrees
                |> List.fold (
                    fun acumulator tree ->
                        let matrixTree = createTree getEdgeCount (level + 1) tree
                        in List.append acumulator [matrixTree])
                    List.Empty
            Branch (matrixNode, matrixSubTrees)

    let createTreeList getEdgeCount (trees : List<Tree<Node>>) : List<Tree<MatrixNode>> =
        trees |> List.map (createTree getEdgeCount 0)

    { Rows = createTreeList getRowEdgeCount rows
      Columns = createTreeList getColumnEdgeCount columns
      Cells = cells }

// --------------------------------------------------------------------------------------
// Renderers

open Fable.Helpers.React
open Fable.Helpers.React.Props

let private transition duration =
    sprintf "transform %dms ease-in-out" duration

let private renderRect xPosition yPosition width height (opacity : float) (delay : int) children transitionDelay transitionDuration =
    rect [
        SVGAttr.Width width
        SVGAttr.Height height
        SVGAttr.FillOpacity opacity
        SVGAttr.Transform (sprintf "translate(%d, %d)" xPosition yPosition)
        Style [
            Transition (transition transitionDuration)
            TransitionDelay (sprintf "%dms" (transitionDelay * delay)) ] ]
        children

let render config (model : Model) sortBy dispatch =
    let labelPadding = 8
    let cellPadding = 1
    let levelSizeDecrement = 15

    let rowLabelWidth = 250
    let rowLabelMargin = 1

    let columnLabelHeight = 140
    let columnLabelMargin = 1

    let rowNodes = model.Rows |> List.collect Shared.Tree.flatten
    let columnNodes = model.Columns |> List.collect Shared.Tree.flatten

    let matrixWidth = columnNodes.Length * (config.CellSize + cellPadding) - cellPadding
    let matrixHeight = rowNodes.Length * (config.CellSize + cellPadding) - cellPadding

    let orderedRowNodes, orderedColumnNodes = sortBy model.Rows model.Columns model.Cells

    let nodePositionsMap nodes =
        nodes
        |> List.mapi (fun i el -> (el.Id, i))
        |> Map.ofList

    let rowPositionsMap = nodePositionsMap orderedRowNodes
    let columnPositionsMap = nodePositionsMap orderedColumnNodes

    let rowLabels =
        rowNodes
        |> List.map (fun row ->
            let sizeDecrement = levelSizeDecrement * row.Level
            g [ classList [
                    (sprintf "adjacency-matrix__label level-%d" row.Level, true)
                    (sprintf "adjacency-matrix__label--selected", row.Selected )
                    ("adjacency-matrix__label--empty", row.EdgeCount = 0) ]
                OnClick (fun ev -> row.Id |> (if row.Selected then RowDeselected else RowSelected) |> dispatch)
                SVGAttr.Transform (sprintf "translate(%d, %d)" 0 (rowPositionsMap.[row.Id] * (config.CellSize + cellPadding)))
                Style [
                    Transition (transition config.TransitionDuration)
                    TransitionDelay (sprintf "%dms" (config.TransitionDelay * rowPositionsMap.[row.Id])) ] ]
              [ renderRect sizeDecrement 0 (rowLabelWidth - sizeDecrement) config.CellSize 1.0 0 [] config.TransitionDelay config.TransitionDuration
                text [
                    SVGAttr.X (rowLabelWidth - labelPadding)
                    SVGAttr.Y ((float config.CellSize) * 0.5)
                    ] [ str row.Name ]
                line [ Class "adjacency-matrix__line" ; SVGAttr.X1 rowLabelWidth ; SVGAttr.Y1 config.CellSize ; SVGAttr.X2 (matrixWidth + rowLabelWidth) ; SVGAttr.Y2 config.CellSize ] []
              ])

    let columnLabels =
        columnNodes
        |> List.mapi (fun i column ->
            let sizeDecrement = levelSizeDecrement * column.Level
            g [ classList [
                    (sprintf "adjacency-matrix__label level-%d" column.Level, true)
                    (sprintf "adjacency-matrix__label--selected", column.Selected )
                    ("adjacency-matrix__label--empty", column.EdgeCount = 0) ]
                OnClick (fun ev -> column.Id |> (if column.Selected then ColumnDeselected else ColumnSelected) |> dispatch)
                SVGAttr.Transform (sprintf "translate(%d, %d)" (columnPositionsMap.[column.Id] * (config.CellSize + cellPadding)) 0)
                Style [
                    Transition (transition config.TransitionDuration)
                    TransitionDelay (sprintf "%dms" (config.TransitionDelay * columnPositionsMap.[column.Id])) ] ]
              [ renderRect 0 sizeDecrement config.CellSize (columnLabelHeight - sizeDecrement) 1.0 0 [] config.TransitionDelay config.TransitionDuration
                g [ SVGAttr.Transform (sprintf "translate(%f %d)" (float config.CellSize * 0.5) (columnLabelHeight - labelPadding)) ] [
                    text [ SVGAttr.Transform (sprintf "rotate(-90)") ] [ str column.Name ] ]
                line [ Class "adjacency-matrix__line" ; SVGAttr.X1 config.CellSize ; SVGAttr.Y1 columnLabelHeight ; SVGAttr.X2 config.CellSize ; SVGAttr.Y2 (matrixHeight + columnLabelHeight) ] []
              ])

    let cells =
        model.Cells
        |> List.map (fun cell ->
            g [ Class "adjacency-matrix__cell" ] [
                renderRect
                    (columnPositionsMap.[cell.Column] * (config.CellSize + cellPadding))
                    (rowPositionsMap.[cell.Row] * (config.CellSize + cellPadding))
                    config.CellSize config.CellSize cell.Value
                    ((float (columnPositionsMap.[cell.Column] + rowPositionsMap.[cell.Row])) / 2.0 |> int)
                    []
                    config.TransitionDelay
                    config.TransitionDuration])

    svg [ Class "adjacency-matrix"
          SVGAttr.Width (matrixWidth + rowLabelWidth + rowLabelMargin)
          SVGAttr.Height (matrixHeight + columnLabelHeight + columnLabelMargin)
        ] [ g [ Class "adjacency-matrix__row-labels"
                SVGAttr.Transform (sprintf "translate(%d, %d)" 0 (columnLabelHeight + columnLabelMargin)) ] rowLabels
            g [ Class "adjacency-matrix__column-labels"
                SVGAttr.Transform (sprintf "translate(%d, %d)" (rowLabelWidth + rowLabelMargin) 0) ] columnLabels
            g [ Class "adjacency-matrix__cells"
                SVGAttr.Transform (sprintf "translate(%d, %d)" (rowLabelWidth + rowLabelMargin) (columnLabelHeight + columnLabelMargin)) ] cells
        ]
