Archive

Archive for the ‘Code Samples’ Category

Static Allocator

August 28th, 2010 admin No comments

This is a simple static allocator I wrote to do block allocations of a fixed size. The fixed size allowed for very quick allocations and deallocations and enabled cache coherency (when objects of the same type were used, especially with components).

I did not modify this in any way before posting, except to remove commented out code blocks. Anything you see here directly corresponds to my own personal coding style and standards.

/******************************************************************************
  Author: Trevor Sundberg
    Date: 03/12/2008

 Purpose:
          An allocator that allocates objects in fixed size blocks.
          All content © 2009 DigiPen (USA) Corporation, all rights reserved.
******************************************************************************/


// Includes
#include "stdafx.h"
#include "static_allocator.h"

// Creates the ObjectManager per the specified values
StaticAllocator::StaticAllocator(unsigned object_size, unsigned objects_per_page)
{
  // Store the object size and objects per page
  mObjectSize     = object_size;
  mObjectsPerPage = objects_per_page;

  // Pre-calculate the size of a single page (+ 4 bytes at the end for the page list next pointer)
  mPageSize = objects_per_page * mObjectSize + sizeof(void*);

  // Set up the initial stats that don't change
  mStats.mPageSize   = mPageSize;
  mStats.mObjectSize = mObjectSize;

  // Build a new page (there was no previous page, so pass it null)
  char* new_page = BuildPage(NULL);

  // Setup the page list pointer to point at the start of the page
  mPageList = new_page;

  // Setup the free list pointer to point at the next free block
  mFreeList = &new_page[mPageSize - object_size];
}

// Allocates and builds a page
char* StaticAllocator::BuildPage(char* last_page)
{
  // Allocate the memory for a single page
  char *allocated_memory = new char[mPageSize];

  // Initialize the first 4 bytes to null, since it's a new page
  // it should point at the previous page
  *reinterpret_cast<char**>(allocated_memory) = last_page;

  // Store two pointers that will point at the previous free list ptr and the current
  char* last;
  char* current = NULL;

  // Initialize the "free list pointers" on the object
  // Loop through the number of objects that's supposed to be on a page
  for (unsigned i = 0; i < mObjectsPerPage; ++i)
  {
    // Update the last and current
    last = current;
    //                          Offset objects     +   sizeof(void*) (for next page ptr)
    current = &allocated_memory[mObjectSize * i   +   sizeof(void*)];

    // Set the free list pointer to point at the previous node
    *reinterpret_cast<char**>(current) = last;
  }

  // Update the stats to reflect the allocated page
  mStats.mFreeObjects += mObjectsPerPage;
  ++mStats.mPagesInUse;

  // Return the allocated memory as a void pointer
  return allocated_memory;
}

// Destroys the ObjectManager
StaticAllocator::~StaticAllocator()
{
  // Walk the page list and free all the pages
  void* iter = mPageList;

  // Iterate through the page list
  while (iter != NULL)
  {
    // Make a temporary copy of the iterator
    void* temp = iter;

    // Iterate to the next position in the page list
    iter = *reinterpret_cast<void**>(iter);

    // Free the memory through the temporary copy
    delete[] reinterpret_cast<char*>(temp);
  }
}



// Take an object from the free list and give it to the client (simulates new)
void* StaticAllocator::Allocate()
{
  // Update the stats to reflect the allocation
  --mStats.mFreeObjects;
  ++mStats.mObjectsInUse;
  ++mStats.mAllocations;

  // Check to see if we have more objects then the current "max/most", if so set the new "most"
  if (mStats.mObjectsInUse > mStats.mMostObjects)
    mStats.mMostObjects = mStats.mObjectsInUse;

  // Check if the free list is empty, if so add a new page
  if (mFreeList == NULL)
  {
    // Build a new page (give it the last page)
    char* new_page = BuildPage(reinterpret_cast<char*>(mPageList));

    // Iterate through the page list until we reach the end
    mPageList = new_page;

    // Setup the free list pointer to point at the next free block
    mFreeList = &new_page[mPageSize - mObjectSize];
  }

  // Create a temporary pointer that points at the current node in the free list
  void* temp = mFreeList;

  // Remove the node from the free list
  mFreeList = *reinterpret_cast<void**>(mFreeList);

  // Return the temporary pointer
  return temp;
}


// Returns an object to the free list for the client (simulates delete)
// Throws an exception if the the object can't be freed. (Invalid object)
void StaticAllocator::Free(void* object)
{
  // Check for a null pointer first
  if (object == NULL)
    return;

  // Update the stats to reflect the free
  ++mStats.mDeallocations;

  // Set the free list ptr in the freed object to point at the free list
  *reinterpret_cast<void**>(object) = mFreeList;

  // Update the stats to reflect the free
  ++mStats.mFreeObjects;
  --mStats.mObjectsInUse;

  // Add the object to the free list
  mFreeList = object;
}

// returns the statistics for the allocator
SAStats StaticAllocator::GetStats(void) const
{
  // Return the current stats
  return mStats;
}
Categories: Code Samples Tags:

Network Nodes

August 28th, 2010 admin No comments

This is a simple networking class I made to manage UDP ‘connections’ between peers and servers/clients. On part of this I was particularly happy with was our network input simulator. The reason stemmed from us having a heap corruption that wasn’t easily reproducible or deterministic. The network input simulator allowed us to replay back entire games until the point of the crash (and luckily for us, the simulated input caused the game to crash!). We were then able to hunt down the exact reason partially by examining the packets themselves and by debugging while running the simulator. For those curious, the problem came from joining a game late while another player had already attached a graphics object to themselves. When the player removed the object, it would doubly add a graphics component pointer into the component list (by accident) and would cause access of deleted memory (since we only removed the first pointer we found upon deletion of the component).

I did not modify this in any way before posting, except to remove commented out code blocks. Anything you see here directly corresponds to my own personal coding style and standards.

/******************************************************************************
  Author: Trevor Sundberg
    Date: 09/19/2009

 Purpose:
          The network node class implementation.
          All content © 2009 DigiPen (USA) Corporation, all rights reserved.
******************************************************************************/


// Includes
#include <stdio.h>
#include <stdlib.h>
#include "network_node.h"

// Using directives
using namespace Wallaby;

// Specifies how many instances we have (by default we start with zero!)
unsigned NetworkNode::mInstances = 0;

// Defines and macros
#define infinite_loop for(;;)

// Constants for file sections
#define FILE_END_UPDATE         'U'
#define FILE_NEW_PACKET         'P'
#define FILE_NEW_CONNECTION     'C'
#define FILE_NEW_DISCONNECTION  'D'

// Constructor
NetworkNode::NetworkNode(u32         program_id,
                         f32         connection_timeout,
                         const char* debug_name) :
// Initializer list
mDebugName(debug_name)
{
  // Increment the number of instances we have
  ++mInstances;

  // If this is the only instance running, instantiate winsock data
  if (mInstances == 1)
  {
    // Initialize winsock
    WSADATA winsock_data;
    WSAStartup(MAKEWORD(2,2), &winsock_data);
  }

  // Set the default program ID
  SetProgramID(program_id);

  // Makes the class unusable (if some really bad error occurs)
  mUnusable = false;

  // By default we're not bound
  mBound = false;

  // By default, there is a one to one ratio of updates to updates in simulation
  mFrameProcessedCount = 1;

  // No file to write to or read from by default
  mWriteToFile  = 0;
  mReadFromFile = 0;

  // By default, debug mode is off
  mDebugMode = DEBUG_MODE_OFF;

  // Set the default text-out callback
  mTextOutCallback = DefaultTextOut;

  // Set the timeout time
  mTimeoutSeconds = connection_timeout;

  // First create the socket that we'll be using over and over
  mSocket = socket(AF_INET, SOCK_DGRAM, IPPROTO_UDP);

  // If we got an invalid socket, get the error and set the socket to unusable
  if (mSocket == INVALID_SOCKET)
  {
    // Get the error associated
    mError = WSAGetLastError();

    // Show an error, then set the socket to be unusable
    Error("Networking: Unable to create socket (Error: %d)\n", mError);
    mUnusable = true;
    return;
  }

  // Used to set the socket to non blocking mode
  unsigned long non_blocking = 1;

  // Set the socket into non-blocking mode
  mError = ioctlsocket(mSocket, FIONBIO, &non_blocking);

  // If there was an error
  if (mError != 0)
  {
    // Show an error, then set the socket to be unusable
    Error("Networking: Unable to set socket to non-blocking\n");
    mUnusable = true;
    return;
  }
 
  // Used to set the socket to allow broadcasts
  unsigned long allow_broadcasts = 1;

  // Set the socket into broadcast mode
  mError = setsockopt(mSocket, SOL_SOCKET, SO_BROADCAST, (const char*) &allow_broadcasts, sizeof(allow_broadcasts));

  // If there was an error
  if (mError != 0)
  {
    // Show an error, then set the socket to be unusable
    Error("Networking: Unable to set socket to accept broadcasts\n");
    mUnusable = true;
    return;
  }
 
  // Set the receive and send buffer sizes
  int buffer_size = 4194304;
  mError = setsockopt(mSocket, SOL_SOCKET, SO_RCVBUF, (const char*) &buffer_size, sizeof(buffer_size));
  mError = setsockopt(mSocket, SOL_SOCKET, SO_SNDBUF, (const char*) &buffer_size, sizeof(buffer_size));
 

  // If there was an error
  if (mError != 0)
  {
    // Show an error, then set the socket to be unusable
    Error("Networking: Unable to set socket to accept broadcasts\n");
    mUnusable = true;
    return;
  }

}

// Destructor
NetworkNode::~NetworkNode()
{
  // Close the connection
  //Close();

  // If we have a file set to write to, close the handle now
  if (mWriteToFile != 0)
    fclose(mWriteToFile);

  // If we have a file set to read from, close the handle now
  if (mReadFromFile != 0)
    fclose(mReadFromFile);

  // Now actually free the resources
  closesocket(mSocket);

  // Decrement the number of instances we have
  --mInstances;

  // If no more instances are running, cleanup winsock data
  if (mInstances == 0)
    WSACleanup();
}


// Sets the program ID (for identifying connections)
void NetworkNode::SetProgramID(u32 id)
{
  mProgramID = htonl(id);
}

// Sets the timeout time for connections
void NetworkNode::SetTimeoutTime(f32 timeout_seconds)
{
  mTimeoutSeconds = timeout_seconds;
}

// Update the network node
void NetworkNode::Update()
{
  // If we have data to be sent, send it now
  SendQueuedData();

  // Receive data from the socket
  ReceiveData();

  // Update all the connections (drop connections, send heartbeats, etc)
  UpdateConnections();

  // Process any file input data
  ProcessFileData();
}


// Sets a flag to write incoming data to a file
void NetworkNode::WriteSimulationToFile(const char* file)
{
  // Attempt to open the file
  mWriteToFile = fopen(file, "wb");
}

// Simulates network input from a given file
void NetworkNode::SimulateFromFile(const char* file)
{
  // Read data from a file
  mReadFromFile = fopen(file, "rb");
}

// Sets the simulation count
void NetworkNode::SetSimulationsPerFrame(int count)
{
  // The number of updates processed for every frame
  mFrameProcessedCount = count;
}

// Tells the socket to connect to the outside
void NetworkNode::Connect(const char* host_or_ip, port_type port)
{
  // Get the ip address from a string
  ip_type ip = inet_addr(host_or_ip);

  // Check if the ip is valid or not
  if (ip == INADDR_NONE)
  {
    // Get the ip of the given host (if it is a host)
    hostent* host = gethostbyname(host_or_ip);

    // If the host was found
    if (host != 0)
      ip = *(ip_type*) host->h_addr_list[0];
    // Otherwise, exit early and give an error
    else
    {
      // Show an error, then exit out
      Error("Invalid host or IP given to Connect (%s)\n", host_or_ip);
      return;
    }
  }

  // Add a new connection to the list
  Connection* connection = CreateConnection(ip, port);

  // Check if that a connection didn't already exist
  if (connection != 0)
  {
    // Send a connection request
    SendConnectionSYN(ip, port, connection);

    // Set the connection state to be authenticated
    connection->SetState(Connection::STATE_SENT_SYN);
  }
  // Otherwise, a connection existed, show an error
  else
  {
    // Show an error
    Error("Attempt to connect to a host/port that was already connected (%s : %d)\n", host_or_ip, port);
  }
}

// Listen to a given port (returns false if that port is in use)
bool NetworkNode::Listen(port_type port)
{
  // Create the binder
  mBinderInfo.sin_family      = AF_INET;
  mBinderInfo.sin_port        = htons(port);
  mBinderInfo.sin_addr.s_addr = INADDR_ANY;

  // Attempt to bind the socket to the port
  int error = bind(mSocket, (const sockaddr*) &mBinderInfo, sizeof(mBinderInfo));

  // If we got an error, show an error
  if (error == SOCKET_ERROR)
    Error("Unable to bind on port %d\n", port);
  else
    mBound = true;

  // Return true if we had no error
  return (error == 0);
}

    // Closes all connections (safe to call at all times)
void NetworkNode::Disconnect()
{
  mBound = false;
}


// Sends data to a specific connection
void NetworkNode::SendTo(ip_type ip, port_type port, const char* data_in, size_t size, unsigned options)
{
  // Check if the data is null terminated
  if (size == DATA_IS_NULL_TERMINATED)
  {
    // Get the string length of the data
    size = strlen(data_in);
  }

  // Check if the data is below the max size
  if (size < MAXIMUM_PACKET_SIZE)
  {
    // Create a new data-gram without a header
    // The special IP and port will be checked when sending to all
    DataGram data = DataGram(data_in, size, ip, port, 0, options);

    // Add the data-gram to the "to be sent" list
    AddToBeSentDataGram(&data);
  }
  else
  {
    // Show an error that the packet size was too big
    Error("Networking: Packet send size too big "
          "(attempt to send packet of size %d, max is %d)\n",
          size, MAXIMUM_PACKET_SIZE);
  }
}

// Sends data to a specific connection
void NetworkNode::SendToAllExcept(ip_type ip, port_type port, const char* data_in, size_t size, unsigned options)
{
  // Use the auxiliary function to send data to everyone
  SendTo(ip, port, data_in, size, options | SEND_TO_ALL_EXCEPT);
}


// Sends data to all the connections
void NetworkNode::SendToAll(const char* data_in, size_t size, unsigned options)
{
  // Use the auxiliary function to send data to everyone
  SendTo(0, 0, data_in, size, options | SEND_TO_ALL);
}

// Sends a broadcast
void NetworkNode::SendBroadcast(port_type port, const char* data_in, size_t size)
{
  // Use the auxiliary function to send data to everyone
  SendTo(0xFFFFFFFF, port, data_in, size, SEND_TO_NON_CONNECTION);
}

// Sends a raw packet to a particular person that is not a connection
void NetworkNode::SendToRaw(ip_type ip, port_type port, const char* data_in, size_t size)
{
  // Use the auxiliary function to send data to everyone
  SendTo(ip, port, data_in, size, SEND_TO_NON_CONNECTION);
}


// Returns if we have data
bool NetworkNode::HasData()
{
  return mReceivedDataGrams.size() > 0;
}

// Returns the ip of the sender
ip_type NetworkNode::GetNextDataIP()
{
  // Check if we actually have data, and if so return the next size
  if (HasData())
    return mReceivedDataGrams.front().GetAddress();
  // Otherwise, return nothing
  else
    return 0;
}

// Returns the port of the sender
port_type NetworkNode::GetNextDataPort()
{
  // Check if we actually have data, and if so return the next size
  if (HasData())
    return mReceivedDataGrams.front().GetPortNum();
  // Otherwise, return nothing
  else
    return 0;
}

// Gets the size of the next message to receive
// Returns zero if no messages are available
size_t NetworkNode::GetNextDataSize()
{
  // Check if we actually have data, and if so return the next size
  if (HasData())
    return mReceivedDataGrams.front().GetContentsSize();
  // Otherwise, return zero
  else
    return 0;
}

// Receives a message and writes it out to a buffer
void NetworkNode::ReadData(char* out_data)
{
  // Check if we actually have data
  if (HasData())
  {
    // Copy all the data out
    mReceivedDataGrams.front().CopyContentsInto(out_data);

    // Pop the data-gram off the back  
    mReceivedDataGrams.pop_front();
  }
  // For some reason, the user is trying to read data without having it
  else
  {
    // Give an error
    Error("Attempt to read data from the network node when it has none\n");
  }
}

// Receives a message and writes it out to a buffer (only safe if data is guaranteed to be string data)
vector<char> NetworkNode::ReadData()
{
  // Check if we actually have data
  if (HasData())
  {
    // Get the contents and size
    const char* data_ptr  = mReceivedDataGrams.front().GetContents();
    unsigned    data_size = (unsigned) mReceivedDataGrams.front().GetContentsSize();

    // Create a buffer that holds all the data
    vector<char> data(data_ptr, data_ptr + data_size);

    // Pop the data-gram off the back  
    mReceivedDataGrams.pop_front();

    // Return the string data
    return data;
  }
  // For some reason, the user is trying to read data without having it
  else
  {
    // Give an error
    Error("Attempt to read data from the network node when it has none\n");

    // Return an empty buffer
    return vector<char>();
  }
}

// Writes the packet data to a file as a single packet simulation
void NetworkNode::WritePacketToFile(const char* file)
{
  // Attempt to open the file
  FILE* temp_file = fopen(file, "wb");

  // If the file opening was as success
  if (temp_file)
  {
    // Write that we got a packet
    fputc(FILE_NEW_PACKET, temp_file);

    // Get a reference to the current datagram
    DataGram& data = mReceivedDataGrams.front();

    // Get the length of the data
    size_t length = data.GetAllDataSize();

    // Get the ip and port
    ip_type ip      = data.GetAddress();
    port_type port  = data.GetPortNum();

    // Write the ip and port to a file
    fwrite(&ip,   sizeof(ip),   1, temp_file);
    fwrite(&port, sizeof(port), 1, temp_file);

    // Write the length of the data to a file
    fwrite(&length, sizeof(size_t), 1, temp_file);

    // Copy the contents into a buffer
    data.CopyAllDataInto(mBuffer);

    // Write the buffer to a file
    fwrite(mBuffer, length, 1, temp_file);

    // Close the file handle
    fclose(temp_file);
  }
}


// Checks if we have any new connections
bool NetworkNode::HasNewConnection()
{
  return mNewConnections.size() > 0;
}

// Gets the most recent connection ip
ip_type NetworkNode::GetNextConnectionIP()
{
  return mNewConnections.back().ip;
}

// Gets the most recent connection port (and pops the connection)
port_type NetworkNode::GetNextConnectionPort()
{
  return mNewConnections.back().port;
}

// Pops the new connection
void NetworkNode::PopNextConnection()
{
  mNewConnections.pop_back();
}

// Checks if we have any new disconnections
bool NetworkNode::HasNewDisconnection()
{
  return mNewDisconnections.size() > 0;
}

// Gets the most recent disconnection ip
ip_type NetworkNode::GetNextDisconnectionIP()
{
  return mNewDisconnections.back().ip;
}

// Gets the most recent disconnection port
port_type NetworkNode::GetNextDisconnectionPort()
{
  return mNewDisconnections.back().port;
}

// Pops the new disconnection
void NetworkNode::PopNextDisconnection()
{
  mNewDisconnections.pop_back();
}


// Checks if the client is connected
bool NetworkNode::IsConnected()
{
  // Loop through all the connections
  for (ConnectIter it = mConnections.begin(); it != mConnections.end(); ++it)
  {
    // If we found a single authenticated connection, return true
    if (it->second.GetState() == Connection::STATE_AUTHENTICATED)
      return true;
  }

  // Since we got here, we found no authenticated connections
  return false;
}

// Checks if the client is listening
bool NetworkNode::IsListening()
{
  return mBound;
}

// Checks if a client connection is closed
bool NetworkNode::IsClosed()
{
  // If we're not connected and not listening, consider it closed
  return IsConnected() == false && IsListening() == false;
}

// Sets the specific debug mode
void NetworkNode::SetDebugMode(DebugMode mode)
{
  mDebugMode = mode;
}

// Sets the specific debug mode
void NetworkNode::SetDebugMode(unsigned mode)
{
  mDebugMode = (DebugMode) mode;
}

// Raise an error
void NetworkNode::Error(const char* format, ...)
{
  // If debug mode is not set to show warnings, don't show them
  if (mDebugMode < DEBUG_MODE_ERRORS)
    return;

  // Create a character buffer
  char buffer[1024];

  // Print out the debug name
  sprintf(buffer, "%s: ", mDebugName.c_str());

  // Create a va list then print the data
  va_list  args;
  va_start(args, format);
  vsprintf(buffer, format, args);
  va_end  (args);

  // Send the buffer to the text out callback
  mTextOutCallback(buffer, NET_MESSAGE_ERROR);
}

// Show a warning
void NetworkNode::Warning(const char* format, ...)
{
  // If debug mode is not set to show warnings, don't show them
  if (mDebugMode < DEBUG_MODE_WARNINGS_ERRORS)
    return;

  // Create a character buffer
  char buffer[1024];

  // Print out the debug name
  sprintf(buffer, "%s: ", mDebugName.c_str());

  // Create a va list then print the data
  va_list  args;
  va_start(args, format);
  vsprintf(buffer, format, args);
  va_end  (args);

  // Send the buffer to the text out callback
  mTextOutCallback(buffer, NET_MESSAGE_WARNING);
}

// Show a message
void NetworkNode::Message(const char* format, ...)
{
  // If debug mode is not set to show messages, don't show them
  if (mDebugMode < DEBUG_MODE_ALL)
    return;

  // Create a character buffer
  char buffer[1024];

  // Print out the debug name
  sprintf(buffer, "%s: ", mDebugName.c_str());

  // Create a va list then print the data
  va_list  args;
  va_start(args, format);
  vsprintf(buffer, format, args);
  va_end  (args);

  // Send the buffer to the text out callback
  mTextOutCallback(buffer, NET_MESSAGE_GENERAL);
}

// Set an text callback function
void NetworkNode::SetTextOutCallback(TextOutCB callback)
{
  mTextOutCallback = callback;
}

// Receives data from the socket
void NetworkNode::ReceiveData()
{
  // Get the size of the receive socket info
  int size_of_sockaddr = sizeof(mReceiveInfo);

  // Receive the buffer
  size_t length = 0;

  // A counter to count how many packets we receive
  unsigned packet_counter = 0;

  // Loop until one of the quitting conditions happens below
  infinite_loop
  {
    // Receive the latest packet and get the length of the data
    length = recvfrom(mSocket, (char*) mBuffer, MAXIMUM_PACKET_SIZE, 0, (sockaddr*) &mReceiveInfo, &size_of_sockaddr);

    // Get the ip from the current packet
    ip_type ip = mReceiveInfo.sin_addr.s_addr;

    // Get the port from the current packet
    port_type port = ntohs(mReceiveInfo.sin_port);

    // Check if the socket had an error
    if (length == SOCKET_ERROR)
    {
      // Now grab the last error value
      mError = WSAGetLastError();

      // Based on the error, give meaningful error messages
      switch (mError)
      {
      // Do nothing if the connection wasn't bound or we got the would-block error
      case WSAEWOULDBLOCK:
      case WSAEINVAL:
        {
          // Break out (we're done!)
          return;
        }

      // In the case of this error, most likely the connection was forcibly closed
      case WSAECONNRESET:
        {
          // Show an error and the IP address
          Error("Connection to (%s) forcibly closed\n", IPToString(ip));

          // Remove the connection
          RemoveConnection(mConnections.find(IPPort(ip, port)));
          return;
        }

      // Some error case that we didn't know about
      default:
        {
          // Show an error then break out
          Error("Unknown error in ReceiveData (Error: %d)\n", mError);
          return;
        }

      } // End error switch
    }


    // Check if the length of the packet is less than the minimum size
    if (length < MINIMUM_PACKET_SIZE)
    {
      // Show an error then skip this data-gram
      Error("Received data-gram smaller than the minimum packet size\n");
      continue;
    }


    // Show a message that we received a packet
    Message("Received %d bytes of data from %s on port %d...\n",
      length, IPToString(ip), port);

    // Create a new data-gram
    DataGram data = DataGram(mBuffer, length, ip, port, sizeof(PacketHeader), 0);

    // Process the received data
    ProcessReceivedData(&data);


    // Count up the number of packets we've got
    ++packet_counter;

    // If the number of packets we've gotten is equal to the max...
    if (packet_counter == MAXIMUM_RECEIVED_PER_CYCLE)
    {
      // Show an error then return out
      Error("Received too many packets in a single cycle\n");
      break;
    }

  } // End loop
}

// Process all the queued received data
void NetworkNode::ProcessReceivedData(DataGram* data)
{
  // For convenience, extract the ip and port
  ip_type   ip    = data->GetAddress();
  port_type port  = data->GetPortNum();

  // Get the header from the data (this is safe because we should have checked size by here)
  PacketHeader* header = (PacketHeader*) data->GetHeader();

  // Immediately try and grab a connection with the given ip and port
  Connection* connection = GetConnection(ip, port);


  // If the connection is valid...
  if (connection != 0)
  {
    // Set the last received sequence number
    connection->AddRemoteSequenceNum(header->sequence_num);

    // Set the last received packet time (now basically)
    connection->SetLastPacketTime(clock() / (f32) CLOCKS_PER_SEC);
  }


  // Check if the program ID is valid
  if (header->program_id != mProgramID)
  {
    // Show an error then skip this data-gram
    Error("Received data-gram with invalid program ID\n");
    return;
  }


  // Handle acking and sequence number updating
  HandleMissedPackets(data, connection);


  // Check if the packet is of type SYN
  if (header->packet_type == PacketHeader::TYPE_CONNECT_SYN)
  {
    // Check if that there isn't already a connection
    if (connection == 0)
    {
      // Show a message that we received a connection request
      Message("Received connection SYN\n");

      // Add the connection
      Connection* connection = CreateConnection(ip, port);

      // Reply to the message
      SendConnectionSYNACK(ip, port, connection);

      // Set the connection state to be authenticated
      connection->SetState(Connection::STATE_RECEIVED_SYN_SENT_SYNACK);
    }
    else
    {
      // Show a message that we received a duplicate connection request
      Message("Received duplicate connection SYN\n");
    }

    // Return, since we've handled it
    return;
  }

  // Check if the packet is of type SYNACK
  if (header->packet_type == PacketHeader::TYPE_CONNECT_SYNACK)
  {
    // Check if the connection actually exists
    if (connection != 0)
    {
      // Check if the state of the connection is in the handshake sent
      // (Make sure that it should be receiving this packet)
      if (connection->GetState() == Connection::STATE_SENT_SYN)
      {
        // Show a message that we received a connection ack
        Message("Received connection SYNACK\n");

        // Reply to the message
        SendConnectionACK(ip, port, connection);

        // Set the connection state to be authenticated
        connection->SetState(Connection::STATE_AUTHENTICATED);

        // Add the new connection to the queue
        mNewConnections.push_back(IPPort(ip, port));

        // Write this current packet to a file if there is one open
        if (mWriteToFile)
        {
          // Write that we got a new connection
          fputc(FILE_NEW_CONNECTION, mWriteToFile);

          // Write the ip and port to a file
          fwrite(&ip,   sizeof(ip),   1, mWriteToFile);
          fwrite(&port, sizeof(port), 1, mWriteToFile);
        }
      }
      // Otherwise, the packet somehow got misordered...
      else
      {
        // Show a message that we received it out of order
        Message("Received a SYNACK out of order\n");
      }
    }
    // Otherwise, the connection was not found
    else
    {
      // Show a message that we received an unexpected connection ack
      Message("Received a connection SYNACK when no connection was initiated\n");
    }

    // Return, since we've handled it
    return;
  }

  // Check if the packet is of type ACK
  if (header->packet_type == PacketHeader::TYPE_CONNECT_ACK)
  {
    // Check if the connection actually exists
    if (connection != 0)
    {
      // Check if the state of the connection is in the handshake sent
      // (Make sure that it should be receiving this packet)
      if (connection->GetState() == Connection::STATE_RECEIVED_SYN_SENT_SYNACK)
      {
        // Show a message that we received a connection ack
        Message("Received connection ACK\n");

        // Set the connection state to be authenticated
        connection->SetState(Connection::STATE_AUTHENTICATED);

        // Add the new connection to the queue
        mNewConnections.push_back(IPPort(ip, port));

        // Write this current packet to a file if there is one open
        if (mWriteToFile)
        {
          // Write that we got a new connection
          fputc(FILE_NEW_CONNECTION, mWriteToFile);

          // Write the ip and port to a file
          fwrite(&ip,   sizeof(ip),   1, mWriteToFile);
          fwrite(&port, sizeof(port), 1, mWriteToFile);
        }
      }
      // Otherwise, the packet somehow got misordered...
      else
      {
        // Show a message that we received it out of order
        Message("Received a ACK out of order\n");
      }
    }
    // Otherwise, the connection was not found
    else
    {
      // Show a message that we received an unexpected connection ack
      Message("Received a connection ACK when no connection was initiated\n");
    }

    // Return, since we've handled it
    return;
  }

  // Check if the packet is of type HEARTBEAT
  if (header->packet_type == PacketHeader::TYPE_HEARTBEAT)
  {
    // Check if the connection actually exists
    if (connection != 0)
    {
      // Check if the state of the connection is in the handshake sent
      // (Make sure that it should be receiving this packet)
      if (connection->GetState() == Connection::STATE_AUTHENTICATED)
      {
        // Show a message that we received a connection ack
        Message("Received connection HEARTBEAT\n");

        // Set the connection state to be authenticated
        connection->SetState(Connection::STATE_AUTHENTICATED);
      }
      // Otherwise, the packet somehow got misordered...
      else
      {
        // Show a message that we received it out of order
        Message("Received a HEARTBEAT out of order\n");
      }
    }
    // Otherwise, the connection was not found
    else
    {
      // Show a message that we received an unexpected HEARTBEAT
      Message("Received a HEARTBEAT when no connection was initiated\n");
    }

    // Return, since we've handled it
    return;
  }


  // Check if the amount of data we've got is more than the max bufferable packets
  if (mReceivedDataGrams.size() < MAXIMUM_BUFFERED_RECEIVED)
  {
    // Write this current packet to a file if there is one open
    if (mWriteToFile)
    {
      // Write that we got a packet
      fputc(FILE_NEW_PACKET, mWriteToFile);

      // Get the length of the data
      size_t length = data->GetAllDataSize();

      // Write the ip and port to a file
      fwrite(&ip,   sizeof(ip),   1, mWriteToFile);
      fwrite(&port, sizeof(port), 1, mWriteToFile);

      // Write the length of the data to a file
      fwrite(&length, sizeof(size_t), 1, mWriteToFile);

      // Write the buffer to a file
      fwrite(mBuffer, length, 1, mWriteToFile);
    }

    // Create a new data-gram and push it on (this one is for the user!)
    mReceivedDataGrams.push_back(*data);
  }
  // Show an error since the user is not checking the packets
  else
  {
    // Show an error that we've buffered too many received packets
    Error("Maximum buffered receive-packets reached (please read the data out!)\n");
  }
}

// Process any file input data
void NetworkNode::ProcessFileData()
{
  // Check if we should be receiving packets from a file
  if (mReadFromFile != 0 && mFrameProcessedCount != 0)
  {
    // The update counter (how many updates chunks per actual update)
    int update_counter = 0;

    // The file header of each chunk
    int file_header;

    // Loop until we get a frame update
    infinite_loop
    {
      // Read the next file chunk header
      file_header = fgetc(mReadFromFile);

      switch (file_header)
      {
      // We got someone new connecting, tell everyone about it
      case FILE_NEW_CONNECTION:
        {
          // The ip and port of the connection
          ip_type   ip;
          port_type port;

          // Read the ip and port from the file
          fread(&ip,   sizeof(ip),   1, mReadFromFile);
          fread(&port, sizeof(port), 1, mReadFromFile);

          // Inform everone of a new connection
          mNewConnections.push_back(IPPort(ip, port));
        }
        break;

      // We got a new packet, give it to the receive function
      case FILE_NEW_DISCONNECTION:
        {
          // The ip and port of the connection
          ip_type   ip;
          port_type port;

          // Read the ip and port from the file
          fread(&ip,   sizeof(ip),   1, mReadFromFile);
          fread(&port, sizeof(port), 1, mReadFromFile);

          // Inform everone of a new disconnection
          mNewDisconnections.push_back(IPPort(ip, port));
        }
        break;

      // We got a new packet, give it to the receive function
      case FILE_NEW_PACKET:
        {
          // The ip and port of the connection
          ip_type   ip;
          port_type port;

          // Read the ip and port from the file
          fread(&ip,   sizeof(ip),   1, mReadFromFile);
          fread(&port, sizeof(port), 1, mReadFromFile);

          // Holds the length of the data
          size_t length;

          // Read the length of the data from the file
          fread(&length, sizeof(size_t), 1, mReadFromFile);

          // Read the pack/buffer from the file
          fread(mBuffer, length, 1, mReadFromFile);

          // Create a new data-gram
          DataGram data = DataGram(mBuffer, length, ip, port, sizeof(PacketHeader), 0);

          // Process the received data
          ProcessReceivedData(&data);
        }
        break;

      // We reached the end of the file
      case EOF:
        {
          // Close the file and set the handle to null
          fclose(mReadFromFile);
          mReadFromFile = 0;
        }
        break;

      } // End file header switch


      // If we reached a file end update, we need to break out
      if (file_header == FILE_END_UPDATE)
      {
        // Increment the update counter
        ++update_counter;

        // All of the frames updated per process
        if (update_counter == mFrameProcessedCount)
          break;
      }

      // If we reach the end of the file, stop...
      if (file_header == EOF)
        break;

    } // End looping through data
  }

  // Write a new frame to the file if there is a file open
  if (mWriteToFile)
  {
    // Write an end update
    fputc(FILE_END_UPDATE, mWriteToFile);

    // Flush the stream
    fflush(mWriteToFile);
  }
}

// If we have data to be sent, send it now (but respect flow control)
void NetworkNode::SendQueuedData()
{
  // If we have any data-grams to be sent out
  while (mToSendDataGrams.size() > 0)
  {
    // Flow control here

    // Get a pointer to the data gram we're using
    DataGram* data = &mToSendDataGrams.front();

    // Get the options for this data-gram
    unsigned options = data->GetOptions();

    // Get the ip address
    ip_type ip = data->GetAddress();

    // Get the port
    port_type port = data->GetPortNum();


    // Check if the options is set to send to all connections (not broadcast!)
    if (options & SEND_TO_ALL)
    {
      // Loop through all the existing connections
      for (ConnectIter it = mConnections.begin(); it != mConnections.end(); ++it)
      {
        // Create a pointer to the current connection for ease of use
        Connection* connection = &it->second;

        // Create a pointer to the IP/Port interface (for convenience)
        const IPPort* ip_port = &it->first;

        // Modify the data-gram to send to the current connection (ip and port)
        data->SetAddress(ip_port->ip);
        data->SetPortNum(ip_port->port);

        // Print out the data as readable text (for debugging)
        //string readable = ToReadableText(data->GetContents(), data->GetContentsSize());
        //ShowMessage(Engine::MESSAGE_GENERAL, "%s SENT DATA (%d, %d): [%s]",
        //  mDebugName.c_str(), ip_port->ip, ip_port->port, readable.c_str());

        // Perform a raw send to the current connection
        RawSendTo(data, connection);
       
      } // End connection loop
    }
    // Otherwise, look for the send to all except option
    else if (options & SEND_TO_ALL_EXCEPT)
    {
      // Loop through all the existing connections
      for (ConnectIter it = mConnections.begin(); it != mConnections.end(); ++it)
      {
        // Create a pointer to the current connection for ease of use
        Connection* connection = &it->second;

        // Create a pointer to the IP/Port interface (for convenience)
        const IPPort* ip_port = &it->first;

        // If the connection has the same ip and port as specified, skip it
        if (ip_port->ip == ip && ip_port->port == port)
          continue;

        // Modify the data-gram to send to the current connection (ip and port)
        data->SetAddress(ip_port->ip);
        data->SetPortNum(ip_port->port);

        // Perform a raw send to the current connection
        RawSendTo(data, connection);
       
      } // End connection loop
    }
    // Check if it's supposed to be sent to a client that isn't connected
    else if (options & SEND_TO_NON_CONNECTION)
    {
      // Perform a raw send to the connection
      RawSendTo(data, 0);
    }
    // Otherwise, just send it normally if an actual connection exists
    else if (Connection* connection = GetConnection(ip, port))
    {
      // Perform a raw send to the connection
      RawSendTo(data, connection);
    }
    // The connection was somehow dropped or never existed, only show this if we're not simulating
    else if (mReadFromFile == 0)
    {
      // Show an error and the IP address
      Error("Attempt to send to an unknown or dropped connection (%s) on port %d\n",
            IPToString(ip), port);
    }

    // Pop the data-gram off since we sent it
    mToSendDataGrams.pop_front();

  } // End data-gram availability check
}

// Send a raw packet
void NetworkNode::RawSendTo(DataGram* data, Connection* connection)
{
  // Check if the state of the connection is valid
  if (connection != 0 && connection->GetState() != Connection::STATE_AUTHENTICATED && (data->GetOptions() & SEND_AS_CONNECTION_PACKET) == 0)
    return;

  // Check if the data-gram has a header
  if (data->HasHeader() == false)
  {
    // If not, append a generic header to the data
    AppendGenericHeader(data, connection);
  }

  // Create a socket address (used for sending)
  sockaddr_in sending_addr;

  // Setup the socket address (with ip and port)
  sending_addr.sin_family       = AF_INET;
  sending_addr.sin_port         = htons(data->GetPortNum());
  sending_addr.sin_addr.s_addr  = data->GetAddress();

  // Read the data from the current data-gram (the size should be checked already)
  data->CopyAllDataInto(mBuffer);

  // Show a message that we received a connection
  Message("Sending %d bytes of data to %s on port %d...\n",
    data->GetAllDataSize(), IPToString(data->GetAddress()), data->GetPortNum());

  // Send the data gram to the given IP
  mError = sendto(mSocket,
                  mBuffer,
                  (int) data->GetAllDataSize(),
                  NULL,
                  (const sockaddr*) &sending_addr,
                  sizeof(sending_addr));

  // If the connection is valid, we should do somethign with the sequence number
  if (connection != 0)
  {
    // Since we're sending a packet, increment the sequence number
    connection->IncrementSequenceNum();
  }

  // Check if the socket had an error
  if (mError == SOCKET_ERROR)
  {
    // Now grab the last error value
    mError = WSAGetLastError();

    // Based on the error, give meaningful error messages
    switch (mError)
    {
    // Do nothing if we got the would-block error
    case WSAEWOULDBLOCK:
      {
        // Break out (we're done!)
        break;
      }

    // Some error case that we didn't know about
    default:
      {
        // Show an error then break out
        Error("Unknown error in RawSendTo (Error: %d)\n", mError);
        break;
      }

    } // End error switch
  } // End error check
}

// Add a data-gram to be sent
void NetworkNode::AddToBeSentDataGram(DataGram* data)
{
  // Check if we have less than the maximum buffered sent data-grams
  if (mToSendDataGrams.size() < MAXIMUM_BUFFERED_SENT)
  {
    // Add the data-gram to the "to be sent" list
    mToSendDataGrams.push_back(*data);
  }
  // Show an error since the user is sending packets too fast
  else
  {
    // Show an error that we've buffered too many sent packets
    Error("Maximum buffered sent-packets reached "
          "(please reduce the rate at which you are sending packets!)\n");
  }
}


// Handle acking and sequence number updating
void NetworkNode::HandleMissedPackets(DataGram* data, Connection* connection)
{
  // Grab the packet header from the data
  PacketHeader* header = (PacketHeader*) data->GetHeader();

  // Check that the header and connection both exist
  if (header != 0 && connection != 0)
  {
    // Check if the current ack we're on is greater
    // than the (ack+1) gotten from the packet
    if (header->ack + 1 < connection->GetCurrentSequenceNum())
    {
      // Resend the packets that weren't acknowledged
     
    }

  } // End checking the header and connection
}


// Send a connection request
void NetworkNode::SendConnectionSYN(ip_type ip, port_type port, Connection* connection)
{
  // Create the datagram that we'll send
  DataGram data(ip, port, SEND_AS_CONNECTION_PACKET);

  // The packet will mainly consist of the packet header
  PacketHeader header;

  // Initialize the packet header
  InitPacketHeader(&header, connection);

  // Set the packet type
  header.packet_type = PacketHeader::TYPE_CONNECT_SYN;

  // Set the header for the data-gram
  data.SetHeader(&header, sizeof(PacketHeader));

  // Now add it to the "to be sent" list
  AddToBeSentDataGram(&data);

  // Show a message about the sent packet
  Message("Sent connection SYN\n");
}

// Send a reply to a connection request
void NetworkNode::SendConnectionSYNACK(ip_type ip, port_type port, Connection* connection)
{
  // Create the datagram that we'll send
  DataGram data(ip, port, SEND_AS_CONNECTION_PACKET);

  // The packet will mainly consist of the packet header
  PacketHeader header;

  // Initialize the packet header
  InitPacketHeader(&header, connection);

  // Set the packet type
  header.packet_type = PacketHeader::TYPE_CONNECT_SYNACK;

  // Set the header for the data-gram
  data.SetHeader(&header, sizeof(PacketHeader));

  // Now add it to the "to be sent" list
  AddToBeSentDataGram(&data);

  // Show a message about the sent packet
  Message("Sent connection SYNACK\n");
}

// Send an acknowledgment to a connection request
void NetworkNode::SendConnectionACK(ip_type ip, port_type port, Connection* connection)
{
  // Create the datagram that we'll send
  DataGram data(ip, port, SEND_AS_CONNECTION_PACKET);

  // The packet will mainly consist of the packet header
  PacketHeader header;

  // Initialize the packet header
  InitPacketHeader(&header, connection);

  // Set the packet type
  header.packet_type = PacketHeader::TYPE_CONNECT_ACK;

  // Set the header for the data-gram
  data.SetHeader(&header, sizeof(PacketHeader));

  // Now add it to the "to be sent" list
  AddToBeSentDataGram(&data);

  // Show a message about the sent packet
  Message("Sent connection ACK\n");
}

// Send a non-acknowledgment over the connection
void NetworkNode::SendNAK(ip_type ip, port_type port, Connection* connection)
{
  // Create the datagram that we'll send
  DataGram data(ip, port, 0);

  // The packet will mainly consist of the packet header
  PacketHeader header;

  // Initialize the packet header
  InitPacketHeader(&header, connection);

  // Set the packet type
  header.packet_type = PacketHeader::TYPE_NAK;

  // Set the header for the data-gram
  data.SetHeader(&header, sizeof(PacketHeader));

  // Now add it to the "to be sent" list
  AddToBeSentDataGram(&data);

  // Show a message about the sent packet
  Message("Sent connection NAK\n");
}

// Send an heartbeat over the connection
void NetworkNode::SendHEARTBEAT(ip_type ip, port_type port, Connection* connection)
{
  // Create the datagram that we'll send
  DataGram data(ip, port, 0);

  // The packet will mainly consist of the packet header
  PacketHeader header;

  // Initialize the packet header
  InitPacketHeader(&header, connection);

  // Set the packet type
  header.packet_type = PacketHeader::TYPE_HEARTBEAT;

  // Set the header for the data-gram
  data.SetHeader(&header, sizeof(PacketHeader));

  // Now add it to the "to be sent" list
  AddToBeSentDataGram(&data);

  // Show a message about the sent packet
  Message("Sent HEARTBEAT\n");
}


// Create a connection and return a pointer to it
// Returns 0 if a connection with that ip and port already exists
Connection* NetworkNode::CreateConnection(ip_type ip, port_type port)
{
  // Add a new connection to the list
  pair<ConnectIter, bool> it = mConnections.insert(ConnectPair(IPPort(ip, port), Connection()));

  // Check if the connection already existed
  if (it.second == false)
    return 0;

  // Otherwise, return the newly created connection!
  return &(it.first->second);
}

// Get a connection by ip and port
Connection* NetworkNode::GetConnection(ip_type ip, port_type port)
{
  // Attempt to find the connection by the ip and port
  ConnectIter it = mConnections.find(IPPort(ip, port));

  // Check if we found it
  if (it != mConnections.end())
    return &it->second;

  // Otherwise, we didn't find it, return zero
  else
    return 0;
}

// Update all the connections
void NetworkNode::UpdateConnections()
{
  // Loop through all the connections
  for (ConnectIter it = mConnections.begin(); it != mConnections.end();)
  {
    // Create a pointer to the current connection for ease of use
    Connection* connection = &it->second;

    // Create a pointer to the IP/Port interface (for convenience)
    const IPPort* ip_port = &it->first;

    // Check if we should send a heartbeat
    if (connection->GetState() == Connection::STATE_AUTHENTICATED &&
        connection->ShouldSendHeartbeat(mTimeoutSeconds, HEARTBEAT_TIMEOUT_RATIO))
    {
      // Send a heartbeat packet
      SendHEARTBEAT(ip_port->ip, ip_port->port, connection);

      // Tell the connection that we sent a heartbeat
      connection->SentHeartbeat();
    }

    // Check if the connection is dead with the current threshold
    if (connection->IsDead(mTimeoutSeconds))
    {
      // Show a message that a particular connection was dropped
      Message("Connection with IP (%s) and port %d dropped due to timeout\n",
        IPToString(ip_port->ip), ip_port->port);

      // Remove the connection from the list
      it = RemoveConnection(it);
    }
    // The connection wasn't dead, iterate to the next connection
    else
    {
      // Increment the iterator
      ++it;
    }

  } // End connection loop
}

// Remove a connection (count it as a disconnect also)
NetworkNode::ConnectIter NetworkNode::RemoveConnection(ConnectIter iter)
{
  // If we're not using an invalid iterator...
  if (iter != mConnections.end())
  {
    // Mark this as a disconect
    mNewDisconnections.push_back(iter->first);

    // Write this current disconnection to a file if there is one open
    if (mWriteToFile)
    {
      // Write that we got a new disconnection
      fputc(FILE_NEW_DISCONNECTION, mWriteToFile);

      // Write the ip and port to a file
      fwrite(&iter->first.ip,   sizeof(iter->first.ip),   1, mWriteToFile);
      fwrite(&iter->first.port, sizeof(iter->first.port), 1, mWriteToFile);
    }

    // Erase the iterator and return the next one
    return mConnections.erase(iter);
  }

  // Otherwise, return an invalid iterator
  return mConnections.end();
}


// Initializes a packet header
void NetworkNode::InitPacketHeader(PacketHeader* header_to_initialize, Connection* connection)
{
  // Check if we got nothing for the connection
  if (connection != 0)
  {
    // Set the program ID and sequence number
    header_to_initialize->program_id    = mProgramID;
    header_to_initialize->sequence_num  = connection->GetCurrentSequenceNum();
    header_to_initialize->ack           = connection->GetLatestRemoteSequenceNum();
    header_to_initialize->packet_type   = PacketHeader::TYPE_NONE;
  }
  else
  {
    // Set the program ID and sequence number
    header_to_initialize->program_id    = mProgramID;
    header_to_initialize->sequence_num  = 0;
    header_to_initialize->ack           = 0;
    header_to_initialize->packet_type   = PacketHeader::TYPE_NONE;
  }
}

// Appends a generic header to a data-gram
void NetworkNode::AppendGenericHeader(DataGram* append_to, Connection* connection)
{
  // Create a packet header
  PacketHeader header;

  // Initialize the packet header with defaults
  InitPacketHeader(&header, connection);

  // Append the header to the data-gram
  append_to->SetHeader(&header, sizeof(PacketHeader));
}

// Turns packet data into readable text
string NetworkNode::ToReadableText(const char* data, size_t size)
{
  // The output string we'll be returning
  string output;

  // Loop through all the data bytes
  for (unsigned i = 0; i < size; ++i)
  {
    // Check if the character is readable
    if (isprint(data[i]))
    {
      // Append the readable character to the output
      output += data[i];
    }
    else
    {
      // Append a different readable character to the output
      output += '.';
    }
  }
 
  // Return the output string
  return output;
}


// Converts an IP into a string
const char* NetworkNode::IPToString(ip_type ip)
{
  // Fill in the in_addr struct with the given ip
  in_addr addr;
  addr.S_un.S_addr = ip;

  // Use NTOA to convert the ip into a string
  return inet_ntoa(addr);
}

// Default text out callback
void NetworkNode::DefaultTextOut(const char* text, NetMessageType type)
{
  // Based off the type, do something...
  switch (type)
  {
  // If it's a normal message, print it normally
  case NET_MESSAGE_GENERAL:
    {
      // Set the color of the console text
      SetConsoleTextAttribute(GetStdHandle(STD_OUTPUT_HANDLE), 15 /*WHITE*/);

      // Output the text to stdout
      fprintf(stdout, text);
    }
    break;

  // If it's a warning message, print it with a special color
  case NET_MESSAGE_WARNING:
    {
      // Set the color of the console text
      SetConsoleTextAttribute(GetStdHandle(STD_OUTPUT_HANDLE), 14 /*YELLOW*/);

      // Output the text to stdout
      fprintf(stdout, text);
    }
    break;

  // If it's a error message, print it with a special color
  case NET_MESSAGE_ERROR:
    {
      // Set the color of the console text
      SetConsoleTextAttribute(GetStdHandle(STD_OUTPUT_HANDLE), 12 /*RED*/);

      // Output the text to stderr
      fprintf(stderr, text);
    }
    break;
  }
}
Categories: Code Samples Tags:

Minkowski Portal Refinement

August 28th, 2010 admin No comments

This is the implementation I made for MPR.

I did not modify this in any way before posting, except to remove commented out code blocks. Anything you see here directly corresponds to my own personal coding style and standards.

/******************************************************************************
  Author: Trevor Sundberg
    Date: 9/21/2008

 Purpose:
          This file contains all the elements for using the MPR algorithm to
          detect collisions between convex polyhedra, as well as find contact
          information.
          All content © 2009 DigiPen (USA) Corporation, all rights reserved.
******************************************************************************/


// Includes
#include "mpr.h"
#include "matrix4.h"
#include "float.h"
#include "math.h"

// Using directives
using namespace Wallaby;
using namespace Collision;

// Helper defines
#define SECOND_PHASE_ACCURACY   10
#define DEEP_PENETRATION        0.3F
#define NORMALIZATION_EPSILON   0.000000001F
#define DIPSLACE_EPSILON        0.001F
#define INIT_RANDOM_SEED        2457891
#define RANDOM_DISPLACE_RATIO   0.15F
#define INFINITE_LOOP           for(;;)
#define DEBUG_NO_BREAK          ((unsigned) -1)
#define PI                      3.141592653589793238462643F
#define V1                      0
#define V2                      1
#define V3                      2

// Debug defines
//#define MPR_DEBUG_INFO
//#define MPR_ERRORS


// Pre-defined vectors
static const Vector3 X(1.0F, 0.0F, 0.0F);
static const Vector3 Y(0.0F, 1.0F, 0.0F);
static const Vector3 Z(0.0F, 0.0F, 1.0F);
static const Vector3 ORIGIN(0.0F, 0.0F, 0.0F);

// If debug is enabled, define a check debug macro that helps debug drawing
#ifdef MPR_DEBUG_INFO
  #define CheckDebug(stage) if (DebugBreak(stage)) return false;
  #define DebugComplete() DebugDraw(STAGE_CONTACT_COMPLETED)
#else
  #define CheckDebug(stage)
  #define DebugComplete()
#endif

#if defined(MPR_DEBUG_INFO) || defined(MPR_ERRORS)
  // Prevent min/max from being #defined
  #define NOMINMAX
  // Prevent infrequently used stuff
  #define WIN32_LEAN_AND_MEAN
  #include "windows.h"
  #include "../architecture/wallaby.h"
#endif

// Constructor
MPR::MPR()
{
  // By default, don't breakpoint
  debug_breakpoint_ = DEBUG_NO_BREAK;
}

// Tests to see if two convex shapes are intersecting
bool MPR::TestIntersection(const ShapePair& shape_pair,
                           CollisionInfo&   collision_info_out,
                           const Vector3*   collision_direction)
{
  /********************************\
   * Phase #1: Portal Discovery
  \********************************/


  // Initialize for collision detection
  Initialize(&shape_pair, collision_info_out, collision_direction);

  // Find the geometric center that we'll use as a starting point
  FindGeometricCenter();

  // Find the origin ray from an interior point in the CSO
  FindOriginRay();

  // Finds the starting portal that we'll use, THIS MAY FAIL IN SPECIAL CASES!
  FindStartingPortal();

  // This do/while loop is for MTD correct
  do
  {
    // Initialize an iteration counter for this phase to avoid failure states
    InitIterationCounter();

    // Checks for the current debug breakpoints
    CheckDebug(STAGE_PORTAL_DISCOVERY);

    // Loop until we find a valid portal that the origin ray passes through (or we timeout)
    while (!CheckOriginRayIntersectPortal())
    {
      // Check if we timed out
      if (IterationTimedOut(MAX_PORTAL_DISCOVERY_STEPS))
      {
        // Show an error and return no collision
        AssertError("Max iterations reached in Portal Discovery!");
        return false;
      }

      // Count up an iteration
      CountIteration();

      // Choose the new starting portal
      ChooseNewStartingPortal();

      // Checks for the current debug breakpoints
      CheckDebug(STAGE_PORTAL_DISCOVERY);
    }


    /********************************\
    * Phase #2: Portal Refinement
    \********************************/


    // Initialize an iteration counter for this phase to avoid failure states
    InitIterationCounter();

    // Checks for the current debug breakpoints
    CheckDebug(STAGE_ENTERING_STAGE_PORTAL_REFINEMENT);

    // Loop until we either return with a hit or miss
    INFINITE_LOOP
    {
      // Check if we timed out
      if (IterationTimedOut(MAX_PORTAL_REFINEMENT_STEPS))
      {
        // Show an error and return no collision
        AssertError("Max iterations reached in Portal Refinement!");
        return false;
      }

      // Count up an iteration
      CountIteration();

      // If the origin is inside the portal, there's a collision!
      // Break out of the loop if this happens, since we'll continue below
      // Also, this will compute the normal of the portal
      if (CheckOriginInsidePortal())
        break;

      // Find the new supporting point, since we didn't find that it was colliding yet
      FindSupportFromPortalOutNormal();

      // If the origin is outside the supporting plane, no collision
      if (CheckOriginOutsideSupportPlane())
        return false;

      // If the support plane and portal are close enough, terminate
      if (CheckPortalCloseToSupportPlane())
        return false;

      // Since we got here, we didn't find a portal that gave us any information
      // The origin is within the supporting plane, but outside the portal
      // Find a new portal
      ChooseNewPortal();

      // Checks for the current debug breakpoints
      CheckDebug(STAGE_PORTAL_REFINEMENT);
    }


    /********************************\
    * Phase #3: Contact Discovery
    \********************************/


    // Initialize an iteration counter for this phase to avoid failure states
    InitIterationCounter();

    // Checks for the current debug breakpoints
    CheckDebug(STAGE_ENTERING_STAGE_CONTACT_DISCOVERY);

    // Loop and choose a new portal, but stop if the new portal's normal
    // is by some epsilon "close enough" to the last normal
    do
    {
      // Check if we timed out
      if (IterationTimedOut(MAX_CONTACT_DISCOVERY_STEPS))
      {
        // Show an error and return no collision
        AssertError("Max iterations reached in Contact Discovery!");
        return false;
      }

      // Count up an iteration
      CountIteration();

      // Save the last portal normal
      SaveLastPortalNormal();

      // Find the new supporting point, since we didn't find that it was colliding yet
      FindSupportFromPortalOutNormal();

      // Since we got here, we didn't find a portal that gave us any information
      // The origin is within the supporting plane, but outside the portal
      // Find a new portal
      ChooseNewPortal();

      // Compute the current portals normal
      ComputeCurrentPortalNormal();

      // Checks for the current debug breakpoints
      CheckDebug(STAGE_CONTACT_DISCOVERY);
    }
    // Loop until we find a normal that is close enough
    while (!CheckCurrentNormalCloseEnough());
  }
  // Loop until we no longer need to correct the center for the MTD
  while (CorrectCenterForMTD());


  // The last step is to compute the contact data: the collision normal,
  // penetration depth (MTD), and two applicable points of contact for each shape
  ComputeContactData();

  // The algorithm has now completed successfully, inform debug
  DebugComplete();

  // Since we made it here, there was a collision
  return true;
}

// Set the debug breaking point
void MPR::SetDebugBreakpoint(unsigned breakpoint)
{
  debug_breakpoint_ = breakpoint;
}

// Initialize the MPR given the passed in information (basically just stores it)
void MPR::Initialize(const ShapePair* pair, CollisionInfo& CI, const Vector3* direction)
{
  // Store away pointers
  pair_           = pair;
  collision_info_ = &CI;

  // If we were given a direction to check in
  if (direction != 0)
  {
    // Normalize the direction and set the direction vector
    direction->Normalize(collision_direction_);

    // Set the flag that we're using a direction vector
    using_direction_ = true;
  }
  else
  {
    // Set the flag that we're not using a direction vector
    using_direction_ = false;
  }

  // The MTD has not yet been corrected
  mtd_corrected_ = false;

  // Assume for right now that there is no collision
  collision_info_->was_collision = false;

  // Initialize the debug iteration counter
  debug_iteration_ = DEBUG_NO_BREAK;
}

// Find the the geometric centers of the objects, then translate them
// to a single center point in the Minkowski Difference space
void MPR::FindGeometricCenter()
{
  // Perform the Minkowski Difference on the two points, to get an interior point on the CSO
  // (B - A) centers
  V0_  = pair_->B_center;
  V0_ -= pair_->A_center;

  // Check if the difference point is the origin, if so offset the point
  if (V0_ == ORIGIN)
  {
    // Generate a random vector
    GenerateRandomVector(V0_);

    // Scale the vector down
    V0_ *= DIPSLACE_EPSILON;
  }
}

// Find the starting origin ray from the geometric centers to the origin
void MPR::FindOriginRay()
{
  // The 'origin ray' is simply just the negative ray to the center
  V0_.Opposite(origin_ray_);
}

// Finds the starting candidate portal, not necessarily a portal that the origin ray goes through
// This can be found through several methods, but basically needs to find 3 non-collinear points
void MPR::FindStartingPortal()
{
  // Find the support point in the direction of the origin ray
  Support(origin_ray_, current_portal_.point[V1], shape_A_points_.point[V1], shape_B_points_.point[V1]);

  // Find the support point that is perpendicular to the plane formed by
  // the origin ray and the ray to the first supporting point
  // NOTE: This could be problematic if the support point was exactly in the same
  //       direction as the ray, thus resulting in a bad cross product
  Vector3 new_direction;
  current_portal_.point[V1].CrossProduct(V0_, new_direction);

  // Initialize an iteration counter for this phase to avoid failure states
  InitIterationCounter();

  // Seed random since we might use it below
  SeedRandom(INIT_RANDOM_SEED);

  // Loop until we find a valid first point (this might involve moving the center)
  // This will only loop if the support point is in the same direction as the origin ray
  while (new_direction == ORIGIN)
  {
    // Generate a random direction vector
    Vector3 random_dir;
    GenerateRandomVector(random_dir);

    // Since we're going to move the center, and we don't want it to leave the shape...
    // we know that supporting points are the exterior most points, thus we'll find a
    // random supporting point, and move by a little bit in that direction
    // Get the supporting point in the random direction
    Vector3 rand_support;
    Support(random_dir, rand_support, shape_A_points_.point[V1], shape_B_points_.point[V1]);

    // Use the random support and direction to displace the center
    rand_support -= V0_;
    rand_support *= RANDOM_DISPLACE_RATIO;
    random_dir   *= DIPSLACE_EPSILON;
    V0_ += rand_support;
    V0_ += random_dir;

    // Now recompute the origin ray
    FindOriginRay();

    // Find the support point in the direction of the origin ray
    Support(origin_ray_, current_portal_.point[V1], shape_A_points_.point[V1], shape_B_points_.point[V1]);

    // Find the perpendicular normal to the plane
    current_portal_.point[V1].CrossProduct(V0_, new_direction);

    // Count up an iteration
    CountIteration();

    // Check if we timed out
    if (IterationTimedOut(MAX_RANDOM_CENTER_MOVE_STEPS))
      // Show an error and return no collision
      AssertError("Max iterations reached in Find Starting Portal!");
  }

  // Now get the second supporting point in the new direction
  Support(new_direction, current_portal_.point[V2], shape_A_points_.point[V2], shape_B_points_.point[V2]);

  // Find the support that is perpendicular to the plane containing the
  // interior point and the first two supporting points
  Vector3 temp1, temp2;
  current_portal_.point[V2].Subtract(V0_, temp1);
  current_portal_.point[V1].Subtract(V0_, temp2);
  temp1.CrossProduct(temp2, new_direction);
  Support(new_direction, current_portal_.point[V3], shape_A_points_.point[V3], shape_B_points_.point[V3]);
}

// Checks if the origin ray intersects the portal by checking if the origin is inside of the
// three planes made up by the portal points and the start point, if it does: it returns true,
// if not: it returns false and fills out the plane that that the origin was outside of
// By the end of this method, the normal member will be filled in
bool MPR::CheckOriginRayIntersectPortal()
{
  // NOTE: We need to check that all 3 normal directions are outward facing (point ordering!)

  // Make a plane between (V0, V1, V2)
  temp_plane_.point[0] = V0_;
  temp_plane_.point[1] = current_portal_.point[V1];
  temp_plane_.point[2] = current_portal_.point[V2];

  // Grab the actual shape points and store them away
  temp_shape_A_points_.point[1] = shape_A_points_.point[V1];
  temp_shape_B_points_.point[1] = shape_B_points_.point[V1];
  temp_shape_A_points_.point[2] = shape_A_points_.point[V2];
  temp_shape_B_points_.point[2] = shape_B_points_.point[V2];

  // Test if the origin is outside the plane
  if (PlaneOriginOutside(temp_plane_))
    return false;

  // Make a plane between (V0, V2, V3)
  temp_plane_.point[0] = V0_;
  temp_plane_.point[1] = current_portal_.point[V2];
  temp_plane_.point[2] = current_portal_.point[V3];

  // Grab the actual shape points and store them away
  temp_shape_A_points_.point[1] = shape_A_points_.point[V2];
  temp_shape_B_points_.point[1] = shape_B_points_.point[V2];
  temp_shape_A_points_.point[2] = shape_A_points_.point[V3];
  temp_shape_B_points_.point[2] = shape_B_points_.point[V3];

  // Test if the origin is outside the plane
  if (PlaneOriginOutside(temp_plane_))
    return false;

  // Make a plane between (V0, V3, V1)
  temp_plane_.point[0] = V0_;
  temp_plane_.point[1] = current_portal_.point[V3];
  temp_plane_.point[2] = current_portal_.point[V1];

  // Grab the actual shape points and store them away
  temp_shape_A_points_.point[1] = shape_A_points_.point[V3];
  temp_shape_B_points_.point[1] = shape_B_points_.point[V3];
  temp_shape_A_points_.point[2] = shape_A_points_.point[V1];
  temp_shape_B_points_.point[2] = shape_B_points_.point[V1];

  // Test if the origin is outside the plane
  if (PlaneOriginOutside(temp_plane_))
    return false;

  // Otherwise, it's inside so return true
  return true;
}

// Chooses a new starting portal since the last one didn't work; the input is the point plane made up by
// the starting point and two other old points, and it returns out a new portal as current_portal
// The first point will always be V0, keep this in mind!
void MPR::ChooseNewStartingPortal()
{
  // Replace V0 with the supporting point in the normal direction
  Support(temp_plane_.normal, temp_plane_.point[0], temp_shape_A_points_.point[0], temp_shape_B_points_.point[0]);

  // Now swap the order of the points to make all planes face outward
  Vector3 temp = temp_plane_.point[0];
  temp_plane_.point[0] = temp_plane_.point[1];
  temp_plane_.point[1] = temp;

  // Swap the order for the shape points too (shape A)
  temp = temp_shape_A_points_.point[0];
  temp_shape_A_points_.point[0] = temp_shape_A_points_.point[1];
  temp_shape_A_points_.point[1] = temp;

  // Swap the order for the shape points too (shape B)
  temp = temp_shape_B_points_.point[0];
  temp_shape_B_points_.point[0] = temp_shape_B_points_.point[1];
  temp_shape_B_points_.point[1] = temp;

  // The current portal is now the plane portal, but the portal normal is not valid!
  current_portal_ = temp_plane_;

  // Copy over the shape points also
  shape_A_points_ = temp_shape_A_points_;
  shape_B_points_ = temp_shape_B_points_;
}

// Since we now have a valid portal...
// Checks to see if the origin ray intersects with the current portal: it does so by checking
// if the origin is inside the portal plane or outside the portal plane (inside means hit!)
// This method needs to fill out the normal field for the portal
bool MPR::CheckOriginInsidePortal()
{
  // Return if it's inside the portal
  return PlaneOriginOutsideOrOn(current_portal_);
}

// Finds the supporting point in the direction of the portal's outward normal,
// only runs if the origin was found to be outside the current portal
void MPR::FindSupportFromPortalOutNormal()
{
  // Since the normal always points inside here, we'll use the negative normal
  Vector3 neg_normal;
  current_portal_.normal.Opposite(neg_normal);

  // Find the supporting point in the portal's normal direction
  Support(neg_normal, new_support_, new_support_A_, new_support_B_);
}

// Checks if the origin is outside the supporting plane, which means that the
// origin is not contained within the CSO and the shapes are not intersecting
bool MPR::CheckOriginOutsideSupportPlane()
{
  // Use the dot product between the point and the normal to compare if it's inside
  // The normal points inward, so we want to check if the origin is on the negative side
  return new_support_.DotProduct(current_portal_.normal) > 0.0F;
}

// Checks if the supporting plane is very close to the portal, specifically used for curved shapes
// The check works by checking the distance between the parallel portal and supporting plane
bool MPR::CheckPortalCloseToSupportPlane()
{
  // False for now, only using discrete shapes!
  return false;
}

// Generates a new portal by finding which face the origin ray will go through
void MPR::ChooseNewPortal()
{
  // NOTE: At first I didn't understand why we used points (V4, V0, V[1-3])
  //       Since we're trying to find what triangle to use as the new portal, why wouldn't
  //       we just use the points made by some combo of V[1-3] and V4... why V0?
  //       Well, without a ray intersect triangle test, we need to use the V0 point to make
  //       planes that will give us the inside information


  // Create a plane that we'll use (the only thing that changes is the last point)
  PointPlane plane;
  plane.point[0] = V0_;
  plane.point[2] = new_support_;

  // Bit field to represent which planes it was inside
  unsigned bit_field = 0x00;
  unsigned result = 0;

  // Add the point to the plane
  plane.point[1] = current_portal_.point[V1];
  // Check if the origin is outside or on the plane, and put the result in the bit field
  result     = (int) PlaneOriginOutsideOrOn(plane);
  //           (Positive     )   (Negative      )
  bit_field |= (result * 0x01) + (!result * 0x02);

  // Add the point to the plane
  plane.point[1] = current_portal_.point[V2];
  // Check if the origin is outside or on the plane, and put the result in the bit field
  result     = (int) PlaneOriginOutsideOrOn(plane);
  //           (Positive     )   (Negative      )
  bit_field |= (result * 0x04) + (!result * 0x08);

  // Add the point to the plane
  plane.point[1] = current_portal_.point[V3];
  // Check if the origin is outside or on the plane, and put the result in the bit field
  result     = (int) PlaneOriginOutsideOrOn(plane);
  //           (Positive     )   (Negative      )
  bit_field |= (result * 0x10) + (!result * 0x20);


  // Based on the cases, we choose the new portal
  switch (bit_field)
  {
  // Plane #1 positive and #2 negative, #3 is anything (and kill point #3)
  case (0x01 | 0x08) | 0x10:
  case (0x01 | 0x08) | 0x20:
      // Get rid of vertex 3
      current_portal_.point[V3] = new_support_;
      shape_A_points_.point[V3] = new_support_A_;
      shape_B_points_.point[V3] = new_support_B_;
    break;

  // Plane #2 positive and #3 negative, #1 is anything (and kill point #1)
  case (0x04 | 0x20) | 0x01:
  case (0x04 | 0x20) | 0x02:
    // Get rid of vertex 1
    current_portal_.point[V1] = new_support_;
    shape_A_points_.point[V1] = new_support_A_;
    shape_B_points_.point[V1] = new_support_B_;
    break;

  // Plane #3 positive and #1 negative, #2 is anything (and kill point #2)
  case (0x10 | 0x02) | 0x04:
  case (0x10 | 0x02) | 0x08:
    // Get rid of vertex 2
    current_portal_.point[V2] = new_support_;
    shape_A_points_.point[V2] = new_support_A_;
    shape_B_points_.point[V2] = new_support_B_;
    break;

  /*
  // Special case where it lies on all planes, since it lies in the direct center of
  // the portal. This is very very rare, but... we'll just choose any one of the sides
  case (0x10 | 0x04 | 0x01):
    // Note: I was wrong, this case can also be bad!
    break;
  */


  // Something went terribly wrong! This case should never happen!
  default:
    AssertError("Error occurred in ChooseNewPortal(), an invalid case!");
    break;
  }
}

// Saves the last portal normal
void MPR::SaveLastPortalNormal()
{
  last_normal_ = current_portal_.normal;
}

// Computes the current portal's normal, only explicitly used in finding contact data
void MPR::ComputeCurrentPortalNormal()
{
  // Create the first and second side
  Vector3 side1, side2;
  current_portal_.point[0].Subtract(current_portal_.point[2], side1);
  current_portal_.point[0].Subtract(current_portal_.point[1], side2);

  // The cross product of the two sides is the normal
  side2.CrossProduct(side1, current_portal_.normal);
}

// Checks the condition that the previous normal is close enough to the current normal
bool MPR::CheckCurrentNormalCloseEnough()
{
  // Unfortunately, we need to normalize the normals here
  last_normal_.Normalize();
  current_portal_.normal.Normalize();

  // Compute the dot product
  float dot = last_normal_.DotProduct(current_portal_.normal);

  // Now check if the dot product between them shows that they are close enough
  return dot > (1.0F - CLOSENESS_NORMAL_EPSILON);
}

// Repositions the center so we can get the correct Minimum Translational Distance
// Returns true if we should restart and correct the center for the MTD
bool MPR::CorrectCenterForMTD()
{
  // As long as we haven't corrected the center to account for the MTD...
  if (mtd_corrected_ == false)
  {
    // If we're using a direction vector, correct V0_ by changing the collision direction
    if (using_direction_)
    {
      // Set V0_ to be at the collision direction spot
      V0_ = collision_direction_;

      // Multiply V0_ by an epsilon to keep it small
      V0_ *= DIPSLACE_EPSILON;
    }
    else
    {
      // Compute the intersection between the origin ray and the current portal
      Vector3 intersection_point;
      IntersectLineTriangle(V0_, ORIGIN, current_portal_, intersection_point);

      // Check if the penetration distance is dep (relative to the object's own size)
      if (intersection_point.Length() > DEEP_PENETRATION)
      {
        // Stores the support point
        Vector3 support, temp;

        // Initialize a minimum distance value
        float minimum_dist = FLT_MAX;

        // Loop through all the stacks
        for (int g = 0; g <= SECOND_PHASE_ACCURACY; ++g)
        {
          // Loop through all the slices
          for (int t = 0; t <= SECOND_PHASE_ACCURACY; ++t)
          {
            // Get a normalized value for the current stack/slice
            float t_norm = (float)t / (float)(SECOND_PHASE_ACCURACY);
            float g_norm = (float)g / (float)(SECOND_PHASE_ACCURACY);

            // Create the direction vector from sphereical coordiantes
            Vector3 dir(sin(t_norm * PI * 2.0F) * sin(g_norm * PI),
                        cos(t_norm * PI * 2.0F) * sin(g_norm * PI),
                        cos(g_norm * PI));

            // Support in the direction
            Support(dir, support, temp, temp);

            // Compute the distance that the support lies along the direction
            float dist = support.DotProduct(dir);

            // Check if the distance is less than
            if (dist < minimum_dist)
            {
              // Store the new minimum distance
              minimum_dist = dist;

              // Store the opposite direction into V0_
              dir.Opposite(V0_);
            }
          }
        }
       
        // Multiply V0_ by an epsilon to keep it small
        V0_ *= DIPSLACE_EPSILON;
      }
      else
      {
        // The MTD didn't have to be corrected, just continue on
        return false;
      }
    }

    // Find the origin ray from an interior point in the CSO
    // This needs to be called since we just found ourselves a new center
    FindOriginRay();

    // The MTD will now be corrected
    mtd_corrected_ = true;
    return true;
  }

  // Since we got here, the MTD no longer needs to be corrected
  return false;
}

// Returns if the MTD has been corrected
bool MPR::IsMTDCorrected()
{
  return mtd_corrected_;
}

// Computes the contact data at the end
void MPR::ComputeContactData()
{
  // This will hold the barycentric coordinates of the final contact point
  Vector3 barycentric_coords;


  // If we're not using a direction
  if (using_direction_ == false)
  {
    // Now find the barycentric coordinates of the projected point in the triangle/portal
    // Note, we don't actually need to project the point, since barycentric coordinates
    // can only express positions along the plane, and thus any movement in the normal
    // direction of the plane is ignored
    ComputeBarycentricCoords(current_portal_, ORIGIN, barycentric_coords);

    // Check to see if the point is outside the triangle, if so we'll have to approximate
    // the MTD by using the intersection of the origin ray with the current portal
    if (barycentric_coords.x < 0.0F ||
        barycentric_coords.y < 0.0F ||
        barycentric_coords.z < 0.0F)
    {
      // Compute the intersection between the origin ray and the current portal
      Vector3 intersection_point;
      IntersectLineTriangle(V0_, ORIGIN, current_portal_, intersection_point);

      // Now compute the barycentric coordinates of the intersection point
      ComputeBarycentricCoords(current_portal_, intersection_point, barycentric_coords);
    }

    // Now use the shape points to determine contact points on each
    // Shape A:
    shape_A_points_.point[0] *= barycentric_coords.x;
    shape_A_points_.point[1] *= barycentric_coords.y;
    shape_A_points_.point[2] *= barycentric_coords.z;
    // Shape B:
    shape_B_points_.point[0] *= barycentric_coords.x;
    shape_B_points_.point[1] *= barycentric_coords.y;
    shape_B_points_.point[2] *= barycentric_coords.z;


    // The contact points are the sum of each barycentric weighted point
    // Shape A:
    contact_A_  = shape_A_points_.point[0];
    contact_A_ += shape_A_points_.point[1];
    contact_A_ += shape_A_points_.point[2];
    // Shape B:
    contact_B_  = shape_B_points_.point[0];
    contact_B_ += shape_B_points_.point[1];
    contact_B_ += shape_B_points_.point[2];

    // The contact for the shapes will be the average for now
    contact_A_.Add(contact_B_, collision_info_->point_of_contact);
    collision_info_->point_of_contact /= 2.0F;
  }
  else
  {
    // Compute the intersection between the origin ray and the current portal
    Vector3 intersection_point;
    IntersectLineTriangle(V0_, ORIGIN, current_portal_, intersection_point);

    // Now compute the barycentric coordinates of the intersection point
    ComputeBarycentricCoords(current_portal_, intersection_point, barycentric_coords);

    // Now use the shape points to determine contact points on each
    // Shape A:
    shape_A_points_.point[0] *= barycentric_coords.x;
    shape_A_points_.point[1] *= barycentric_coords.y;
    shape_A_points_.point[2] *= barycentric_coords.z;

    // The contact points are the sum of each barycentric weighted point
    // Shape A:
    contact_A_  = shape_A_points_.point[0];
    contact_A_ += shape_A_points_.point[1];
    contact_A_ += shape_A_points_.point[2];

    // The contact for the shapes will simply just be the contact_A_
    collision_info_->point_of_contact = contact_A_;
  }

  // Compute the collision normal between A and B
  contact_A_.Subtract(contact_B_, collision_info_->contact_normal);

  // Get the penetration depth/length (length of the normal)
  collision_info_->penetration_depth = collision_info_->contact_normal.Length();

  // Now normalize the collision normal by dividing by length
  // Add a little epsilon just in case length is zero for some reason
  collision_info_->contact_normal /= collision_info_->penetration_depth + NORMALIZATION_EPSILON;

  // Also specify that there was a collision (duh, I know)
  collision_info_->was_collision = true;

  // If some invalid case happened, return no collision
  if (collision_info_->contact_normal == ORIGIN)
    collision_info_->was_collision = false;
}

// The supporting function we'll use to sample support
void MPR::Support(const Vector3& dir, Vector3& support_out, Vector3& A_out, Vector3& B_out)
{
  // Call the supporting function
  pair_->MinkowskiSupport(dir, support_out, B_out, A_out);
}

// Simple check for 3 points and a plane (to see if the point is strictly outside the plane)
bool MPR::PlaneOriginOutside(PointPlane& plane_in_plane_out)
{
  // Create the first and second side
  Vector3 side1, side2;
  plane_in_plane_out.point[0].Subtract(plane_in_plane_out.point[2], side1);
  plane_in_plane_out.point[0].Subtract(plane_in_plane_out.point[1], side2);

  // The cross product of the two sides is the normal
  side2.CrossProduct(side1, plane_in_plane_out.normal);
  // Use the dot product between point1 and the normal as the d value
  float d = -plane_in_plane_out.point[0].DotProduct(plane_in_plane_out.normal);
  // If the d-value is greater than zero, return true
  return d > 0.0F;
}

// Simple check for 3 points and a plane (to see if the point is outside or on the plane)
bool MPR::PlaneOriginOutsideOrOn(PointPlane& plane_in_plane_out)
{
  // Create the first and second side
  Vector3 side1, side2;
  plane_in_plane_out.point[0].Subtract(plane_in_plane_out.point[2], side1);
  plane_in_plane_out.point[0].Subtract(plane_in_plane_out.point[1], side2);

  // The cross product of the two sides is the normal
  side2.CrossProduct(side1, plane_in_plane_out.normal);
  // Use the dot product between point1 and the normal as the d value
  float d = -plane_in_plane_out.point[0].DotProduct(plane_in_plane_out.normal);
  // If the d-value is greater than zero, return true
  return d >= 0.0F;
}

// Computes the barycentric coordinates of a point inside of a triangle
void MPR::ComputeBarycentricCoords(const Triangle& triangle, const Vector3& point, Vector3& uvw_out)
{
  // You can find a great explanation of this function in Chapter 3 of
  // Real-time Collision Detection (Christer Ericson), starting on page 46

  // Will hold the side vectors and origin transform
  Vector3 S0, S1, S2;

  // Compute the side vector S0
  triangle.point[1].Subtract(triangle.point[0], S0);
  // Compute the side vector S1
  triangle.point[2].Subtract(triangle.point[0], S1);
  // Compute the origin transformation
  point.Subtract(triangle.point[0], S2);

  // Compute the dot products that we'll use to solve for a and b
  // In the equation a(S0) + b(S1) = S2
  float dot_00 = S0.DotProduct(S0);
  float dot_01 = S0.DotProduct(S1);
  float dot_11 = S1.DotProduct(S1);
  float dot_20 = S2.DotProduct(S0);
  float dot_21 = S2.DotProduct(S1);

  // Compute the denominator (again just to solve for a and b)
  float denominator = (dot_00 * dot_11) - (dot_01 * dot_01);

  // Now compute the u, v, and w values
  // Note, for efficiency we just used a vector, so u = x, v = y, w = z
  uvw_out.y = (dot_11 * dot_20 - dot_01 * dot_21) / denominator;  // v
  uvw_out.z = (dot_00 * dot_21 - dot_01 * dot_20) / denominator;  // w
  uvw_out.x = 1.0F - uvw_out.y - uvw_out.z;                       // u
}

// Intersect a line with a triangle and get the intersection point
void MPR::IntersectLineTriangle(const Vector3& P1, const Vector3& P2, const Triangle& tri, Vector3& point_out)
{
  // Compute the numerator and denominator values
  Vector3 N1, D1;
  tri.point[0].Subtract(P1, N1);
  P2.Subtract(P1, D1);
 
  // Compute the t-value for the intersection time
  float t = tri.normal.DotProduct(N1) / tri.normal.DotProduct(D1);

  // Now compute the actual intersection point
  D1.Scale(t, point_out);
  point_out += P1;
}

// Seeds the random number generator
void MPR::SeedRandom(unsigned seed)
{
  rand_seed_ = seed;
}

// Generates a random number
void MPR::GenerateRandomVector(Vector3& out_rand)
{
  // Generate a new random seed, and then a random number from that
  // X Value:
  rand_seed_ = rand_seed_ * 1103515245 + 12345;
  out_rand.x = ((float)((rand_seed_ / 65536) % 32768) / 32768.0F) - 0.5F;
  // Y Value:
  rand_seed_ = rand_seed_ * 1103515245 + 12345;
  out_rand.y = ((float)((rand_seed_ / 65536) % 32768) / 32768.0F) - 0.5F;
  // Z Value:
  rand_seed_ = rand_seed_ * 1103515245 + 12345;
  out_rand.z = ((float)((rand_seed_ / 65536) % 32768) / 32768.0F) - 0.5F;
}

// Asserts an error
void MPR::AssertError(const char* error)
{
  // This just makes sure that we "use" the error
  (void*) error;

#if defined(MPR_ERRORS) && !defined(MPR_DEBUG_INFO)

  // Show a message box
  ::MessageBox(0, error, "MPR Error:", 0);

#elif defined(MPR_ERRORS) && defined(MPR_DEBUG_INFO)

  // Show the error in the console
  ShowMessage(Engine::MESSAGE_ERROR, error);

  // Print the error on the top
  GetMgr("Graphics")->Call("DebugDrawScreenText", 0x00, 0xFFFF0000, Vector3(0.0F, 0.05F, 0.0F), error);

#endif
}

// Initializes the iteration counter
void MPR::InitIterationCounter()
{
  iteration_counter_ = 0;
}

// Counts up each time an iteration occurs
void MPR::CountIteration()
{
  ++iteration_counter_;
}

// Counts up debug iterations (returns if we should stop)
bool MPR::DebugBreak(Stage stage)
{
  // Count up the debug iterations
  ++debug_iteration_;

  // If we're on the breakpoint, draw the current state
  if (debug_iteration_ == debug_breakpoint_)
  {
    // Draw the current debug state
    DebugDraw(stage);

    // Return true, since we hit the breakpoint
    return true;
  }

  // Return false
  return false;
}


// Draws the current state
void MPR::DebugDraw(Stage stage)
{
  // Make sure we don't get any unused warnings
  (int) stage;

  // If debug is defined
#ifdef MPR_DEBUG_INFO

  // This defines what text to write
  const char* stage_text[] =
  {
    "Stage: Portal Discovery",
    "Stage: Entering Portal Refinement",
    "Stage: Portal Refinement",
    "Stage: Entering Contact Discovery",
    "Stage: Contact Discovery",
    "Stage: Contact Completed"
  };

  // Draw text based off the current stage
  GetMgr("Graphics")->Call("DebugDrawScreenText", 0x00, 0xFFFFFFFF, Vector3(0.0F, 0.0F, 0.0F), stage_text[stage]);

  // Create a vector of vertices
  vector<Vector3> convex_hull;
  convex_hull.reserve(pair_->A_num_verts * pair_->B_num_verts);

  // Loop through all vertices in shape A
  for (unsigned a = 0; a < pair_->A_num_verts; ++a)
  {
    // Loop through all vertices in shape B
    for (unsigned b = 0; b < pair_->B_num_verts; ++b)
    {
      // Compute the Minkowski difference point and add it to the vector
      Vector3 pt = pair_->B_vertices[b];
      pt -= pair_->A_vertices[a];
      convex_hull.push_back(pt);
    }
  }

  // Draw the hull
  GetMgr("Graphics")->Call("DebugDrawHull", 0x00, 0xFF000000, &convex_hull[0], convex_hull.size());

  // Draw the inner point
  GetMgr("Graphics")->Call("DebugDrawPoint", 0x00, 0xFF000000, V0_, 2.0F);

  // Draw the origin ray
  GetMgr("Graphics")->Call("DebugDrawVector", 0x00, 0xFF0000FF, V0_, origin_ray_);

  // Draw the current portal
  GetMgr("Graphics")->Call("DebugDrawTriangle", 0x00, (unsigned) 0x40AA00AA, current_portal_.point[0], current_portal_.point[1], current_portal_.point[2]);

  // Draw the sides of the current portal
  GetMgr("Graphics")->Call("DebugDrawTriangle", 0x00, (unsigned) 0x150000AA, V0_, current_portal_.point[1], current_portal_.point[2]);
  GetMgr("Graphics")->Call("DebugDrawTriangle", 0x00, (unsigned) 0x150000AA, V0_, current_portal_.point[2], current_portal_.point[0]);
  GetMgr("Graphics")->Call("DebugDrawTriangle", 0x00, (unsigned) 0x150000AA, V0_, current_portal_.point[0], current_portal_.point[1]);


  // Draw which points are which
  GetMgr("Graphics")->Call("DebugDrawWorldText", 0x00, 0xFFFFFFFF, V0_, "V0");
  GetMgr("Graphics")->Call("DebugDrawWorldText", 0x00, 0xFFFFFFFF, current_portal_.point[V1], "V1");
  GetMgr("Graphics")->Call("DebugDrawWorldText", 0x00, 0xFFFFFFFF, current_portal_.point[V2], "V2");
  GetMgr("Graphics")->Call("DebugDrawWorldText", 0x00, 0xFFFFFFFF, current_portal_.point[V3], "V3");
  GetMgr("Graphics")->Call("DebugDrawWorldText", 0x00, 0xFFFF0000, ORIGIN, "Origin");


  // Find an average point in the center of the portal (just for drawing the normal)
  Vector3 avg = current_portal_.point[0];
  avg += current_portal_.point[1];
  avg += current_portal_.point[2];
  avg /= 3.0F;

  // Compute the current portal's normal, just to ensure it's correct
  Vector3 side1, side2, portal_normal;
  current_portal_.point[0].Subtract(current_portal_.point[2], side1);
  current_portal_.point[0].Subtract(current_portal_.point[1], side2);

  // The cross product of the two sides is the normal
  side2.CrossProduct(side1, portal_normal);
  portal_normal.Normalize();

  // Draw the portal's normal in the center of the portal
  GetMgr("Graphics")->Call("DebugDrawVector", 0x00, 0xFF0000FF, avg, portal_normal);

  // Draw the origin coordinates
  GetMgr("Graphics")->Call("DebugDrawVector", 0x00, 0xFFFF0000, ORIGIN, Vector3(1.0F, 0.0F, 0.0F));
  GetMgr("Graphics")->Call("DebugDrawVector", 0x00, 0xFFFF0000, ORIGIN, Vector3(0.0F, 1.0F, 0.0F));
  GetMgr("Graphics")->Call("DebugDrawVector", 0x00, 0xFFFF0000, ORIGIN, Vector3(0.0F, 0.0F, 1.0F));


  // If the current stage is contact completed
  if (stage == STAGE_CONTACT_COMPLETED)
  {
    // Draw the points of contact
    GetMgr("Graphics")->Call("DebugDrawSphere", 0x00, 0xFF0000FF, contact_A_, 0.5F);
    GetMgr("Graphics")->Call("DebugDrawSphere", 0x00, 0xFF00FF00, contact_B_, 0.5F);

    // Draw the contact point on the portal
    GetMgr("Graphics")->Call("DebugDrawVector", 0x00, 0xFF00FF00, contact_point, Vector3(1.0F, 0.0F, 0.0F));
    GetMgr("Graphics")->Call("DebugDrawVector", 0x00, 0xFF00FF00, contact_point, Vector3(0.0F, 1.0F, 0.0F));
    GetMgr("Graphics")->Call("DebugDrawVector", 0x00, 0xFF00FF00, contact_point, Vector3(0.0F, 0.0F, 1.0F));

    // Draw the point of intersection text
    GetMgr("Graphics")->Call("DebugDrawWorldText", 0x00, 0xFF00FF00, contact_point, "POI");
  }

#endif
}


// Triggers a collision/algorithm timeout if it reaches the given number of iterations
bool MPR::IterationTimedOut(unsigned max_iterations)
{
  return iteration_counter_ == max_iterations;
}
Categories: Code Samples Tags: