학습 목표
Dino가 점프을 해서 선인장을 피한다
훈련 내용
선인장 위에 있는 동전을 먹으면 리워드를 준다,
Dino가 선인장에 닿으면 감점을 하고 에피소드를 종료한다.
바닥에 닿아 있으면 소량의 리워드를 준다.
Demonstration Recorder 컴포넌트를 이용하여 모방 학습을 한다.



using System.Collections;
using System.Collections.Generic;
using Unity.MLAgents;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Sensors;
using UnityEngine;
using UnityEngine.Events;
public class DinoAgent : Agent
{
private Rigidbody2D rBody;
[SerializeField]
private float jumpForce = 300f;
public UnityAction dieAction;
public UnityAction episodeBeginAction;
void Start()
{
this.rBody = this.GetComponent<Rigidbody2D>();
}
public override void OnEpisodeBegin()
{
this.transform.localPosition = new Vector3(4.5f, 2f, 0);
this.episodeBeginAction();
}
public override void CollectObservations(VectorSensor sensor)
{
sensor.AddObservation(this.transform.localPosition.y);
sensor.AddObservation(this.rBody.velocity.y);
}
public override void OnActionReceived(ActionBuffers actions)
{
var action = actions.DiscreteActions;
if (action[0] == 1)
{
if (this.rBody.velocity.y == 0)
{
this.Jump();
}
}
}
private void Jump()
{
this.rBody.AddForce(Vector2.up * this.jumpForce);
}
private void OnCollisionEnter2D(Collision2D collision)
{
if (collision.collider.CompareTag("Cactus"))
{
Destroy(collision.gameObject);
this.AddReward(-2f);
this.dieAction();
this.EndEpisode();
}
}
private void OnTriggerEnter2D(Collider2D collision)
{
if (collision.tag == "Reward")
{
Destroy(collision.gameObject);
this.AddReward(1f);
}
}
private void OnCollisionStay2D(Collision2D collision)
{
if (collision.collider.CompareTag("Ground"))
{
this.AddReward(0.001f);
}
}
public override void Heuristic(in ActionBuffers actionsOut)
{
if (Input.GetKey(KeyCode.Space))
{
var action = actionsOut.DiscreteActions;
action[0] = 1;
}
}
}
| 보상과 감점 if (collision.tag == "Reward") { Destroy(collision.gameObject); this.AddReward(1f); } ---------------------------------------------------------- if (collision.collider.CompareTag("Ground")) { this.AddReward(0.001f); } ---------------------------------------------------------- if (collision.collider.CompareTag("Cactus")) { Destroy(collision.gameObject); this.AddReward(-2f); this.dieAction(); this.EndEpisode(); } |
using System.Collections;
using System.Collections.Generic;
using UnityEngine;
public class TrainingArea : MonoBehaviour
{
[SerializeField]
private GameObject cactusPrefab;
[SerializeField]
private float spawnTime = 3f;
[SerializeField]
private Transform targetParent;
[SerializeField]
private DinoAgent agent;
private Coroutine routine;
private List<GameObject> cactusGoList = new List<GameObject>();
void Start()
{
this.agent.dieAction = () => {
//생성을 멈춘다
if (this.routine != null)
StopCoroutine(this.routine);
//생성되어 있는 모든 선인장들을 제거
this.ClearCactusList();
};
this.agent.episodeBeginAction = () => {
//다시 선인장을 생성하기 시작한다
if (this.routine != null)
StopCoroutine(this.routine);
this.routine = StartCoroutine(this.GenerateCactus());
};
//최초 시작시 선인장들을 생성한다
this.routine = StartCoroutine(this.GenerateCactus());
}
private void ClearCactusList()
{
foreach (var go in this.cactusGoList)
{
Destroy(go);
}
//컬렉션을 비운다
this.cactusGoList.Clear();
}
private IEnumerator GenerateCactus()
{
while (true)
{
yield return new WaitForSeconds(this.spawnTime);
var go = Instantiate(this.cactusPrefab, targetParent);
go.transform.localPosition = new Vector3(12f, 0, 0);
this.cactusGoList.Add(go);
}
}
}
using System.Collections;
using System.Collections.Generic;
using UnityEngine;
public class CactusMove : MonoBehaviour
{
[SerializeField]
private float moveSpeed = 5f;
void Update()
{
this.transform.Translate(Vector2.left * this.moveSpeed * Time.deltaTime);
if(this.transform.position.x < -3)
{
Destroy(this.gameObject);
}
}
}
behaviors:
MLDino:
trainer_type: ppo
hyperparameters:
batch_size: 64
buffer_size: 12000
learning_rate: 0.0003
beta: 0.001
epsilon: 0.2
lambd: 0.99
num_epoch: 3
learning_rate_schedule: linear
network_settings:
normalize: true
hidden_units: 128
num_layers: 2
vis_encode_type: simple
reward_signals:
extrinsic:
gamma: 0.99
strength: 1.0
gail:
strength: 0.1
demo_path: ./demos/DinoRun.demo
behavioral_cloning:
demo_path: ./demos/DinoRun.demo
strength: 0.5
keep_checkpoints: 5
max_steps: 500000
time_horizon: 1000
summary_freq: 12000
| 이미테이션 러닝 reward_signals: extrinsic: gamma: 0.99 strength: 1.0 gail: strength: 0.1 demo_path: ./demos/DinoRun.demo behavioral_cloning: demo_path: ./demos/DinoRun.demo strength: 0.5 |
이미테이션 러닝 없이 더 좋은 학습 방법
절대적인 보상이 아닌 상대적인 보상을 준다
바닥에 붙어있는 동안 리워드를 준다
리워드와 동일한 수치를 스코어 변수에 담는다
선인장에 닿을 경우 스코어의 반절만큼 감점을 준다.

동전에는 리워드가 없다.
using System.Collections;
using System.Collections.Generic;
using Unity.MLAgents;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Sensors;
using UnityEngine;
using UnityEngine.Events;
public class DinoAgent : Agent
{
private Rigidbody2D rBody;
[SerializeField]
private float jumpForce = 300f;
public UnityAction dieAction;
public UnityAction episodeBeginAction;
void Start()
{
this.rBody = this.GetComponent<Rigidbody2D>();
}
public override void OnEpisodeBegin()
{
this.transform.localPosition = new Vector3(4.5f, 2f, 0);
this.episodeBeginAction();
}
public override void CollectObservations(VectorSensor sensor)
{
sensor.AddObservation(this.transform.localPosition.y);
sensor.AddObservation(this.rBody.velocity.y);
}
public override void OnActionReceived(ActionBuffers actions)
{
var action = actions.DiscreteActions;
if (action[0] == 1)
{
if (this.rBody.velocity.y == 0)
{
this.Jump();
}
}
}
private void Jump()
{
this.rBody.AddForce(Vector2.up * this.jumpForce);
}
private void OnCollisionEnter2D(Collision2D collision)
{
if (collision.collider.CompareTag("Cactus"))
{
Destroy(collision.gameObject);
this.score = 0;
this.AddReward(-this.score/2);
this.dieAction();
this.EndEpisode();
}
}
private void OnTriggerEnter2D(Collider2D collision)
{
if (collision.tag == "Reward")
{
Destroy(collision.gameObject);
}
}
private float score;
private void OnCollisionStay2D(Collision2D collision)
{
if (collision.collider.CompareTag("Ground"))
{
this.score += 0.03f;
this.AddReward(0.03f);
}
}
public override void Heuristic(in ActionBuffers actionsOut)
{
if (Input.GetKey(KeyCode.Space))
{
var action = actionsOut.DiscreteActions;
action[0] = 1;
}
}
}
| if (collision.collider.CompareTag("Ground")) { this.score += 0.03f; this.AddReward(0.03f); } ------------------------------------------------ if (collision.collider.CompareTag("Cactus")) { Destroy(collision.gameObject); this.score = 0; this.AddReward(-this.score/2); this.EndEpisode(); } |
'Unity > 게임 인공지능 프로그래밍(ML-Agent)' 카테고리의 다른 글
| [펌]Understanding PPO Plots in TensorBoard (0) | 2021.11.16 |
|---|---|
| [ML-Agents] 펭귄 강화학습 (Ray를 사용한 강화학습)(진행중) (0) | 2021.11.10 |
| [ML-Agents] RollerBall (TensorBoard 사용법) (0) | 2021.11.10 |
| [ML-Agents] 3DBall (0) | 2021.11.08 |
| 머신러닝 준비하기 : ML-Agents (0) | 2021.11.08 |