해결된 질문
작성
·
277
0
using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Sensors;
using PA_DronePack;
public class DroneAgent : Agent
{
public PA_DroneController dcoScript;
public DroneSetting area;
public GameObject goal;
float preDist;
private Transform agentTrans;
private Transform goalTrans;
private Rigidbody agent_Rigidbody;
public override void Initialize()
{
base.Initialize();
dcoScript = gameObject.Getcomponent<PA_DroneController>();
agentTrans = gameObject.transform;
goalTrans = goal.transform;
agent_Rigidbody = gameObject.GetComponent<agent_Rigidbody>();
Academy.Instance.AgentPreStep += WaitTimeInterference;
}
public override void CollectObservations(VectorSensor sensor)
{
//거리벡터
sensor.AddObservation(agentTrans.position - goalTrans.position);
//속도벡터
sensor.AddObservation(agent_Rigidbody.velocity);
//각속도벡터
sensor.AddObservation(agent_Rigidbody.angularVelocity);
}
public override void OnActionReceived(ActionBuffers actionBuffers)
{
AddReward(-0.01f);
var actions = actionBuffers.ContinuousActions;
float moveX = Mathf.Clamp(actions[0], -1, 1f);
float moveY = Mathf.Clamp(actions[1], -1, 1f);
float moveZ = Mathf.Clamp(actions[2], -1, 1f);
dcoScript.DriveInput(moveX);
dcoScript.StrafeInput(moveY);
dcoScript.LiftInput(moveZ);
float distance = Vector3.Magnitude(goalTrans.position - agentTrans.position);
if(distance <=0.5f)
{
SetReward(1f);
EndEpisode();
}
else if(distance >10f)
{
SetReward(-1f);
EndEpisode();
}
else
{
floar reward = preDist - distance;
AddReward(reward);
preDist = distance;
}
}
public override void OnEpisodeBegin()
{
area.AreaSetting();
preDist = Vector3.Magnitude(goalTrans.position - agentTrans.position);
}
public override void Heuristic(in ActionBuffers actionsOut)
{
var continuousActionsOut = actionsOut.ContinuousActions;
continuousActionsOut[0] = Input.GetAxis("Vertical");
continuousActionsOut[1] = Input.GetAxis("Horizontal");
continuousActionsOut[2] = Input.GetAxis("Mouse ScrollWheel");
}
public float DecisionWaitingTime = 5f;
float m_currentTime = 0f;
public void WaitTimeInterference(int action)
{
if(Academy.Instance.IsCommunicatorOn)
{
RequestDecision();
}
else
{
if(m_currentTime >=DecisionWaitingTime)
{
m_currentTime = 0f;
RequestDecision();
}
else
{
m_currentTime += m_currentTime.fixedDeltaTime;
}
}
}
}
아 제가 코드를 작성하던 부분에서 실수가 있었던 것 같습니다. 깃헙에 있는 코드를 그대로 복붙하니 오류가 해결되었습니다..!