aboutsummaryrefslogtreecommitdiff
path: root/src/rewrite.c
blob: e3cea018cbf1f5c5baa686759388d37a2d90a085 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
#include <fwd/rewrite.h>
#include <stddef.h>
#include <stdlib.h>

struct type_helper {
	char *orig;
	char *new;
};

/* types are simple enough (for now) that this works fine */
static int rewrite_type_ids(struct type *type, char *orig, char *new)
{
	if (type->k == TYPE_ID && strcmp(type->id, orig) == 0) {
		char *r = strdup(new);
		free(type->id);
		type->id = r;
	}

	if (type->t0 && rewrite_type_ids(type->t0, orig, new))
		return -1;

	if (type->n && rewrite_type_ids(type->n, orig, new))
		return -1;

	return 0;
}

static int rewrite_type_visit(struct ast *node, struct type_helper *helper)
{
	if (node->t2)
		return rewrite_type_ids(node->t2, helper->orig, helper->new);

	return 0;
}

int rewrite_types(struct ast *node, char *orig, char *new)
{
	struct type_helper helper = {.orig = orig, .new = new};
	return ast_visit((ast_callback_t)rewrite_type_visit, NULL, node, &helper);
}

/* not the fastest thing in the world but should work well enough for now */
static char *rewrite_hole(char *old, char *new)
{
	/* skip "<>" */
	size_t on = strlen(old) - 2;
	size_t nn = strlen(new);

	/* +1 for null terminator */
	char *r = malloc(on + nn + 1);

	memcpy(r, new, nn);

	/* +2 to skip "<>", +1 for null terminator */
	memcpy(r + nn, old + 2, on + 1);
	return r;
}

static int rewrite_type_holes(struct type *type, char *new)
{
	if (type->id && strncmp(type->id, "<>", 2) == 0) {
		char *r = rewrite_hole(type->id, new);
		free(type->id);
		type->id = r;
	}

	if (type->t0 && rewrite_type_holes(type->t0, new))
		return -1;

	if (type->n && rewrite_type_holes(type->n, new))
		return -1;

	return 0;
}

static int rewrite_holes_visit(struct ast *node, char *new)
{
	if (node->k == AST_CONST_STR)
		return 0;

	if (node->s && strncmp(node->s, "<>", 2) == 0) {
		char *r = rewrite_hole(node->s, new);
		free(node->s);
		node->s = r;
	}

	if (node->t2 && rewrite_type_holes(node->t2, new))
		return -1;

	return 0;
}

int rewrite_holes(struct ast *node, char *new)
{
	return ast_visit((ast_callback_t)rewrite_holes_visit, NULL, node, new);
}