Unity/게임 인공지능 프로그래밍(ML-Agent)

[ML-Agents] DinoRun (이미테이션 러닝)

치명적흑형 2021. 11. 16. 16:38

학습 목표

Dino가 점프을 해서 선인장을 피한다

 

훈련 내용

선인장 위에 있는 동전을 먹으면 리워드를 준다,

Dino가 선인장에 닿으면 감점을 하고 에피소드를 종료한다.

바닥에 닿아 있으면 소량의 리워드를 준다.

 

Demonstration Recorder 컴포넌트를 이용하여 모방 학습을 한다.

 




 

 


using System.Collections;
using System.Collections.Generic;
using Unity.MLAgents;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Sensors;
using UnityEngine;
using UnityEngine.Events;

public class DinoAgent : Agent
{
    private Rigidbody2D rBody;
    [SerializeField]
    private float jumpForce = 300f;
    public UnityAction dieAction;
    public UnityAction episodeBeginAction;

    void Start()
    {
        this.rBody = this.GetComponent<Rigidbody2D>();
    }

    public override void OnEpisodeBegin()
    {
        this.transform.localPosition = new Vector3(4.5f, 2f, 0);
        this.episodeBeginAction();
    }

    public override void CollectObservations(VectorSensor sensor)
    {
        sensor.AddObservation(this.transform.localPosition.y);
        sensor.AddObservation(this.rBody.velocity.y);
    }

    public override void OnActionReceived(ActionBuffers actions)
    {
        var action = actions.DiscreteActions;


        if (action[0] == 1)
        {
            if (this.rBody.velocity.y == 0)
            {
                this.Jump();
            }
        }
    }

    private void Jump()
    {
        this.rBody.AddForce(Vector2.up * this.jumpForce);
    }


    private void OnCollisionEnter2D(Collision2D collision)
    {
        if (collision.collider.CompareTag("Cactus"))
        {
            Destroy(collision.gameObject);

            this.AddReward(-2f);
            this.dieAction();
            this.EndEpisode();
        }
    }
    private void OnTriggerEnter2D(Collider2D collision)
    {
        if (collision.tag == "Reward")
        {
            Destroy(collision.gameObject);

            this.AddReward(1f);
        }
    }
    private void OnCollisionStay2D(Collision2D collision)
    {
        if (collision.collider.CompareTag("Ground"))
        {
            this.AddReward(0.001f);
        }
    }

    public override void Heuristic(in ActionBuffers actionsOut)
    {
        if (Input.GetKey(KeyCode.Space))
        {
            var action = actionsOut.DiscreteActions;
            action[0] = 1;
        }
    }
}
보상과 감점

if (collision.tag == "Reward")
        {
            Destroy(collision.gameObject);

            this.AddReward(1f);
        }

----------------------------------------------------------

if (collision.collider.CompareTag("Ground"))
        {
            this.AddReward(0.001f);
        }

----------------------------------------------------------

if (collision.collider.CompareTag("Cactus"))
        {
            Destroy(collision.gameObject);

            this.AddReward(-2f);
            this.dieAction();
            this.EndEpisode();
        }

 

 

using System.Collections;
using System.Collections.Generic;
using UnityEngine;

public class TrainingArea : MonoBehaviour
{
    [SerializeField]
    private GameObject cactusPrefab;
    [SerializeField]
    private float spawnTime = 3f;
    [SerializeField]
    private Transform targetParent;
    [SerializeField]
    private DinoAgent agent;
    private Coroutine routine;
    private List<GameObject> cactusGoList = new List<GameObject>();

    void Start()
    {
        this.agent.dieAction = () => {
            //생성을 멈춘다 
            if (this.routine != null)
                StopCoroutine(this.routine);
            //생성되어 있는 모든 선인장들을 제거 
            this.ClearCactusList();
        };
        this.agent.episodeBeginAction = () => {
            //다시 선인장을 생성하기 시작한다 
            if (this.routine != null)
                StopCoroutine(this.routine);
            this.routine = StartCoroutine(this.GenerateCactus());
        };

        //최초 시작시 선인장들을 생성한다 
        this.routine = StartCoroutine(this.GenerateCactus());
    }

    private void ClearCactusList()
    {
        foreach (var go in this.cactusGoList)
        {
            Destroy(go);
        }
        //컬렉션을 비운다 
        this.cactusGoList.Clear();
    }

    private IEnumerator GenerateCactus()
    {
        while (true)
        {
            yield return new WaitForSeconds(this.spawnTime);
            var go = Instantiate(this.cactusPrefab, targetParent);
            go.transform.localPosition = new Vector3(12f, 0, 0);
            this.cactusGoList.Add(go);
        }
    }

}
using System.Collections;
using System.Collections.Generic;
using UnityEngine;

public class CactusMove : MonoBehaviour
{
    [SerializeField]
    private float moveSpeed = 5f;
    
    void Update()
    {
        this.transform.Translate(Vector2.left * this.moveSpeed * Time.deltaTime);
        if(this.transform.position.x < -3)
        {
            Destroy(this.gameObject);
        }
    }
}

behaviors:
  MLDino:
    trainer_type: ppo
    hyperparameters:
      batch_size: 64
      buffer_size: 12000
      learning_rate: 0.0003
      beta: 0.001
      epsilon: 0.2
      lambd: 0.99
      num_epoch: 3
      learning_rate_schedule: linear
    network_settings:
      normalize: true
      hidden_units: 128
      num_layers: 2
      vis_encode_type: simple
    reward_signals:
      extrinsic:
        gamma: 0.99
        strength: 1.0
      gail:
        strength: 0.1
        demo_path: ./demos/DinoRun.demo

    behavioral_cloning:
      demo_path: ./demos/DinoRun.demo
      strength: 0.5

    keep_checkpoints: 5
    max_steps: 500000
    time_horizon: 1000
    summary_freq: 12000
이미테이션 러닝

    reward_signals:
      extrinsic:
        gamma: 0.99
        strength: 1.0
      gail:
        strength: 0.1
        demo_path: ./demos/DinoRun.demo

    behavioral_cloning:
      demo_path: ./demos/DinoRun.demo
      strength: 0.5

 

 


이미테이션 러닝 없이 더 좋은 학습 방법

 

절대적인 보상이 아닌 상대적인 보상을 준다

 

바닥에 붙어있는 동안 리워드를 준다

 

리워드와 동일한 수치를 스코어 변수에 담는다

 

선인장에 닿을 경우 스코어의 반절만큼 감점을 준다.

 


 


 

동전에는 리워드가 없다.


using System.Collections;
using System.Collections.Generic;
using Unity.MLAgents;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Sensors;
using UnityEngine;
using UnityEngine.Events;

public class DinoAgent : Agent
{
    private Rigidbody2D rBody;
    [SerializeField]
    private float jumpForce = 300f;
    public UnityAction dieAction;
    public UnityAction episodeBeginAction;

    void Start()
    {
        this.rBody = this.GetComponent<Rigidbody2D>();
    }

    public override void OnEpisodeBegin()
    {
        this.transform.localPosition = new Vector3(4.5f, 2f, 0);
        this.episodeBeginAction();
    }

    public override void CollectObservations(VectorSensor sensor)
    {
        sensor.AddObservation(this.transform.localPosition.y);
        sensor.AddObservation(this.rBody.velocity.y);
    }

    public override void OnActionReceived(ActionBuffers actions)
    {
        var action = actions.DiscreteActions;


        if (action[0] == 1)
        {
            if (this.rBody.velocity.y == 0)
            {
                this.Jump();
            }
        }
    }

    private void Jump()
    {
        this.rBody.AddForce(Vector2.up * this.jumpForce);
    }


    private void OnCollisionEnter2D(Collision2D collision)
    {
        if (collision.collider.CompareTag("Cactus"))
        {
            Destroy(collision.gameObject);

            this.score = 0;
            this.AddReward(-this.score/2);
            this.dieAction();
            this.EndEpisode();
        }
    }
    private void OnTriggerEnter2D(Collider2D collision)
    {
        if (collision.tag == "Reward")
        {
            Destroy(collision.gameObject);
        }
    }

    private float score;

    private void OnCollisionStay2D(Collision2D collision)
    {
        if (collision.collider.CompareTag("Ground"))
        {
            this.score += 0.03f;
            this.AddReward(0.03f);
        }
    }

    public override void Heuristic(in ActionBuffers actionsOut)
    {
        if (Input.GetKey(KeyCode.Space))
        {
            var action = actionsOut.DiscreteActions;
            action[0] = 1;
        }
    }
}
if (collision.collider.CompareTag("Ground"))
        {
            this.score += 0.03f;
            this.AddReward(0.03f);
        }

------------------------------------------------

if (collision.collider.CompareTag("Cactus"))
        {
            Destroy(collision.gameObject);

            this.score = 0;
            this.AddReward(-this.score/2);
            this.EndEpisode();
        }