Reinforcement Learning Penguins (Part 2/4) | Unity ML-Agents

Penguins with Unity ML-Agents Project

In this tutorial you’ll write all of the C# code needed for the penguin ML-Agents. These scripts manage Scene setup (such as randomized placement of the penguin agent), penguin decision making, fish movement, and interaction between the penguin agent and the Scene.

Writing the Code

First, you’ll create all of the C# scripts needed for this project. After you’ve created them, we’ll walk through the code for each.

  • Create a new folder in Unity called Scripts inside the Penguin folder.
  • Create three new C# scripts inside the Scripts folder (Figure 01):
    • PenguinArea
    • PenguinAgent
    • Fish

C# scripts
Figure 01: C# scripts in the Scripts folder.

PenguinAcademy.cs

The PenguinAcademy script is not needed as of version 0.14. Academy is now a singleton and we don't need our own version. I'm leaving this here in case you are coming from the older tutorial and are confused about where it went.

PenguinArea.cs

The PenguinArea (Figure 02) will manage a training area with one penguin, one baby, and multiple fish. It has the responsibilities of removing fish, spawning fish, and random placement of the penguins. There might be multiple PenguinAreas in a Scene for more efficient training.

PenguinArea
Figure 02: A single PenguinArea with penguins and fish, which you will be creating later and attaching the PenguinArea.cs script to.

  • Open PenguinArea.cs.
  • Delete the Start() function.
  • Delete the Update function.
  • Add using statements for MLAgents and TMPro.
  • Update the class definition to inherit from Area instead of Monobehaviour.
using System.Collections.Generic;
using UnityEngine;
using MLAgents;
using TMPro;

public class PenguinArea : MonoBehaviour
{

}
  • Add the following variables inside the class (between { }).
[Tooltip("The agent inside the area")]
    public PenguinAgent penguinAgent;

    [Tooltip("The baby penguin inside the area")]
    public GameObject penguinBaby;

    [Tooltip("The TextMeshPro text that shows the cumulative reward of the agent")]
    public TextMeshPro cumulativeRewardText;

    [Tooltip("Prefab of a live fish")]
    public Fish fishPrefab;

    private List<GameObject> fishList;

These variables will keep track of important objects in the Scene. You will hook up objects to the public variables in Unity later in this tutorial.

  • Add a new ResetArea() function inside the class after the last private variable.
/// <summary>
    /// Reset the area, including fish and penguin placement
    /// </summary>
    public void ResetArea()
    {
        RemoveAllFish();
        PlacePenguin();
        PlaceBaby();
        SpawnFish(4, .5f);
    }

The functions in the code above do not exist yet, but we will create them later in this script.

  • Add a new RemoveSpecificFish() function.
  • Add a new FishRemaining() function.
/// <summary>
    /// Remove a specific fish from the area when it is eaten
    /// </summary>
    /// <param name="fishObject">The fish to remove</param>
    public void RemoveSpecificFish(GameObject fishObject)
    {
        fishList.Remove(fishObject);
        Destroy(fishObject);
    }

    /// <summary>
    /// The number of fish remaining
    /// </summary>
    public int FishRemaining
    {
        get { return fishList.Count; }
    }

When the penguin catches a fish, the PenguinAgent script will call RemoveSpecificFish() to remove it from the water.

The next few functions will handle placement of the animals in the area. It makes the most sense to spawn fish in the water and place the baby penguin on land. The penguin can move between land and water, so it can be placed in either. In Figure 03 below, you can see where the script will randomly position each type of animal.

Placement regions
Figure 03: Placement regions for the penguin, the baby penguin, and fish.

  • Add a new ChooseRandomPosition() function.
/// <summary>
    /// Choose a random position on the X-Z plane within a partial donut shape
    /// </summary>
    /// <param name="center">The center of the donut</param>
    /// <param name="minAngle">Minimum angle of the wedge</param>
    /// <param name="maxAngle">Maximum angle of the wedge</param>
    /// <param name="minRadius">Minimum distance from the center</param>
    /// <param name="maxRadius">Maximum distance from the center</param>
    /// <returns>A position falling within the specified region</returns>
    public static Vector3 ChooseRandomPosition(Vector3 center, float minAngle, float maxAngle, float minRadius, float maxRadius)
    {
        float radius = minRadius;
        float angle = minAngle;

        if (maxRadius > minRadius)
        {
            // Pick a random radius
            radius = UnityEngine.Random.Range(minRadius, maxRadius);
        }

        if (maxAngle > minAngle)
        {
            // Pick a random angle
            angle = UnityEngine.Random.Range(minAngle, maxAngle);
        }

        // Center position + forward vector rotated around the Y axis by "angle" degrees, multiplies by "radius"
        return center + Quaternion.Euler(0f, angle, 0f) * Vector3.forward * radius;
    }

This function uses special radius and angle limits to pick a random position within wedges around the central point in the area. Read the comments in the code for more detail.

  • Add a new RemoveAllFish() function.
/// <summary>
    /// Remove all fish from the area
    /// </summary>
    private void RemoveAllFish()
    {
        if (fishList != null)
        {
            for (int i = 0; i < fishList.Count; i++)
            {
                if (fishList[i] != null)
                {
                    Destroy(fishList[i]);
                }
            }
        }

        fishList = new List<GameObject>();
    }

The ResetArea() function calls RemoveAllFish() to make sure no fish are in the area before spawning new fish.

  • Add a new PlacePenguin() function.
  • Add a new PlaceBaby() function.
/// <summary>
    /// Place the penguin in the area
    /// </summary>
    private void PlacePenguin()
    {
        Rigidbody rigidbody = penguinAgent.GetComponent<Rigidbody>();
        rigidbody.velocity = Vector3.zero;
        rigidbody.angularVelocity = Vector3.zero;
        penguinAgent.transform.position = ChooseRandomPosition(transform.position, 0f, 360f, 0f, 9f) + Vector3.up * .5f;
        penguinAgent.transform.rotation = Quaternion.Euler(0f, UnityEngine.Random.Range(0f, 360f), 0f);
    }

    /// <summary>
    /// Place the baby in the area
    /// </summary>
    private void PlaceBaby()
    {
        Rigidbody rigidbody = penguinBaby.GetComponent<Rigidbody>();
        rigidbody.velocity = Vector3.zero;
        rigidbody.angularVelocity = Vector3.zero;
        penguinBaby.transform.position = ChooseRandomPosition(transform.position, -45f, 45f, 4f, 9f) + Vector3.up * .5f;
        penguinBaby.transform.rotation = Quaternion.Euler(0f, 180f, 0f);
    }

These functions place the penguins. In both cases, they set rigidbody velocities to zero because unexpected things can happen when training for long periods of time at 100x speed. For example, the penguin could fall through the floor, then accelerate downward. When the area resets, the position would be reset, but if the downward velocity is not reset, the penguin might blast through the ground.

  • Add a new SpawnFish() function.
/// <summary>
    /// Spawn some number of fish in the area and set their swim speed
    /// </summary>
    /// <param name="count">The number to spawn</param>
    /// <param name="fishSpeed">The swim speed</param>
    private void SpawnFish(int count, float fishSpeed)
    {
        for (int i = 0; i < count; i++)
        {
            // Spawn and place the fish
            GameObject fishObject = Instantiate<GameObject>(fishPrefab.gameObject);
            fishObject.transform.position = ChooseRandomPosition(transform.position, 100f, 260f, 2f, 13f) + Vector3.up * .5f;
            fishObject.transform.rotation = Quaternion.Euler(0f, UnityEngine.Random.Range(0f, 360f), 0f);

            // Set the fish's parent to this area's transform
            fishObject.transform.SetParent(transform);

            // Keep track of the fish
            fishList.Add(fishObject);

            // Set the fish speed
            fishObject.GetComponent<Fish>().fishSpeed = fishSpeed;
        }
    }

This code places a specified number of fish in the area and sets their default swim speed. See the comments in the code for more detail.

  • Add a new Start() function and call ResetArea().
/// <summary>
    /// Called when the game starts
    /// </summary>
    private void Start()
    {
        ResetArea();
    }
  • Add a new Update() function.
/// <summary>
    /// Called every frame
    /// </summary>
    private void Update()
    {
        // Update the cumulative reward text
        cumulativeRewardText.text = penguinAgent.GetCumulativeReward().ToString("0.00");
    }

This function updates the cumulative reward display text on the back wall of the area every frame. It is not necessary for training, but it helps you see how well the penguins are performing.

That’s all for the PenguinArea script!

PenguinAgent.cs

The PenguinAgent class, which inherits from the Agent class, is where the cool stuff happens. It handles observing the environment, taking action, interacting, and accepting player input.

  • Open PenguinAgent.cs.
  • Delete the Start() function.
  • Delete the Update() function.
  • Add a MLAgents using statement.
  • Change the class definition to inherit from Agent instead of Monobehaviour.
using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Sensors;

public class PenguinAgent : Agent
{

}
  • Add public variables to keep track of the move and turn speed of the penguin agent as well as the Prefabs for the heart and regurgitated fish.
[Tooltip("How fast the agent moves forward")]
    public float moveSpeed = 5f;

    [Tooltip("How fast the agent turns")]
    public float turnSpeed = 180f;

    [Tooltip("Prefab of the heart that appears when the baby is fed")]
    public GameObject heartPrefab;

    [Tooltip("Prefab of the regurgitated fish that appears when the baby is fed")]
    public GameObject regurgitatedFishPrefab;
  • Add private variables to keep track of things.
private PenguinArea penguinArea;
    new private Rigidbody rigidbody;
    private GameObject baby;
    private bool isFull; // If true, penguin has a full stomach

InitializeAgent() is called once, automatically, when the agent wakes up. It is not called every time the agent is reset, which is why there is a separate ResetAgent() function. We’ll use it to find a few objects in our Scene.

  • Override Initialize()
/// <summary>
    /// Initial setup, called when the agent is enabled
    /// </summary>
    public override void Initialize()
    {
        base.Initialize();
        penguinArea = GetComponentInParent<PenguinArea>();
        baby = penguinArea.penguinBaby;
        rigidbody = GetComponent<Rigidbody>();
    }

OnActionReceived() is where the agent receives and responds to commands. These commands may originate from a neural network or a human player, but this function treats them the same.

The actionBuffers parameter is a struct that contains an array of numerical values that correspond to actions the agent should take. For this project, we are using "discrete" actions, which means each integer value (e.g., 0, 1, 2, …) corresponds to a choice. The alternative is "continuous" actions, which instead allows a choice of any fractional value between -1 and +1 (e.g., -.7, 0.23, .4, …). Discrete actions allow only one choice at a time with no in-between.

In this case:

  • actionBuffers.DiscreteActions[0] can either be 0 or 1, indicating whether to remain in place (0) or move forward at full speed (1).
  • actionBuffers.DiscreteActions[1] can either be 0, 1, or 2, indicating whether to not turn (0), turn in the negative direction (1), or turn in the positive direction (2).

The neural network, when trained, actually has no concept of what these actions do. It only knows that when it sees the environment a certain way, some actions tend to result in more reward points. This is why it will be very important to create an effective observation of the environment later in this script.

After interpreting the vector actions, the OnActionReceived() function applies the movement and rotation and then adds a small negative reward. This small negative reward encourages the agent to complete its task as quickly as possible.

In this case, a reward of -1 / 5000 is given for each of the 5,000 steps (this is comes from the variable MaxStep, which we'll set later). If the penguin finishes early — in 3,000 steps, for example — the negative reward added from this line of code would be -3000 / 5000 = -0.6. If the penguin takes all 5,000 steps, the total negative reward would be -5000 / 5000 = -1.

  • Override OnActionReceived()
/// <summary>
    /// Perform actions based on a vector of numbers
    /// </summary>
    /// <param name="actionBuffers">The struct of actions to take</param>
    public override void OnActionReceived(ActionBuffers actionBuffers)
    {
        // Convert the first action to forward movement
        float forwardAmount = actionBuffers.DiscreteActions[0];

        // Convert the second action to turning left or right
        float turnAmount = 0f;
        if (actionBuffers.DiscreteActions[1] == 1f)
        {
            turnAmount = -1f;
        }
        else if (actionBuffers.DiscreteActions[1] == 2f)
        {
            turnAmount = 1f;
        }

        // Apply movement
        rigidbody.MovePosition(transform.position + transform.forward * forwardAmount * moveSpeed * Time.fixedDeltaTime);
        transform.Rotate(transform.up * turnAmount * turnSpeed * Time.fixedDeltaTime);

        // Apply a tiny negative reward every step to encourage action
        if (MaxStep > 0) AddReward(-1f / MaxStep);
    }

The Heuristic() function allows control of the agent without a neural network. This function will read inputs from the human player via the keyboard, convert them into actions, and place those actions into an array called DiscreteActions. This same array is what is read in the OnActionReceived function when a human is playing (rather than an AI). In our project:

  • The default forwardAction will be 0, but if the player presses 'W' on the keyboard, this value will be set to 1.
  • The default turnAction will be 0, but if the player presses 'A' or 'D' on the keyboard, the value will be set to 1 or 2 respectively to turn left or right.

  • Override the Heuristic() function.

/// <summary>
    /// Read inputs from the keyboard and convert them to a list of actions.
    /// This is called only when the player wants to control the agent and has set
    /// Behavior Type to "Heuristic Only" in the Behavior Parameters inspector.
    /// </summary>
    /// <returns>A vectorAction array of floats that will be passed into <see cref="AgentAction(float[])"/></returns>
    public override void Heuristic(in ActionBuffers actionsOut)
    {
        int forwardAction = 0;
        int turnAction = 0;
        if (Input.GetKey(KeyCode.W))
        {
            // move forward
            forwardAction = 1;
        }
        if (Input.GetKey(KeyCode.A))
        {
            // turn left
            turnAction = 1;
        }
        else if (Input.GetKey(KeyCode.D))
        {
            // turn right
            turnAction = 2;
        }

        // Put the actions into the array
        actionsOut.DiscreteActions.Array[0] = forwardAction;
        actionsOut.DiscreteActions.Array[1] = turnAction;
    }

The base Agent class calls the OnEpisodeBegin() function automatically when the agent is done feeding the baby all of the fish or reaches the max number of steps. We will use it to empty the penguin’s belly and reset the area.

/// <summary>
    /// When a new episode begins, reset the agent and area
    /// </summary>
    public override void OnEpisodeBegin()
    {
        isFull = false;
        penguinArea.ResetArea();
    }

The penguin agent observes the environment in two different ways. The first way is with raycasts. This is like shining a bunch of laser pointers out from the penguin and seeing if they hit anything. It's similar to LIDAR, which is used by autonomous cars and robots. Raycast observations are added via a RayPerceptionSensor component, which we'll add in the Unity Editor later.

The second way the agent observes the environment is with numerical values. Whether it's a true/false value, a distance, an XYZ position in space, or a quaternion rotation, you can convert an observation into a list of numbers and add it as an observation for the agent. Check out the comments in the code to understand what we're adding.

You need to be very thoughtful when choosing what to observe. If the agent doesn't have enough information about its environment, it will not be able to complete its task. Imagine your agent is floating in space, blindfolded. What would it need to be told about its environment to make an intelligent decision?

This penguin agent, as currently implemented, doesn't have any memory. We need to help it out by telling it where things are every update step so that it can make a decision. It’s possible to use memory in ML-Agents, but that’s beyond the scope of this tutorial. You can read more about it in the Memory-enhanced Agents using Recurrent Neural Networks documentation.

  • Override the CollectObservations() function.
/// <summary>
    /// Collect all non-Raycast observations
    /// </summary>
    /// <param name="sensor">The vector sensor to add observations to</param>
    public override void CollectObservations(VectorSensor sensor)
    {
        // Whether the penguin has eaten a fish (1 float = 1 value)
        sensor.AddObservation(isFull);

        // Distance to the baby (1 float = 1 value)
        sensor.AddObservation(Vector3.Distance(baby.transform.position, transform.position));

        // Direction to baby (1 Vector3 = 3 values)
        sensor.AddObservation((baby.transform.position - transform.position).normalized);

        // Direction penguin is facing (1 Vector3 = 3 values)
        sensor.AddObservation(transform.forward);

        // 1 + 1 + 3 + 3 = 8 total values
    }

Next we'll implement OnCollisionEnter() and test for collisions with items that have the tag "fish" or "baby" and respond accordingly.

  • Add a new OnCollisionEnter() function.
/// <summary>
    /// When the agent collides with something, take action
    /// </summary>
    /// <param name="collision">The collision info</param>
    private void OnCollisionEnter(Collision collision)
    {
        if (collision.transform.CompareTag("fish"))
        {
            // Try to eat the fish
            EatFish(collision.gameObject);
        }
        else if (collision.transform.CompareTag("baby"))
        {
            // Try to feed the baby
            RegurgitateFish();
        }
    }

Now we can add a function to eat fish, assuming the penguin doesn't already have a full stomach. It will remove that fish from the area and get a reward.

  • Add a new EatFish() function.
/// <summary>
    /// Check if agent is full, if not, eat the fish and get a reward
    /// </summary>
    /// <param name="fishObject">The fish to eat</param>
    private void EatFish(GameObject fishObject)
    {
        if (isFull) return; // Can't eat another fish while full
        isFull = true;

        penguinArea.RemoveSpecificFish(fishObject);

        AddReward(1f);
    }

Finally, we'll add a function to regurgitate fish and feed the baby. We’ll spawn a regurgitated fish blob on the ground as well as a heart floating in the air to show how much the baby loves its parent for feeding it. We’ll also set an auto-destroy timer. The agent gets a reward, and if there are no fish remaining, we call Done(), which will automatically call AgentReset().

  • Create a new RegurgitateFish() function.
/// <summary>
    /// Check if agent is full, if yes, feed the baby
    /// </summary>
    private void RegurgitateFish()
    {
        if (!isFull) return; // Nothing to regurgitate
        isFull = false;

        // Spawn regurgitated fish
        GameObject regurgitatedFish = Instantiate<GameObject>(regurgitatedFishPrefab);
        regurgitatedFish.transform.parent = transform.parent;
        regurgitatedFish.transform.position = baby.transform.position;
        Destroy(regurgitatedFish, 4f);

        // Spawn heart
        GameObject heart = Instantiate<GameObject>(heartPrefab);
        heart.transform.parent = transform.parent;
        heart.transform.position = baby.transform.position + Vector3.up;
        Destroy(heart, 4f);

        AddReward(1f);

        if (penguinArea.FishRemaining <= 0)
        {
            EndEpisode();
        }
    }

That's all for the PenguinAgent script!

Fish.cs

The Fish class will attach to each fish and make it swim. Unity doesn’t have water physics built in, so our code just moves them in a straight line toward a target destination to keep things simple.

  • Open Fish.cs.
  • Delete the Start() function.
  • Delete the Update() function.
  • Add several variables as shown.

Here’s an overview of the variables:

  • fishSpeed controls the average speed of the fish.
  • randomizedSpeed is a slightly altered speed that we will change randomly each time a new swim destination is picked.
  • nextActionTime is used to trigger the selection of a new swim destination.
  • targetPosition is the position of the destination the fish is swimming toward.
using UnityEngine;

public class Fish : MonoBehaviour
{
    [Tooltip("The swim speed")]
    public float fishSpeed;

    private float randomizedSpeed = 0f;
    private float nextActionTime = -1f;
    private Vector3 targetPosition;
}

FixedUpdate is called at a regular interval of 0.02 seconds (it is independent of frame rate) and will allow us to interact even when the agent is training at an increased game speed, which is common for training ML-Agents. In it, we check if the fish should swim and, if so, call the Swim() function.

  • Add a new FixedUpdate() function.
/// <summary>
    /// Called every timestep
    /// </summary>
    private void FixedUpdate()
    {
        if (fishSpeed > 0f)
        {
            Swim();
        }
    }

Next, we’ll add swim functionality. At any given update, the fish will either pick a new speed and destination, or move toward its current destination.

When it is time to take a new action, the fish will:

  • Choose a new randomized speed between 50% and 150% of the average fish speed.
  • Pick a new random target position (in the water) to swim toward.
  • Rotate the fish to face the target.
  • Calculate the time needed to get there.

Otherwise, the fish will move toward the target and make sure it doesn't swim past it.

  • Add a new Swim() function.

    /// <summary>
      /// Swim between random positions
      /// </summary>
      private void Swim()
      {
          // If it's time for the next action, pick a new speed and destination
          // Else, swim toward the destination
          if (Time.fixedTime >= nextActionTime)
          {
              // Randomize the speed
              randomizedSpeed = fishSpeed * UnityEngine.Random.Range(.5f, 1.5f);
    
              // Pick a random target
              targetPosition = PenguinArea.ChooseRandomPosition(transform.parent.position, 100f, 260f, 2f, 13f);
    
              // Rotate toward the target
              transform.rotation = Quaternion.LookRotation(targetPosition - transform.position, Vector3.up);
    
              // Calculate the time to get there
              float timeToGetThere = Vector3.Distance(transform.position, targetPosition) / randomizedSpeed;
              nextActionTime = Time.fixedTime + timeToGetThere;
          }
          else
          {
              // Make sure that the fish does not swim past the target
              Vector3 moveVector = randomizedSpeed * transform.forward * Time.fixedDeltaTime;
              if (moveVector.magnitude <= Vector3.Distance(transform.position, targetPosition))
              {
                  transform.position += moveVector;
              }
              else
              {
                  transform.position = targetPosition;
                  nextActionTime = Time.fixedTime;
              }
          }
      }

    That’s all for the Fish script!

Conclusion

You should now have all of the code you need to train the penguins to catch fish and feed their babies. In the next tutorial, you will set up your Scene to use this code.

Tutorial Parts

Reinforcement Learning Penguins (Part 1/4)
Reinforcement Learning Penguins (Part 2/4)
Reinforcement Learning Penguins (Part 3/4)
Reinforcement Learning Penguins (Part 4/4)