## make 3D interactive scatterplot, then interpolate with
## 2D regression plane; adapted from:
## https://stackoverflow.com/questions/38331198/add-regression-plane-to-3d-scatter-plot-in-plotly


library(plotly)
library(reshape2)

## load 'mtcars' data, see ?mtcars
data(mtcars)

## set variable "auto or manual gear" ('am') 
mtcars$am[which(mtcars$am == 0)] <- 'Automatic'
mtcars$am[which(mtcars$am == 1)] <- 'Manual'
mtcars$am <- as.factor(mtcars$am)

## make graph object (nothing appears!)
p <- plot_ly(mtcars, x = ~wt, y = ~hp, z = ~mpg, color = ~am, colors = c('#BF382A', '#0C4B8E')) %>%
  add_markers() %>%
  layout(scene = list(xaxis = list(title = 'Weight'),
                     yaxis = list(title = 'Gross horsepower'),
                     zaxis = list(title = 'Miles per gallon')))
## render graph
p  # equivalent to > print(p)

## Optional: create a shareable link to your chart
## set up API credentials: https://plot.ly/r/getting-started
#chart_link = api_create(p, filename="scatter3d-basic")
#chart_link

## make regression plane
cars_lm <- lm(mpg ~ hp + wt, data=mtcars)

## set graph resolution
graph_reso <- 0.05

## setup axes
axis_x <- seq(min(mtcars$wt), max(mtcars$wt), by = graph_reso)
axis_y <- seq(min(mtcars$hp), max(mtcars$hp), by = graph_reso)

## sample points over 2D grid in the (wt, hp) space
cars_lm_surface <- expand.grid(wt = axis_x, hp = axis_y, KEEP.OUT.ATTRS = F)
## make corresponding fitted values
cars_lm_surface$mpg <- predict.lm(cars_lm, newdata = cars_lm_surface)
cars_lm_surface <- acast(cars_lm_surface, hp ~ wt, value.var = "mpg") #y ~ x

## colours for points, accroding to 'am'
hcolors=c("red","blue")[mtcars$am]

## make 3D plot
cars_plot <- plot_ly(mtcars, 
                     x = ~wt, 
                     y = ~hp, 
                     z = ~mpg,
                     #text = Species, 
                     type = "scatter3d", 
                     mode = "markers",
                     marker = list(color = hcolors))

## add regression plane
cars_plot <- add_trace(p = cars_plot,
                       z = cars_lm_surface,
                       x = axis_x,
                       y = axis_y,
                       type = "surface")

## render plot
cars_plot