/*
Licensed to the Apache Software Foundation (ASF) under one
or more contributor license agreements. The ASF licenses this
file to you under the Apache License, Version 2.0 (the
"License"); you may not use this file except in compliance
with the License.  You may obtain a copy of the License at

  http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing,
software distributed under the License is distributed on an
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, either express or implied.  See the License for the
specific language governing permissions and limitations
under the License.   
*/

#include "max.h"
#include "primitiveMoving.h"

using namespace VR;

void BerconMetaballPrimitiveMoving::getField(BVHMStep* cur, VR::real time, FieldMoving* field, VR::Vector* pointCloud[2], VR::Ireal* sizes[2], bool linear) {
	if (radii[0])
		cur->cont->getField(pointCloud, sizes, time, field, linear);
	else 
		cur->cont->getField(pointCloud, time, field);
}

void BerconMetaballPrimitiveMoving::getAverageResult(FieldMovingAvg &field, Vector &pos) {
	float radius = radii[0] ? field.totalR/field.totalW : blobSize;
	if (field.totalW < blobSize)
		radius *= smoothstep1(field.totalW / trim, cutoff);
	field.res = radius - (pos - field.totalP / field.totalW).length();
}

void BerconMetaballPrimitiveMoving::getAverageResult(FieldGradMovingAvg &grad) {	
	for (int i=0; i<4; i++) {
		float radius = radii[0] ? grad.totalR[i]/grad.totalW[i] : blobSize;
		if (grad.totalW[i] < blobSize)
			radius *= smoothstep1(grad.totalW[i] / trim, cutoff);
		grad.res[i] = radius - (grad.location[i] - grad.totalP[i] / grad.totalW[i]).length();
	}
}

bool BerconMetaballPrimitiveMoving::intersectTraditional(Vector rO, Vector rD, real time, BVHMStep* &step, Ireal stepLength, Ireal maxError, Ireal &dist, Vector* pointCloud[2], Ireal* sizes[2]) {
	BerconShadeContext sc;
	Ireal textureVal;
	BVHMStep* cur;
	BVHMStep* prev;

	// Calculate if we are inside or outside the field
	Vector pos = rO+rD*dist;
	FieldMoving field = FieldMoving(pos, function, blobSize2);
	cur = step;
	prev = NULL;
	Ireal distLate = dist;
	while (cur != NULL && dist > cur->min) {		
		if (distLate > cur->max) {
			if (cur == step)
				step = step->next;
			else {						
				prev->next = cur->next;
				cur = prev;
			}
			if (step == NULL) return false;
			cur = cur->next;
			continue;
		}
		if (dist < cur->max) {
			getField(cur, time, &field, pointCloud, sizes, false);
		}
		prev = cur;
		cur = cur->next;
	}		

	if (tex) applyTexmap(sc, pos, field.res);
	bool insideField = field.res > blobThreshold;
	
	// Find intersection along the ray
	if (insideField) { // Inside					
		while (step != NULL) {		
			pos = rO+rD*dist;
			if (tex) textureVal = getTexmap(sc, pos);
			field = FieldMoving(pos, function, blobSize2);			
			cur = step;
			prev = NULL;		
			while (cur != NULL && dist > cur->min) {			
				if (distLate > cur->max) {
					if (cur == step)
						step = step->next;
					else {						
						prev->next = cur->next;
						cur->deleteThisOnly();
						cur = prev;
					}
					if (step == NULL) return false;
					cur = cur->next;
					continue;
				}
				if (dist < cur->max) {
					getField(cur, time, &field, pointCloud, sizes, false);
					Ireal fieldRes = field.res;
					if (tex) applyTexmap(textureVal, field.res);
					if (testThreshold(field.res, false))
						break;
					field.res = fieldRes; // Return to original value before applying texture
				}
				prev = cur;
				cur = cur->next;
			}
			if (tex) applyTexmap(textureVal, field.res);
			if (testThreshold(field.res, insideField))
				break;
			distLate = dist;
			dist += stepLength;
		}
	} else { // Outside		
		while (step != NULL) {		
			pos = rO+rD*dist;
			field = FieldMoving(pos, function, blobSize2);
			cur = step;
			prev = NULL;		
			while (cur != NULL && dist > cur->min) {
				// Optimize the ray BVHRay on the fly, however do it one step late because
				// we might need to use the step information when searching the exact hit.				
				if (distLate > cur->max) {
					if (cur == step)
						step = step->next;
					else {						
						prev->next = cur->next;
						cur->deleteThisOnly();
						cur = prev;
					}
					if (step == NULL) return false;
					cur = cur->next;
					continue;
				}
				if (dist < cur->max) {
					getField(cur, time, &field, pointCloud, sizes, false);
				}
				prev = cur;
				cur = cur->next;
			}
			if (tex) applyTexmap(sc, pos, field.res);
			if (testThreshold(field.res, insideField))
				break;
			if (dist < step->min) {
				dist = step->min;
				distLate = dist;
			} else {
				distLate = dist;
				dist += stepLength;
			}
		}
	}
	
	if (step == NULL) return false; // No hit with the field

	// Confirmed hit, now we need to achieve desired maximum error	
	Ireal error = stepLength / 2.;
	dist -= error;		
	while (error > maxError) {
		pos = rO+rD*dist;
		field = FieldMoving(pos, function, blobSize2);
		cur = step;
		while (cur != NULL && dist > cur->min) {			
			if (dist < cur->max) {
				getField(cur, time, &field, pointCloud, sizes, false);
			}
			cur = cur->next;
		}
		error /= 2.;
		if (tex) applyTexmap(sc, pos, field.res);
		if (testThreshold(field.res, insideField))
			dist -= error;
		else
			dist += error;
	}
	return true;
}

bool BerconMetaballPrimitiveMoving::intersectAverage(Vector rO, Vector rD, real time, BVHMStep* &step, Ireal stepLength, Ireal maxError, Ireal &dist, Vector* pointCloud[2], Ireal* sizes[2]) {
	BerconShadeContext sc;
	BVHMStep* cur;
	BVHMStep* prev;

	// Calculate if we are inside or outside the field
	Vector pos = rO + rD * dist;
	FieldMovingAvg field = FieldMovingAvg(pos, function, blobSize2);
	cur = step;
	prev = NULL;
	Ireal distLate = dist;
	while (cur != NULL && dist > cur->min) {		
		if (distLate > cur->max) {
			if (cur == step)
				step = step->next;
			else {						
				prev->next = cur->next;
				cur = prev;
			}
			if (step == NULL) return false;
			cur = cur->next;
			continue;
		}
		if (dist < cur->max) {
			getField(cur, time, &field, pointCloud, sizes, true);
		}
		prev = cur;
		cur = cur->next;
	}		
	getAverageResult(field, pos);
	if (tex) applyTexmap(sc, pos, field.res);

	bool insideField = field.res > blobThreshold;
	
	// Find intersection along the ray
	if (insideField) { // Inside					
		while (step != NULL) {		
			pos = rO + rD * dist;
			field = FieldMovingAvg(pos, function, blobSize2);
			cur = step;
			prev = NULL;		
			while (cur != NULL && dist > cur->min) {			
				if (distLate > cur->max) {
					if (cur == step)
						step = step->next;
					else {						
						prev->next = cur->next;
						cur->deleteThisOnly();
						cur = prev;
					}
					if (step == NULL) return false;
					cur = cur->next;
					continue;
				}
				if (dist < cur->max)
					getField(cur, time, &field, pointCloud, sizes, true);
				prev = cur;
				cur = cur->next;
			}
			getAverageResult(field, pos);
			if (tex) applyTexmap(sc, pos, field.res);
			if (testThreshold(field.res, insideField))
				break;
			distLate = dist;
			dist += stepLength;
		}
	} else { // Outside		
		while (step != NULL) {		
			pos = rO + rD * dist;
			field = FieldMovingAvg(pos, function, blobSize2);
			cur = step;
			prev = NULL;		
			while (cur != NULL && dist > cur->min) {
				// Optimize the ray BVHRay on the fly, however do it one step late because
				// we might need to use the step information when searching the exact hit.				
				if (distLate > cur->max) {
					if (cur == step)
						step = step->next;
					else {						
						prev->next = cur->next;
						cur->deleteThisOnly();
						cur = prev;
					}
					if (step == NULL) return false;
					cur = cur->next;
					continue;
				}
				if (dist < cur->max) {
					getField(cur, time, &field, pointCloud, sizes, true);
				}
				prev = cur;
				cur = cur->next;
			}
			getAverageResult(field, pos);
			if (tex) applyTexmap(sc, pos, field.res);
			if (testThreshold(field.res, insideField))
				break;
			if (dist < step->min) {
				dist = step->min;
				distLate = dist;
			} else {
				distLate = dist;
				dist += stepLength;
			}
		}
	}
	
	if (step == NULL) return false; // No hit with the field

	// Confirmed hit, now we need to achieve desired maximum error	
	Ireal error = stepLength / 2.;
	dist -= error;		
	while (error > maxError) {
		pos = rO + rD * dist;
		field = FieldMovingAvg(pos, function, blobSize2);
		cur = step;
		while (cur != NULL && dist > cur->min) {			
			if (dist < cur->max) {
				getField(cur, time, &field, pointCloud, sizes, true);
			}
			cur = cur->next;
		}
		error /= 2.;		
		getAverageResult(field, pos);
		if (tex) applyTexmap(sc, pos, field.res);
		if (testThreshold(field.res, insideField))
			dist -= error;
		else
			dist += error;
	}
	return true;
}

int BerconMetaballPrimitiveMoving::intersect(RSRay &ray) {	
	real dirLength = ray.dir.length();
	Ireal stepLength = this->stepLength / dirLength;
	Ireal maxError = this->maxError / dirLength;			
	Ireal dist = ray.cmint;
		  
	//if (ray.skipTag == this)
		dist += maxError * 2.;

	// Find BVH
	int bvhID = (int)(ray.time / (1.f / (float)steps));
	if (bvhID >= steps) bvhID = steps-1;

	// Select correct paricle clouds	
	Vector* curBlobs[2];
	Ireal* curRadii[2];
	curBlobs[0] = blobs[bvhID*2]; curBlobs[1] = blobs[bvhID*2+1];
	curRadii[0] = radii[bvhID*2]; curRadii[1] = radii[bvhID*2+1];

	// Rescale time
	float time = (ray.time - times[bvhID]) / timeInterval;	

	 // Infinite range, not cmint/cmaxt because it caused black "sliceplane" along x or y or x axis
	IRay tRay = IRay(ray.p, ray.dir);
	BVHMRay bRay = BVHMRay();	
	bvh[bvhID].traceRay(&tRay, &bRay, time);
	
	if (bRay.first == NULL) // No intersection with the bounding volume
		return false;	
	
	BVHMStep *step = bRay.first; // Starting point when we start computing normal
	if (useAverage) {
		if (!intersectAverage(ray.p, ray.dir, time, step, stepLength, maxError, dist, curBlobs, curRadii))
			return false; // No intersection found
	} else {
		if (!intersectTraditional(ray.p, ray.dir, time, step, stepLength, maxError, dist, curBlobs, curRadii))
			return false; // No intersection found	
	}
		
	if (dist>ray.is.t) return false;
	if (dist<=ray.cmint) return false;	
	if (dist>=ray.cmaxt) return false;

	// Calculate normal, I'm not 100% sure its faster to compute normal here than to do it seperately, here we have precomputed positions b
	// but we also have to compute normals when its not necessary, for example shadow rays and occlusion rays

	Vector N, col;
	BerconShadeContext sc;

	if (useAverage) {
		FieldGradMovingAvg grad = FieldGradMovingAvg(colors);		
		grad.function = function; grad.function2 = function2; grad.size2 = blobSize2; 
		computeNormalLocations(grad.location, ray.p + ray.dir * dist);
		while (step != NULL && dist > step->min) {
			if (dist < step->max)
				if (curRadii[0])
					step->cont->getFieldGrad(curBlobs, curRadii, ray.time, &grad, true);
				else
					step->cont->getFieldGrad(curBlobs, ray.time, &grad);
			step = step->next;
		}
		getAverageResult(grad);
		if (tex)
			for (int i=0; i<4; i++)
				applyTexmap(sc, grad.location[i], grad.res[i]);
		computeNormalResult(grad.res);
		N = Vector(grad.res[1], grad.res[2], grad.res[3]);
		col = grad.col / grad.colr;
	} else {
		FieldGradMoving grad = FieldGradMoving(colors);		
		grad.function = function; grad.function2 = function2; grad.size2 = blobSize2;
		computeNormalLocations(grad.location, ray.p + ray.dir * dist);
		while (step != NULL && dist > step->min) {
			if (dist < step->max)
				if (curRadii[0])
					step->cont->getFieldGrad(curBlobs, curRadii, ray.time, &grad, false);
				else
					step->cont->getFieldGrad(curBlobs, ray.time, &grad);
			step = step->next;
		}
		if (tex)
			for (int i=0; i<4; i++)
				applyTexmap(sc, grad.location[i], grad.res[i]);
		computeNormalResult(grad.res);
		N = Vector(grad.res[1], grad.res[2], grad.res[3]);
		col = grad.col / grad.colr;
	}		

	// Set intersection
	MyCacheStruct cacheStruct;
	cacheStruct.rO = -N; // Cache normal
	cacheStruct.col = col;

	ray.is.t = dist;
	ray.is.primitive = (GenericPrimitive*) this;
	ray.is.skipTag = this;
	
	raycache->putCache(*(VR::VRayContext*) ray.rayparams, cacheStruct);

	return true;
}

Vector BerconMetaballPrimitiveMoving::getGNormal(RSRay &ray) const {
	MyCacheStruct cacheStruct;
	raycache->getCache(*(VR::VRayContext*) ray.rayparams, cacheStruct);	
	return cacheStruct.rO;
}

/*
Vector BerconMetaballPrimitiveMoving::getGNormal(RSRayRef& rayref) {
	MyCacheStruct cacheStruct;
	raycache->getCache(*(VR::VRayContext*) rayref.getRayParams(), cacheStruct);	
	return cacheStruct.rO;
}
*/

void BerconMetaballPrimitiveMoving::init(MetaballParams params, VR::Vector** points, VR::Ireal** sizes, VR::Vector* colors, int* pointCount, VR::GeometryGenerator *owner, int ownerIndex, VR::RayCache<MyCacheStruct> *rayc) {
	setFunction(params.field, params.color);
	setParams(params);
	
	this->blobs = points;
	this->radii = sizes;
	this->colors = colors;
	this->steps = params.steps;
	this->blobsCount = pointCount;	

	// Compute static bounding box, there is no point working out the dynamic one since BVH handles things as fast as the Renderer would
	bb.init();
	bb.t[0] = 0.;
	bb.t[1] = 1.;
	for (int s=0;s<steps;s++) // Compute starting states only, end states are duplicates
		for (int i=0;i<blobsCount[s];i++)
			if (radii[0])
				bb.b[0].addExpanded(blobs[s*2][i], radii[s*2][i]);
			else
				bb.b[0].addExpanded(blobs[s*2][i], blobSize);
	for (int i=0;i<blobsCount[steps-1];i++)// Except the final state
		if (radii[0])
			bb.b[0].addExpanded(blobs[(steps-1)*2][i]+blobs[(steps-1)*2+1][i], radii[(steps-1)*2][i]+radii[(steps-1)*2+1][i]);
		else
			bb.b[0].addExpanded(blobs[(steps-1)*2][i]+blobs[(steps-1)*2+1][i], blobSize); 
	bb.b[1] += Vector(0,0,0);
	

	// Calculate BVHs, one for each step	
	this->bvh = new BoundingVolumeHierarchyMoving[steps];
	for (int s=0;s<steps;s++) {
		Vector* curBlobs[2];
		Ireal* curRadii[2];
		curBlobs[0] = blobs[s*2]; curBlobs[1] = blobs[s*2+1];
		curRadii[0] = radii[s*2]; curRadii[1] = radii[s*2+1];

		this->bvh[s].create(curBlobs, curRadii, blobsCount[s], (real)blobSize, params.depth, params.leafSize, blobSize*params.leafLength);
	}

	// Precompute time intervals
	this->times = new real[steps];
	times[0] = 0.f;
	for (int s=1;s<steps;s++)
		times[s] = (float)s / (float)steps;
	timeInterval = 1.f / (float)steps;

	this->owner=owner;
	this->ownerIndex=ownerIndex;
	raycache=rayc;
}