/
OS-Worldaa05f6c
import io
import re
import xml.etree.ElementTree as ET
from typing import List, Tuple
from PIL import Image, ImageDraw, ImageFont
from .deduplicate_node import filter_similar_nodes
attributes_ns_ubuntu = "https://accessibility.windows.example.org/ns/attributes"
attributes_ns_windows = "https://accessibility.windows.example.org/ns/attributes"
state_ns_ubuntu = "https://accessibility.ubuntu.example.org/ns/state"
state_ns_windows = "https://accessibility.windows.example.org/ns/state"
component_ns_ubuntu = "https://accessibility.ubuntu.example.org/ns/component"
component_ns_windows = "https://accessibility.windows.example.org/ns/component"
value_ns_ubuntu = "https://accessibility.ubuntu.example.org/ns/value"
value_ns_windows = "https://accessibility.windows.example.org/ns/value"
class_ns_windows = "https://accessibility.windows.example.org/ns/class"
def find_leaf_nodes(xlm_file_str):
if not xlm_file_str:
return []
root = ET.fromstring(xlm_file_str)
# Recursive function to traverse the XML tree and collect leaf nodes
def collect_leaf_nodes(node, leaf_nodes):
# If the node has no children, it is a leaf node, add it to the list
if not list(node):
leaf_nodes.append(node)
# If the node has children, recurse on each child
for child in node:
collect_leaf_nodes(child, leaf_nodes)
# List to hold all leaf nodes
leaf_nodes = []
collect_leaf_nodes(root, leaf_nodes)
return leaf_nodes
def judge_node(node: ET, platform="Ubuntu", check_image=False) -> bool:
if platform == "Ubuntu":
_state_ns = state_ns_ubuntu
_component_ns = component_ns_ubuntu
elif platform == "Windows":
_state_ns = state_ns_windows
_component_ns = component_ns_windows
else:
raise ValueError("Invalid platform, must be 'Ubuntu' or 'Windows'")
keeps: bool = (
node.tag.startswith("document")
or node.tag.endswith("item")
or node.tag.endswith("button")
or node.tag.endswith("heading")
or node.tag.endswith("label")
or node.tag.endswith("scrollbar")
or node.tag.endswith("searchbox")
or node.tag.endswith("textbox")
or node.tag.endswith("link")
or node.tag.endswith("tabelement")
or node.tag.endswith("textfield")
or node.tag.endswith("textarea")
or node.tag.endswith("menu")
or node.tag
in {
"alert",
"canvas",
"check-box",
"combo-box",
"entry",
"icon",
"image",
"paragraph",
"scroll-bar",
"section",
"slider",
"static",
"table-cell",
"terminal",
"text",
"netuiribbontab",
"start",
"trayclockwclass",
"traydummysearchcontrol",
"uiimage",
"uiproperty",
"uiribboncommandbar",
}
)
keeps = (
keeps
and (
platform == "Ubuntu"
and node.get("{{{:}}}showing".format(_state_ns), "false") == "true"
and node.get("{{{:}}}visible".format(_state_ns), "false") == "true"
or platform == "Windows"
and node.get("{{{:}}}visible".format(_state_ns), "false") == "true"
)
and (
node.get("name", "") != ""
or node.text is not None
and len(node.text) > 0
or check_image
and node.get("image", "false") == "true"
)
)
# and (
# node.get("{{{:}}}enabled".format(_state_ns), "false") == "true"
# or node.get("{{{:}}}editable".format(_state_ns), "false") == "true"
# or node.get("{{{:}}}expandable".format(_state_ns), "false") == "true"
# or node.get("{{{:}}}checkable".format(_state_ns), "false") == "true"
# ) \
coordinates: Tuple[int, int] = eval(node.get("{{{:}}}screencoord".format(_component_ns), "(-1, -1)"))
sizes: Tuple[int, int] = eval(node.get("{{{:}}}size".format(_component_ns), "(-1, -1)"))
keeps = keeps and coordinates[0] >= 0 and coordinates[1] >= 0 and sizes[0] > 0 and sizes[1] > 0
return keeps
def filter_nodes(root: ET, platform="Ubuntu", check_image=False):
filtered_nodes = []
for node in root.iter():
if judge_node(node, platform, check_image):
filtered_nodes.append(node)
return filtered_nodes
def draw_bounding_boxes(nodes, image_file_content, down_sampling_ratio=1.0, platform="Ubuntu"):
if platform == "Ubuntu":
_state_ns = state_ns_ubuntu
_component_ns = component_ns_ubuntu
_value_ns = value_ns_ubuntu
elif platform == "Windows":
_state_ns = state_ns_windows
_component_ns = component_ns_windows
_value_ns = value_ns_windows
else:
raise ValueError("Invalid platform, must be 'Ubuntu' or 'Windows'")
# Load the screenshot image
image_stream = io.BytesIO(image_file_content)
image = Image.open(image_stream)
if float(down_sampling_ratio) != 1.0:
image = image.resize((int(image.size[0] * down_sampling_ratio), int(image.size[1] * down_sampling_ratio)))
draw = ImageDraw.Draw(image)
marks = []
drew_nodes = []
text_informations: List[str] = ["index\ttag\tname\ttext"]
try:
# Adjust the path to the font file you have or use a default one
font = ImageFont.truetype("arial.ttf", 15)
except IOError:
# Fallback to a basic font if the specified font can't be loaded
font = ImageFont.load_default()
index = 1
# Loop over all the visible nodes and draw their bounding boxes
for _node in nodes:
coords_str = _node.attrib.get("{{{:}}}screencoord".format(_component_ns))
size_str = _node.attrib.get("{{{:}}}size".format(_component_ns))
if coords_str and size_str:
try:
# Parse the coordinates and size from the strings
coords = tuple(map(int, coords_str.strip("()").split(", ")))
size = tuple(map(int, size_str.strip("()").split(", ")))
import copy
original_coords = copy.deepcopy(coords)
original_size = copy.deepcopy(size)
if float(down_sampling_ratio) != 1.0:
# Downsample the coordinates and size
coords = tuple(int(coord * down_sampling_ratio) for coord in coords)
size = tuple(int(s * down_sampling_ratio) for s in size)
# Check for negative sizes
if size[0] <= 0 or size[1] <= 0:
raise ValueError(f"Size must be positive, got: {size}")
# Calculate the bottom-right corner of the bounding box
bottom_right = (coords[0] + size[0], coords[1] + size[1])
# Check that bottom_right > coords (x1 >= x0, y1 >= y0)
if bottom_right[0] < coords[0] or bottom_right[1] < coords[1]:
raise ValueError(f"Invalid coordinates or size, coords: {coords}, size: {size}")
# Check if the area only contains one color
cropped_image = image.crop((*coords, *bottom_right))
if len(set(list(cropped_image.getdata()))) == 1:
continue
# Draw rectangle on image
draw.rectangle([coords, bottom_right], outline="red", width=1)
# Draw index number at the bottom left of the bounding box with black background
text_position = (coords[0], bottom_right[1]) # Adjust Y to be above the bottom right
text_bbox: Tuple[int, int, int, int] = draw.textbbox(text_position, str(index), font=font, anchor="lb")
# offset: int = bottom_right[1]-text_bbox[3]
# text_bbox = (text_bbox[0], text_bbox[1]+offset, text_bbox[2], text_bbox[3]+offset)
# draw.rectangle([text_position, (text_position[0] + 25, text_position[1] + 18)], fill='black')
draw.rectangle(text_bbox, fill="black")
draw.text(text_position, str(index), font=font, anchor="lb", fill="white")
# each mark is an x, y, w, h tuple
marks.append([original_coords[0], original_coords[1], original_size[0], original_size[1]])
drew_nodes.append(_node)
if _node.text:
node_text = _node.text if '"' not in _node.text else '"{:}"'.format(_node.text.replace('"', '""'))
elif _node.get("{{{:}}}class".format(class_ns_windows), "").endswith("EditWrapper") and _node.get(
"{{{:}}}value".format(_value_ns)
):
node_text = _node.get("{{{:}}}value".format(_value_ns), "")
node_text = node_text if '"' not in node_text else '"{:}"'.format(node_text.replace('"', '""'))
else:
node_text = '""'
text_information: str = "{:d}\t{:}\t{:}\t{:}".format(index, _node.tag, _node.get("name", ""), node_text)
text_informations.append(text_information)
index += 1
except ValueError:
pass
output_image_stream = io.BytesIO()
image.save(output_image_stream, format="PNG")
image_content = output_image_stream.getvalue()
return marks, drew_nodes, "\n".join(text_informations), image_content
def print_nodes_with_indent(nodes, indent=0):
for node in nodes:
print(" " * indent, node.tag, node.attrib)
print_nodes_with_indent(node, indent + 2)
def find_active_applications(tree, state_ns):
apps_with_active_tag = []
for application in list(tree.getroot()):
app_name = application.attrib.get("name")
for frame in application:
is_active = frame.attrib.get("{{{:}}}active".format(state_ns), "false")
if is_active == "true":
apps_with_active_tag.append(app_name)
if apps_with_active_tag:
to_keep = apps_with_active_tag + ["gnome-shell"]
else:
to_keep = ["gjs", "gnome-shell"]
return to_keep
def linearize_accessibility_tree(accessibility_tree, platform="Ubuntu"):
if platform == "Ubuntu":
_attributes_ns = attributes_ns_ubuntu
_state_ns = state_ns_ubuntu
_component_ns = component_ns_ubuntu
_value_ns = value_ns_ubuntu
elif platform == "Windows":
_attributes_ns = attributes_ns_windows
_state_ns = state_ns_windows
_component_ns = component_ns_windows
_value_ns = value_ns_windows
else:
raise ValueError("Invalid platform, must be 'Ubuntu' or 'Windows'")
try:
tree = ET.ElementTree(ET.fromstring(accessibility_tree))
keep_apps = find_active_applications(tree, _state_ns)
# Remove inactive applications
for application in list(tree.getroot()):
if application.get("name") not in keep_apps:
tree.getroot().remove(application)
filtered_nodes = filter_nodes(tree.getroot(), platform, check_image=True)
linearized_accessibility_tree = ["tag\ttext\tposition (center x & y)\tsize (w & h)"]
# Linearize the accessibility tree nodes into a table format
for node in filtered_nodes:
try:
text = node.text if node.text is not None else ""
text = text.strip()
name = node.get("name", "").strip()
if text == "":
text = name
elif name != "" and text != name:
text = f"{name} ({text})"
text = text.replace("\n", "\\n")
pos = node.get("{{{:}}}screencoord".format(_component_ns), "")
size = node.get("{{{:}}}size".format(_component_ns), "")
x, y = re.match(f"\((\d+), (\d+)\)", pos).groups()
w, h = re.match(f"\((\d+), (\d+)\)", size).groups()
x_mid, y_mid = int(x) + int(w) // 2, int(y) + int(h) // 2
linearized_accessibility_tree.append(
"{:}\t{:}\t{:}\t{:}".format(node.tag, text, f"({x_mid}, {y_mid})", size)
)
except Exception as e:
continue
# Filter out similar nodes
linearized_accessibility_tree = filter_similar_nodes("\n".join(linearized_accessibility_tree))
except Exception as e:
print(f"Error in linearize_accessibility_tree: {e}")
linearized_accessibility_tree = ""
return linearized_accessibility_tree
def trim_accessibility_tree(linearized_accessibility_tree, max_items):
lines = linearized_accessibility_tree.strip().split("\n")
if len(lines) > max_items:
lines = lines[:max_items]
linearized_accessibility_tree = "\n".join(lines)
linearized_accessibility_tree += "\n..."
return linearized_accessibility_tree