해결된 질문
작성
·
101
0
선생님
#3. 정책 개선
#policy_stable <- true
policy_stable = True
old_pi = pi
#각 s에 대해:
for s in range(num_states):
# pi_s <- argmax_a(sum(p(s',r|s,a)*[r + gamma*V(s')]))
new_action_values = np.zeros(num_actions)
for a in range(num_actions):
for prob, s_, r, _ in transitions[s][a]:
new_action_values[a] += prob * (r + GAMMA * V[s_])
new_action = np.argmax(new_action_values)
pi[s] = np.eye(num_actions)[new_action]
if old_pi.all() != pi.all():
policy_stable = False
여기 최적 정책을 업데이트 하는 과정을 디버깅 하는 과정에서
old.pi = pi 에서 메모리를 공유 해서
pi[s]를 업데이트를 같이 해서 무조건 아래서
if old_pi.all() != pi.all():
구문은 True가 나오게 되어 있는데
의도한 바는
old_pi = copy.deepcopy(pi)
로 코드를 바꾸는게 맞나요?
답변 1
1
코드에 버그가 있었네요. 그래도 결과가 수렴이 되어서 모르고 지나쳤습니다 ^^
다음 내용 수정 했습니다.
GAMMA = 1.0 --> 0.9 로 수정
old_pi = copy.deepcopy(pi) 로 수정
비교문을
if np.array_equal(old_pi, pi): policy_stable = True else: policy_stable = False
로 수정
수정된 코드 github 에 update 해 놓았습니다. 좋은 질문 감사합니다.
확인 감사합니다!