# -*- coding: utf-8 -*-
# =============================================================================
# Created on Sun Jul 21 00:51:34 2019
#
# @author: Brénainn Woodsend
#
#
# Scatter.py creates a scatter plot using spheres.
# Copyright (C) 2019  Brénainn Woodsend
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
# =============================================================================

from builtins import super

import vtk
import numpy as np
from vtk.util.numpy_support import (
                                    numpy_to_vtk,
                                    numpy_to_vtkIdTypeArray,
                                    vtk_to_numpy,
                                    )



from vtkplotlib.plots.BasePlot import SourcedPlot, _iter_colors, _iter_points, _iter_scalar


class Sphere(SourcedPlot):
    """Plot an individual sphere."""
    def __init__(self, point, color=None, opacity=None, radius=1., fig="gcf"):
        super().__init__(fig)

        self.source = vtk.vtkSphereSource()
        self.point = point
        self.radius = radius

        self.add_to_plot()

        self.color_opacity(color, opacity)

    @property
    def point(self):
        return self.source.GetCenter()

    @point.setter
    def point(self, point):
        self.source.SetCenter(*point)

    @property
    def radius(self):
        return self.source.GetRadius()

    @radius.setter
    def radius(self, r):
        self.source.SetRadius(r)


class Cursor(SourcedPlot):
    def __init__(self, point, color=None, opacity=None, radius=1, fig="gcf"):
        super().__init__(fig)

        self.source = vtk.vtkCursor3D()
        self.source.SetTranslationMode(True)
        self.source.OutlineOff()


        self.add_to_plot()

        self.radius = radius


        self.point = point


        self.color_opacity(color, opacity)

    @property
    def point(self):
        return self.source.GetFocalPoint()

    @point.setter
    def point(self, point):
        self.source.SetFocalPoint(*point)

    @property
    def radius(self):
        return np.array([self.source.GetModelBounds()]).reshape((3, 2)).T - self.point

    @radius.setter
    def radius(self, r):
        r = np.asarray(r)

        try:
            r = r * np.array([[-1, -1, -1], [1, 1, 1]])
            assert r.size == 6
        except Exception:
            raise ValueError()

        self.source.SetModelBounds(*(r + self.point).T.flat)






def scatter(points, color=None, opacity=None, radius=1., use_cursors=False, fig="gcf"):
    """Scatter plot using little spheres or cursor objects.

    :param points: The point(s) to place the marker(s) at.
    :type points: np.array with ``points.shape[-1] == 3``

    :param color: The color of the markers, can be singular or per marker, defaults to white.
    :type color: str, 3-tuple, 4-tuple, np.array with same shape as `points`, optional

    :param opacity: The translucencies of the plots, 0 is invisible, 1 is solid, defaults to solid.
    :type opacity: float, np.array, optional

    :param radius: The radius of each marker, defaults to 1.0.
    :type radius: float, np.array, optional

    :param use_cursors: If false use spheres, if true use cursors, defaults to False.
    :type use_cursors: bool, optional

    :param fig: The figure to plot into, can be None, defaults to vpl.gcf().
    :type fig: vpl.figure, vpl.QtFigure, optional


    :return: The marker or an array of markers.
    :rtype: vtkplotlib.plots.Scatter.Sphere or vtkplotlib.plots.Scatter.Cursor or np.array


    """

    points = np.asarray(points)
    out = np.empty(points.shape[:-1], object)
    out_flat = out.ravel()
    for (i, (xyz, c, r)) in enumerate(zip(_iter_points(points),
                                         _iter_colors(color, points.shape[:-1]),
                                         _iter_scalar(radius, points.shape[:-1]))):

        if use_cursors:
            cls = Cursor
        else:
            cls = Sphere

        out_flat[i] = cls(xyz, c, opacity, r, fig)

    if out.ndim:
        return out
    else:
        return out_flat[0]



def test():
    import vtkplotlib as vpl

    points = np.random.uniform(-10, 10, (30, 3))

#    for i in range(3):
#        self = vpl.cursor(np.array([5, 0, 0]) * i, radius=4)

    colors = vpl.colors.normalise(points)
    radii = np.abs(points[:, 0]) ** .5

    vpl.scatter(points,
                color=colors,
                radius=radii,
                use_cursors=False,
                )[0]
    self = vpl.scatter(points,
                       color=colors,
                       radius=radii,
                       use_cursors=True)[0]
#    self.point += np.array([10, 0, 0])

    globals().update(locals())

    vpl.show()


if __name__ == "__main__":
     test()


