module SimpleRayTracer using LinearAlgebra using StaticArrays using Images using GeometryBasics using Rotations export OrthogonalCamera export World, Ray export Sphere abstract type AbstractCamera end abstract type AbstractObject end struct OrthogonalCamera{T<:AbstractFloat} <: AbstractCamera origin::Vec3{T} size::Tuple{T, T} rotation::RotXYZ{T} end function get_rays(camera::OrthogonalCamera, resolution::Tuple{I, I}) where {I<:Integer} R = camera.rotation X, Y = camera.size Nx, Ny = resolution screen_pixel_position = [R*Vec3(x, y, 0.0) + camera.origin for y = LinRange(Y/2, -Y/2, Ny), x = LinRange(-X/2, X/2, Nx)] [Ray(p, R*Vec3(0.0, 0.0, 1.0)) for p in screen_pixel_position] end struct PinHoleCamera{T<:AbstractFloat} <: AbstractCamera origin::Vec3{T} distance::T base::SMatrix{3, 3, T} size::Tuple{T, T} end function PinHoleCamera(origin, lookAt, up, distance, size) w = origin - lookAt |> normalize u = up × w |> normalize v = w × u B = [u v w] PinHoleCamera(origin, distance, B, size) end function get_rays(camera::PinHoleCamera, resolution::Tuple{I, I}) where {I<:Integer} origin = camera.origin d = camera.distance X, Y = camera.size B = camera.base Nx, Ny = resolution [Ray(origin, B*Vec3(x, y, -d)) for y = LinRange(Y/2, -Y/2, Ny), x = LinRange(-X/2, X/2, Nx)] end struct Ray{T<:AbstractFloat} origin::Vec3{T} direction::Vec3{T} function Ray(o::Vec3{T}, d::Vec3{T}) where {T<:AbstractFloat} new{T}(o, d |> normalize) end end struct HitInfo hit::Bool color::RGB{N0f8} end struct World objects::Vector background::RGB{N0f8} function World(color) new([], color) end end Base.push!(w::World, x::AbstractObject) = push!(w.objects, x) struct Sphere{T<:AbstractFloat} <: AbstractObject origin::Vec3{T} radius::T color::RGB{N0f8} end function check_hit(ray::Ray{T}, sphere::Sphere{T})::Tuple{Bool, T} where {T <: AbstractFloat} p = ray.origin - sphere.origin a = ray.direction ⋅ ray.direction b = 2 * (p ⋅ ray.direction) c = p ⋅ p - sphere.radius^2 Δ = b^2 - 4a*c Δ < 0.0 && return false, -1.0 x, z = -b/(2a), √Δ/(2a) t = x - z t = t < eps() ? x + z : t return t < eps() ? (false, -1.0) : (true, t) end function trace_ray(world::World, ray::Ray)::HitInfo mindist = Float64 |> typemax |> prevfloat hit = false color = world.background for obj in world.objects hit, dist = check_hit(ray, obj) if hit && dist < mindist mindist = dist hit = true color = obj.color end end return HitInfo(hit, color) end function render(world, camera, resolution) rays = get_rays(camera, resolution) [trace_ray(world, ray).color for ray in rays]; end end