zoukankan      html  css  js  c++  java
  • 井字棋小游戏AI(蒙特卡洛搜索)

    刚把《强化学习》的第一部分写完,突发奇想想写一个井字棋小游戏AI,采用MCTS算法,中间采用了UCT算法作为树中策略,等概率随机作为树外策略。

    代码:

    #include <bits/stdc++.h>
    using namespace std;
    const int maxn = 20010;
    double UCT_C = 2.0;
    struct node {
    	double x, y;
    		
    	double to_double(void) {
    		return x / y;
    	}
    	
    	void init() {
    		x = 0;
    		y = 0; 
    	}
    	
    };
    node V[maxn];
    double eps = 1e-10;
    vector<int> Next[maxn];
    vector<int> Tree[maxn];
    bool ed[maxn];
    char table[5][5];
    mt19937 random(time(0));
    
    int dep(int x) {
    	int ret = 0;
    	for (int i = 0; i < 9; i++) {
    		if(x % 3 != 0) ret++;
    		x /= 3;
    	}
    	return ret;
    }
    int rbuild(void) {
    	int res = 0, p = 1;
    	for (int i = 0; i < 9; i++, p *= 3) {
    		int x = i / 3, y = i % 3, tmp = 0;
    		if(table[x][y] == 0) tmp = 0;
    		else if(table[x][y] == 'b') tmp = 1;
    		else tmp = 2;
    		res = res + p * tmp;
    	}
    	return res;
    }
    
    void build(int st) {
    	for (int i = 0; i < 9; i++, st /= 3) {
    		int now = st % 3;
    		int x = i / 3, y = i % 3;
    		if(now == 0) table[x][y] = 0;
    		else if(now == 1) table[i / 3][i % 3] = 'b';
    		else table[i / 3][i % 3] = 'w';
    	}
    }
    
    vector<int> find_next(int x) {
    	build(x);
    	int now = x, p = 1, d = dep(x);
    	vector<int> ret;
    	for (int i = 0; i < 9; i++, p *= 3) {
    		int x = i / 3, y = i % 3;
    		if(table[x][y] == 0) {
    			ret.push_back(now + p * ((d % 2) + 1));
    		} 
    	}
    	return ret;
    }
    
    bool lose(int st) {
    	build(st);
    	for (int i = 0; i < 3; i++) {
    		if(table[i][0] == table[i][1] && table[i][1] == table[i][2] && table[i][0] == 'w') return true;
    		if(table[0][i] == table[1][i] && table[1][i] == table[2][i] && table[0][i] == 'w') return true;
    	}
    	if(table[0][0] == table[1][1] && table[2][2] == table[1][1] && table[0][0] == 'w') return true;
    	if(table[2][0] == table[1][1] && table[0][2] == table[1][1] && table[2][0] == 'w') return true;
    	return false;
    }
    
    bool vectory(int st) {
    	build(st);
    	for (int i = 0; i < 3; i++) {
    		if(table[i][0] == table[i][1] && table[i][1] == table[i][2] && table[i][0] == 'b') return true;
    		if(table[0][i] == table[1][i] && table[1][i] == table[2][i] && table[0][i] == 'b') return true;
    	}
    	if(table[0][0] == table[1][1] && table[2][2] == table[1][1] && table[0][0] == 'b') return true;
    	if(table[2][0] == table[1][1] && table[0][2] == table[1][1] && table[2][0] == 'b') return true;
    	return false;
    }
    
    int dfs(int x) {
    	if(ed[x]) {
    		if(vectory(x)) {
    			return 1;
    		}
    		else if(lose(x)) {
    			return 2;
    		}
    		else return 3;
    	}
    	int t = random() % Next[x].size();
    	return dfs(Next[x][t]);
    }
    
    double UCT(int x, double tot) {
    	return V[x].to_double() + UCT_C * sqrt(log(tot) / V[x].y);
    }
    
    void MCTS(int root, int flag) {
    	int now = root;
    	stack<int> path;
    	path.push(now);
    	while(!ed[now] && Tree[now].size() == Next[now].size()) {
    		double mx = 0;
    		int mx_pos = 0;
    		for (auto t : Tree[now]) {
    			if(UCT(t, V[now].y) > mx) {
    				mx = UCT(t, V[now].y);
    				mx_pos = t;
    			}
    		}
    		now = mx_pos;
    		flag ^= 1;
    		path.push(now);
    	}
    	if(!ed[now]) {
    		int x = Next[now][Tree[now].size()];
    		Tree[now].push_back(x);
    		flag ^= 1;
    		V[x].init();
    		path.push(x);
    		now = x;
    	}
    	int res = dfs(now);
    	while(path.size()) {
    		now = path.top();
    		path.pop();
    		if(res == 3) {
    			V[now].x += 2;
    			V[now].y += 2;
    		}
    		else if((res == 1 && flag) || (res == 2 && flag == 0)) {
    			V[now].x += 2;
    			V[now].y += 2;
    		} else {
    			V[now].y += 2;
    		}
    		flag ^= 1;
    	}
    }
    
    int solve(int root, bool flag) {
    	for (int i = 1; i <= 500; i++) {
    		MCTS(root, flag); 
    	}
    	int res = -1;
    	double mx = -1;
    	for (auto x : Next[root]) {
    		if(V[x].to_double() > mx) {
    			mx = V[x].to_double();
    			res = x;
    		}
    	}
    	return res;
    }
    
    void init() {
    	int tmp;
    	for (int i = 0; i < 19683; i++) {
    		bool x1 = vectory(i), x2 = lose(i);
    		tmp = 9 - dep(i);
    		if(x1 || x2 || (tmp == 0))
    			ed[i] = true;
    		else {
    			Next[i] = find_next(i);
    		}
    	}
    }
    
    void print_table() {
    	printf("请落子(比如 0 0):
    ");
    	printf("----------
    ");
    	for (int i = 0; i < 3; i++) {
    		for (int j = 0; j < 3; j++) {
    			if(table[i][j] == 0) printf(" ");
    			else printf("%c", table[i][j]);
    			if(j < 2) printf("-"); 
    		}
    		if(i < 2) {
    			printf("
    ");
    			for (int j = 0; j < 5; j++) {
    				if(j % 2 == 0) printf("|");
    				else printf(" ");
    			}
    		}
    		
    		printf("
    ");
    	}
    	printf("----------
    ");
    }
    
    void play() {
    	int s = 0;
    	int e = 0, l = 0, a = 0;
    	bool flag;
    	int round = 0;
    	int T = 10;
    	while(T--) {
    		round++;
    		printf("第%d回合:
    ", round);
    //		printf("----------
    
    
    
    
    round %d
    
    
    
    --------
    ", round);
    		memset(table, 0, sizeof(table));
    		int p = 0;
    		printf("请决定执黑还是执白:
    0: 黑棋; 1: 白棋
    ");
    		scanf("%d", &p);
    //		print_table();
    		s = 0;
    		flag = 0;
    		print_table();
    		while(!ed[s]) {
    			int x, y;
    			if(p == 0) {
    				scanf("%d %d", &x, &y);
    				table[x][y] = 'b';
    				s = rbuild();
    			} else {
    				s = solve(s, flag);
    				build(s);
    			}
    			print_table();
    			if(ed[s]) {
    				if(vectory(s)) {
    					printf("黑方胜利!
    ");
    					a++;
    				}
    				else {
    					printf("平局!
    ");
    					e++;
    				}	
    				break;
    			}
    			flag ^= 1;
    			if(p) {
    				scanf("%d %d", &x, &y);
    				table[x][y] = 'w';
    				s = rbuild();
    			} else {
    				s = solve(s, flag);
    				build(s);
    			}
    			print_table();
    			if(ed[s]) {
    				if(lose(s)) {
    					printf("白方胜利!
    ");
    					l++;
    				}
    				else {
    					printf("平局!
    ");
    					e++;
    				}
    				break;
    			}
    			flag ^= 1;
    		}
    	}
    	printf("AI wins: %d
    player wins: %d
    equals: %d
    ", a, l, e);
    }
    
    int main() {
    	srand(time(0));
    	init();
    	play();
    }
    

      每步的计算量在500的时候已经基本能跑出最优解了,可见MCTS比暴力搜索好很多

  • 相关阅读:
    借助magicwindow sdk plugin快速集成sdk
    Deeplink做不出效果,那是你不会玩!
    iOS/Android 浏览器(h5)及微信中唤起本地APP
    C#回顾 Ado.Net C#连接数据库进行增、删、改、查
    C# 文件操作(全部) 追加、拷贝、删除、移动文件、创建目录 修改文件名、文件夹名
    C#中的静态方法|如何调用静态方法
    SpringBoot实体类对象和json格式的转化
    SpringBoot + kaptcha 生成、校对 验证码
    SpringBoot配置自定义美化Swagger2
    Spring Boot关于layui的通用返回类
  • 原文地址:https://www.cnblogs.com/pkgunboat/p/14249574.html
Copyright © 2011-2022 走看看